Mercurial > repos > goeckslab > tabular_learner
comparison base_model_trainer.py @ 0:209b663a4f62 draft
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
| author | goeckslab |
|---|---|
| date | Wed, 18 Jun 2025 15:38:19 +0000 |
| parents | |
| children | 77c88226bfde |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:209b663a4f62 |
|---|---|
| 1 import base64 | |
| 2 import logging | |
| 3 import os | |
| 4 import tempfile | |
| 5 | |
| 6 import h5py | |
| 7 import joblib | |
| 8 import numpy as np | |
| 9 import pandas as pd | |
| 10 from feature_importance import FeatureImportanceAnalyzer | |
| 11 from sklearn.metrics import average_precision_score | |
| 12 from utils import get_html_closing, get_html_template | |
| 13 | |
| 14 logging.basicConfig(level=logging.DEBUG) | |
| 15 LOG = logging.getLogger(__name__) | |
| 16 | |
| 17 | |
| 18 class BaseModelTrainer: | |
| 19 | |
| 20 def __init__( | |
| 21 self, | |
| 22 input_file, | |
| 23 target_col, | |
| 24 output_dir, | |
| 25 task_type, | |
| 26 random_seed, | |
| 27 test_file=None, | |
| 28 **kwargs): | |
| 29 self.exp = None # This will be set in the subclass | |
| 30 self.input_file = input_file | |
| 31 self.target_col = target_col | |
| 32 self.output_dir = output_dir | |
| 33 self.task_type = task_type | |
| 34 self.random_seed = random_seed | |
| 35 self.data = None | |
| 36 self.target = None | |
| 37 self.best_model = None | |
| 38 self.results = None | |
| 39 self.features_name = None | |
| 40 self.plots = {} | |
| 41 self.expaliner = None | |
| 42 self.plots_explainer_html = None | |
| 43 self.trees = [] | |
| 44 for key, value in kwargs.items(): | |
| 45 setattr(self, key, value) | |
| 46 self.setup_params = {} | |
| 47 self.test_file = test_file | |
| 48 self.test_data = None | |
| 49 | |
| 50 LOG.info(f"Model kwargs: {self.__dict__}") | |
| 51 | |
| 52 def load_data(self): | |
| 53 LOG.info(f"Loading data from {self.input_file}") | |
| 54 self.data = pd.read_csv(self.input_file, sep=None, engine='python') | |
| 55 self.data.columns = self.data.columns.str.replace('.', '_') | |
| 56 | |
| 57 numeric_cols = self.data.select_dtypes(include=['number']).columns | |
| 58 non_numeric_cols = self.data.select_dtypes(exclude=['number']).columns | |
| 59 | |
| 60 self.data[numeric_cols] = self.data[numeric_cols].apply( | |
| 61 pd.to_numeric, errors='coerce') | |
| 62 | |
| 63 if len(non_numeric_cols) > 0: | |
| 64 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") | |
| 65 | |
| 66 names = self.data.columns.to_list() | |
| 67 target_index = int(self.target_col) - 1 | |
| 68 self.target = names[target_index] | |
| 69 self.features_name = [name | |
| 70 for i, name in enumerate(names) | |
| 71 if i != target_index] | |
| 72 if hasattr(self, 'missing_value_strategy'): | |
| 73 if self.missing_value_strategy == 'mean': | |
| 74 self.data = self.data.fillna( | |
| 75 self.data.mean(numeric_only=True)) | |
| 76 elif self.missing_value_strategy == 'median': | |
| 77 self.data = self.data.fillna( | |
| 78 self.data.median(numeric_only=True)) | |
| 79 elif self.missing_value_strategy == 'drop': | |
| 80 self.data = self.data.dropna() | |
| 81 else: | |
| 82 # Default strategy if not specified | |
| 83 self.data = self.data.fillna(self.data.median(numeric_only=True)) | |
| 84 | |
| 85 if self.test_file: | |
| 86 LOG.info(f"Loading test data from {self.test_file}") | |
| 87 self.test_data = pd.read_csv( | |
| 88 self.test_file, sep=None, engine='python') | |
| 89 self.test_data = self.test_data[numeric_cols].apply( | |
| 90 pd.to_numeric, errors='coerce') | |
| 91 self.test_data.columns = self.test_data.columns.str.replace( | |
| 92 '.', '_' | |
| 93 ) | |
| 94 | |
| 95 def setup_pycaret(self): | |
| 96 LOG.info("Initializing PyCaret") | |
| 97 self.setup_params = { | |
| 98 'target': self.target, | |
| 99 'session_id': self.random_seed, | |
| 100 'html': True, | |
| 101 'log_experiment': False, | |
| 102 'system_log': False, | |
| 103 'index': False, | |
| 104 } | |
| 105 | |
| 106 if self.test_data is not None: | |
| 107 self.setup_params['test_data'] = self.test_data | |
| 108 | |
| 109 if hasattr(self, 'train_size') and self.train_size is not None \ | |
| 110 and self.test_data is None: | |
| 111 self.setup_params['train_size'] = self.train_size | |
| 112 | |
| 113 if hasattr(self, 'normalize') and self.normalize is not None: | |
| 114 self.setup_params['normalize'] = self.normalize | |
| 115 | |
| 116 if hasattr(self, 'feature_selection') and \ | |
| 117 self.feature_selection is not None: | |
| 118 self.setup_params['feature_selection'] = self.feature_selection | |
| 119 | |
| 120 if hasattr(self, 'cross_validation') and \ | |
| 121 self.cross_validation is not None \ | |
| 122 and self.cross_validation is False: | |
| 123 self.setup_params['cross_validation'] = self.cross_validation | |
| 124 | |
| 125 if hasattr(self, 'cross_validation') and \ | |
| 126 self.cross_validation is not None: | |
| 127 if hasattr(self, 'cross_validation_folds'): | |
| 128 self.setup_params['fold'] = self.cross_validation_folds | |
| 129 | |
| 130 if hasattr(self, 'remove_outliers') and \ | |
| 131 self.remove_outliers is not None: | |
| 132 self.setup_params['remove_outliers'] = self.remove_outliers | |
| 133 | |
| 134 if hasattr(self, 'remove_multicollinearity') and \ | |
| 135 self.remove_multicollinearity is not None: | |
| 136 self.setup_params['remove_multicollinearity'] = \ | |
| 137 self.remove_multicollinearity | |
| 138 | |
| 139 if hasattr(self, 'polynomial_features') and \ | |
| 140 self.polynomial_features is not None: | |
| 141 self.setup_params['polynomial_features'] = self.polynomial_features | |
| 142 | |
| 143 if hasattr(self, 'fix_imbalance') and \ | |
| 144 self.fix_imbalance is not None: | |
| 145 self.setup_params['fix_imbalance'] = self.fix_imbalance | |
| 146 | |
| 147 LOG.info(self.setup_params) | |
| 148 self.exp.setup(self.data, **self.setup_params) | |
| 149 | |
| 150 def train_model(self): | |
| 151 LOG.info("Training and selecting the best model") | |
| 152 if self.task_type == "classification": | |
| 153 average_displayed = "Weighted" | |
| 154 self.exp.add_metric(id=f'PR-AUC-{average_displayed}', | |
| 155 name=f'PR-AUC-{average_displayed}', | |
| 156 target='pred_proba', | |
| 157 score_func=average_precision_score, | |
| 158 average='weighted' | |
| 159 ) | |
| 160 | |
| 161 if hasattr(self, 'models') and self.models is not None: | |
| 162 self.best_model = self.exp.compare_models( | |
| 163 include=self.models) | |
| 164 else: | |
| 165 self.best_model = self.exp.compare_models() | |
| 166 self.results = self.exp.pull() | |
| 167 if self.task_type == "classification": | |
| 168 self.results.rename(columns={'AUC': 'ROC-AUC'}, inplace=True) | |
| 169 | |
| 170 _ = self.exp.predict_model(self.best_model) | |
| 171 self.test_result_df = self.exp.pull() | |
| 172 if self.task_type == "classification": | |
| 173 self.test_result_df.rename( | |
| 174 columns={'AUC': 'ROC-AUC'}, inplace=True) | |
| 175 | |
| 176 def save_model(self): | |
| 177 hdf5_model_path = "pycaret_model.h5" | |
| 178 with h5py.File(hdf5_model_path, 'w') as f: | |
| 179 with tempfile.NamedTemporaryFile(delete=False) as temp_file: | |
| 180 joblib.dump(self.best_model, temp_file.name) | |
| 181 temp_file.seek(0) | |
| 182 model_bytes = temp_file.read() | |
| 183 f.create_dataset('model', data=np.void(model_bytes)) | |
| 184 | |
| 185 def generate_plots(self): | |
| 186 raise NotImplementedError("Subclasses should implement this method") | |
| 187 | |
| 188 def encode_image_to_base64(self, img_path): | |
| 189 with open(img_path, 'rb') as img_file: | |
| 190 return base64.b64encode(img_file.read()).decode('utf-8') | |
| 191 | |
| 192 def save_html_report(self): | |
| 193 LOG.info("Saving HTML report") | |
| 194 | |
| 195 model_name = type(self.best_model).__name__ | |
| 196 excluded_params = ['html', 'log_experiment', 'system_log', 'test_data'] | |
| 197 filtered_setup_params = { | |
| 198 k: v | |
| 199 for k, v in self.setup_params.items() if k not in excluded_params | |
| 200 } | |
| 201 setup_params_table = pd.DataFrame( | |
| 202 list(filtered_setup_params.items()), columns=['Parameter', 'Value'] | |
| 203 ) | |
| 204 | |
| 205 best_model_params = pd.DataFrame( | |
| 206 self.best_model.get_params().items(), | |
| 207 columns=['Parameter', 'Value'] | |
| 208 ) | |
| 209 best_model_params.to_csv( | |
| 210 os.path.join(self.output_dir, "best_model.csv"), index=False | |
| 211 ) | |
| 212 self.results.to_csv( | |
| 213 os.path.join(self.output_dir, "comparison_results.csv") | |
| 214 ) | |
| 215 self.test_result_df.to_csv( | |
| 216 os.path.join(self.output_dir, "test_results.csv") | |
| 217 ) | |
| 218 | |
| 219 plots_html = "" | |
| 220 length = len(self.plots) | |
| 221 for i, (plot_name, plot_path) in enumerate(self.plots.items()): | |
| 222 encoded_image = self.encode_image_to_base64(plot_path) | |
| 223 plots_html += f""" | |
| 224 <div class="plot"> | |
| 225 <h3>{plot_name.capitalize()}</h3> | |
| 226 <img src="data:image/png;base64,{encoded_image}" | |
| 227 alt="{plot_name}"> | |
| 228 </div> | |
| 229 """ | |
| 230 if i < length - 1: | |
| 231 plots_html += "<hr>" | |
| 232 | |
| 233 tree_plots = "" | |
| 234 for i, tree in enumerate(self.trees): | |
| 235 if tree: | |
| 236 tree_plots += f""" | |
| 237 <div class="plot"> | |
| 238 <h3>Tree {i+1}</h3> | |
| 239 <img src="data:image/png;base64, | |
| 240 {tree}" | |
| 241 alt="tree {i+1}"> | |
| 242 </div> | |
| 243 """ | |
| 244 | |
| 245 analyzer = FeatureImportanceAnalyzer( | |
| 246 data=self.data, | |
| 247 target_col=self.target_col, | |
| 248 task_type=self.task_type, | |
| 249 output_dir=self.output_dir, | |
| 250 ) | |
| 251 feature_importance_html = analyzer.run() | |
| 252 | |
| 253 html_content = f""" | |
| 254 {get_html_template()} | |
| 255 <h1>PyCaret Model Training Report</h1> | |
| 256 <div class="tabs"> | |
| 257 <div class="tab" onclick="openTab(event, 'summary')"> | |
| 258 Setup & Best Model</div> | |
| 259 <div class="tab" onclick="openTab(event, 'plots')"> | |
| 260 Best Model Plots</div> | |
| 261 <div class="tab" onclick="openTab(event, 'feature')"> | |
| 262 Feature Importance</div> | |
| 263 """ | |
| 264 if self.plots_explainer_html: | |
| 265 html_content += """ | |
| 266 <div class="tab" onclick="openTab(event, 'explainer')"> | |
| 267 Explainer Plots</div> | |
| 268 """ | |
| 269 html_content += f""" | |
| 270 </div> | |
| 271 <div id="summary" class="tab-content"> | |
| 272 <h2>Setup Parameters</h2> | |
| 273 {setup_params_table.to_html( | |
| 274 index=False, | |
| 275 header=True, | |
| 276 classes='table sortable' | |
| 277 )} | |
| 278 <h5>If you want to know all the experiment setup parameters, | |
| 279 please check the PyCaret documentation for | |
| 280 the classification/regression <code>exp</code> function.</h5> | |
| 281 <h2>Best Model: {model_name}</h2> | |
| 282 {best_model_params.to_html( | |
| 283 index=False, | |
| 284 header=True, | |
| 285 classes='table sortable' | |
| 286 )} | |
| 287 <h2>Comparison Results on the Cross-Validation Set</h2> | |
| 288 {self.results.to_html(index=False, classes='table sortable')} | |
| 289 <h2>Results on the Test Set for the best model</h2> | |
| 290 {self.test_result_df.to_html( | |
| 291 index=False, | |
| 292 classes='table sortable' | |
| 293 )} | |
| 294 </div> | |
| 295 <div id="plots" class="tab-content"> | |
| 296 <h2>Best Model Plots on the testing set</h2> | |
| 297 {plots_html} | |
| 298 </div> | |
| 299 <div id="feature" class="tab-content"> | |
| 300 {feature_importance_html} | |
| 301 </div> | |
| 302 """ | |
| 303 if self.plots_explainer_html: | |
| 304 html_content += f""" | |
| 305 <div id="explainer" class="tab-content"> | |
| 306 {self.plots_explainer_html} | |
| 307 {tree_plots} | |
| 308 </div> | |
| 309 """ | |
| 310 html_content += """ | |
| 311 <script> | |
| 312 document.addEventListener("DOMContentLoaded", function() { | |
| 313 var tables = document.querySelectorAll("table.sortable"); | |
| 314 tables.forEach(function(table) { | |
| 315 var headers = table.querySelectorAll("th"); | |
| 316 headers.forEach(function(header, index) { | |
| 317 header.style.cursor = "pointer"; | |
| 318 // Add initial arrow (up) to indicate sortability | |
| 319 header.innerHTML += '<span class="sort-arrow"> ↑</span>'; | |
| 320 header.addEventListener("click", function() { | |
| 321 var direction = this.getAttribute( | |
| 322 "data-sort-direction" | |
| 323 ) || "asc"; | |
| 324 // Reset arrows in all headers of this table | |
| 325 headers.forEach(function(h) { | |
| 326 var arrow = h.querySelector(".sort-arrow"); | |
| 327 if (arrow) arrow.textContent = " ↑"; | |
| 328 }); | |
| 329 // Set arrow for clicked header | |
| 330 var arrow = this.querySelector(".sort-arrow"); | |
| 331 arrow.textContent = direction === "asc" ? " ↓" : " ↑"; | |
| 332 sortTable(table, index, direction); | |
| 333 this.setAttribute("data-sort-direction", | |
| 334 direction === "asc" ? "desc" : "asc"); | |
| 335 }); | |
| 336 }); | |
| 337 }); | |
| 338 }); | |
| 339 | |
| 340 function sortTable(table, colNum, direction) { | |
| 341 var tb = table.tBodies[0]; | |
| 342 var tr = Array.prototype.slice.call(tb.rows, 0); | |
| 343 var multiplier = direction === "asc" ? 1 : -1; | |
| 344 tr = tr.sort(function(a, b) { | |
| 345 var aText = a.cells[colNum].textContent.trim(); | |
| 346 var bText = b.cells[colNum].textContent.trim(); | |
| 347 // Remove arrow from text comparison | |
| 348 aText = aText.replace(/[↑↓]/g, '').trim(); | |
| 349 bText = bText.replace(/[↑↓]/g, '').trim(); | |
| 350 if (!isNaN(aText) && !isNaN(bText)) { | |
| 351 return multiplier * ( | |
| 352 parseFloat(aText) - parseFloat(bText) | |
| 353 ); | |
| 354 } else { | |
| 355 return multiplier * aText.localeCompare(bText); | |
| 356 } | |
| 357 }); | |
| 358 for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]); | |
| 359 } | |
| 360 </script> | |
| 361 """ | |
| 362 html_content += f""" | |
| 363 {get_html_closing()} | |
| 364 """ | |
| 365 with open( | |
| 366 os.path.join(self.output_dir, "comparison_result.html"), | |
| 367 "w" | |
| 368 ) as file: | |
| 369 file.write(html_content) | |
| 370 | |
| 371 def save_dashboard(self): | |
| 372 raise NotImplementedError("Subclasses should implement this method") | |
| 373 | |
| 374 def generate_plots_explainer(self): | |
| 375 raise NotImplementedError("Subclasses should implement this method") | |
| 376 | |
| 377 # not working now | |
| 378 def generate_tree_plots(self): | |
| 379 from sklearn.ensemble import RandomForestClassifier, \ | |
| 380 RandomForestRegressor | |
| 381 from xgboost import XGBClassifier, XGBRegressor | |
| 382 from explainerdashboard.explainers import RandomForestExplainer | |
| 383 | |
| 384 LOG.info("Generating tree plots") | |
| 385 X_test = self.exp.X_test_transformed.copy() | |
| 386 y_test = self.exp.y_test_transformed | |
| 387 | |
| 388 is_rf = isinstance(self.best_model, RandomForestClassifier) or \ | |
| 389 isinstance(self.best_model, RandomForestRegressor) | |
| 390 | |
| 391 is_xgb = isinstance(self.best_model, XGBClassifier) or \ | |
| 392 isinstance(self.best_model, XGBRegressor) | |
| 393 | |
| 394 try: | |
| 395 if is_rf: | |
| 396 num_trees = self.best_model.n_estimators | |
| 397 if is_xgb: | |
| 398 num_trees = len(self.best_model.get_booster().get_dump()) | |
| 399 explainer = RandomForestExplainer(self.best_model, X_test, y_test) | |
| 400 for i in range(num_trees): | |
| 401 fig = explainer.decisiontree_encoded(tree_idx=i, index=0) | |
| 402 LOG.info(f"Tree {i+1}") | |
| 403 LOG.info(fig) | |
| 404 self.trees.append(fig) | |
| 405 except Exception as e: | |
| 406 LOG.error(f"Error generating tree plots: {e}") | |
| 407 | |
| 408 def run(self): | |
| 409 self.load_data() | |
| 410 self.setup_pycaret() | |
| 411 self.train_model() | |
| 412 self.save_model() | |
| 413 self.generate_plots() | |
| 414 self.generate_plots_explainer() | |
| 415 self.generate_tree_plots() | |
| 416 self.save_html_report() | |
| 417 # self.save_dashboard() |
