Mercurial > repos > goeckslab > pycaret_predict
comparison base_model_trainer.py @ 3:ccd798db5abb draft
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit cf47efb521b91a9cb44ae5c5ade860627f9b9030
| author | goeckslab |
|---|---|
| date | Tue, 03 Jun 2025 19:31:06 +0000 |
| parents | 0314dad38aaa |
| children | a32ff7201629 |
comparison
equal
deleted
inserted
replaced
| 2:0314dad38aaa | 3:ccd798db5abb |
|---|---|
| 1 import base64 | 1 import base64 |
| 2 import logging | 2 import logging |
| 3 import os | 3 import os |
| 4 import tempfile | 4 import tempfile |
| 5 | 5 |
| 6 import h5py | |
| 7 import joblib | |
| 8 import numpy as np | |
| 9 import pandas as pd | |
| 6 from feature_importance import FeatureImportanceAnalyzer | 10 from feature_importance import FeatureImportanceAnalyzer |
| 7 | |
| 8 import h5py | |
| 9 | |
| 10 import joblib | |
| 11 | |
| 12 import numpy as np | |
| 13 | |
| 14 import pandas as pd | |
| 15 | |
| 16 from sklearn.metrics import average_precision_score | 11 from sklearn.metrics import average_precision_score |
| 17 | |
| 18 from utils import get_html_closing, get_html_template | 12 from utils import get_html_closing, get_html_template |
| 19 | 13 |
| 20 logging.basicConfig(level=logging.DEBUG) | 14 logging.basicConfig(level=logging.DEBUG) |
| 21 LOG = logging.getLogger(__name__) | 15 LOG = logging.getLogger(__name__) |
| 22 | 16 |
| 29 target_col, | 23 target_col, |
| 30 output_dir, | 24 output_dir, |
| 31 task_type, | 25 task_type, |
| 32 random_seed, | 26 random_seed, |
| 33 test_file=None, | 27 test_file=None, |
| 34 **kwargs | 28 **kwargs): |
| 35 ): | |
| 36 self.exp = None # This will be set in the subclass | 29 self.exp = None # This will be set in the subclass |
| 37 self.input_file = input_file | 30 self.input_file = input_file |
| 38 self.target_col = target_col | 31 self.target_col = target_col |
| 39 self.output_dir = output_dir | 32 self.output_dir = output_dir |
| 40 self.task_type = task_type | 33 self.task_type = task_type |
| 69 | 62 |
| 70 if len(non_numeric_cols) > 0: | 63 if len(non_numeric_cols) > 0: |
| 71 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") | 64 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") |
| 72 | 65 |
| 73 names = self.data.columns.to_list() | 66 names = self.data.columns.to_list() |
| 74 target_index = int(self.target_col)-1 | 67 target_index = int(self.target_col) - 1 |
| 75 self.target = names[target_index] | 68 self.target = names[target_index] |
| 76 self.features_name = [name | 69 self.features_name = [name |
| 77 for i, name in enumerate(names) | 70 for i, name in enumerate(names) |
| 78 if i != target_index] | 71 if i != target_index] |
| 79 if hasattr(self, 'missing_value_strategy'): | 72 if hasattr(self, 'missing_value_strategy'): |
| 95 self.test_file, sep=None, engine='python') | 88 self.test_file, sep=None, engine='python') |
| 96 self.test_data = self.test_data[numeric_cols].apply( | 89 self.test_data = self.test_data[numeric_cols].apply( |
| 97 pd.to_numeric, errors='coerce') | 90 pd.to_numeric, errors='coerce') |
| 98 self.test_data.columns = self.test_data.columns.str.replace( | 91 self.test_data.columns = self.test_data.columns.str.replace( |
| 99 '.', '_' | 92 '.', '_' |
| 100 ) | 93 ) |
| 101 | 94 |
| 102 def setup_pycaret(self): | 95 def setup_pycaret(self): |
| 103 LOG.info("Initializing PyCaret") | 96 LOG.info("Initializing PyCaret") |
| 104 self.setup_params = { | 97 self.setup_params = { |
| 105 'target': self.target, | 98 'target': self.target, |
| 204 filtered_setup_params = { | 197 filtered_setup_params = { |
| 205 k: v | 198 k: v |
| 206 for k, v in self.setup_params.items() if k not in excluded_params | 199 for k, v in self.setup_params.items() if k not in excluded_params |
| 207 } | 200 } |
| 208 setup_params_table = pd.DataFrame( | 201 setup_params_table = pd.DataFrame( |
| 209 list(filtered_setup_params.items()), | 202 list(filtered_setup_params.items()), columns=['Parameter', 'Value'] |
| 210 columns=['Parameter', 'Value']) | 203 ) |
| 211 | 204 |
| 212 best_model_params = pd.DataFrame( | 205 best_model_params = pd.DataFrame( |
| 213 self.best_model.get_params().items(), | 206 self.best_model.get_params().items(), |
| 214 columns=['Parameter', 'Value']) | 207 columns=['Parameter', 'Value'] |
| 208 ) | |
| 215 best_model_params.to_csv( | 209 best_model_params.to_csv( |
| 216 os.path.join(self.output_dir, 'best_model.csv'), | 210 os.path.join(self.output_dir, "best_model.csv"), index=False |
| 217 index=False) | 211 ) |
| 218 self.results.to_csv(os.path.join( | 212 self.results.to_csv( |
| 219 self.output_dir, "comparison_results.csv")) | 213 os.path.join(self.output_dir, "comparison_results.csv") |
| 220 self.test_result_df.to_csv(os.path.join( | 214 ) |
| 221 self.output_dir, "test_results.csv")) | 215 self.test_result_df.to_csv( |
| 216 os.path.join(self.output_dir, "test_results.csv") | |
| 217 ) | |
| 222 | 218 |
| 223 plots_html = "" | 219 plots_html = "" |
| 224 length = len(self.plots) | 220 length = len(self.plots) |
| 225 for i, (plot_name, plot_path) in enumerate(self.plots.items()): | 221 for i, (plot_name, plot_path) in enumerate(self.plots.items()): |
| 226 encoded_image = self.encode_image_to_base64(plot_path) | 222 encoded_image = self.encode_image_to_base64(plot_path) |
| 248 | 244 |
| 249 analyzer = FeatureImportanceAnalyzer( | 245 analyzer = FeatureImportanceAnalyzer( |
| 250 data=self.data, | 246 data=self.data, |
| 251 target_col=self.target_col, | 247 target_col=self.target_col, |
| 252 task_type=self.task_type, | 248 task_type=self.task_type, |
| 253 output_dir=self.output_dir) | 249 output_dir=self.output_dir, |
| 250 ) | |
| 254 feature_importance_html = analyzer.run() | 251 feature_importance_html = analyzer.run() |
| 255 | 252 |
| 256 html_content = f""" | 253 html_content = f""" |
| 257 {get_html_template()} | 254 {get_html_template()} |
| 258 <h1>PyCaret Model Training Report</h1> | 255 <h1>PyCaret Model Training Report</h1> |
| 261 Setup & Best Model</div> | 258 Setup & Best Model</div> |
| 262 <div class="tab" onclick="openTab(event, 'plots')"> | 259 <div class="tab" onclick="openTab(event, 'plots')"> |
| 263 Best Model Plots</div> | 260 Best Model Plots</div> |
| 264 <div class="tab" onclick="openTab(event, 'feature')"> | 261 <div class="tab" onclick="openTab(event, 'feature')"> |
| 265 Feature Importance</div> | 262 Feature Importance</div> |
| 266 """ | 263 """ |
| 267 if self.plots_explainer_html: | 264 if self.plots_explainer_html: |
| 268 html_content += """ | 265 html_content += """ |
| 269 "<div class="tab" onclick="openTab(event, 'explainer')">" | 266 <div class="tab" onclick="openTab(event, 'explainer')"> |
| 270 Explainer Plots</div> | 267 Explainer Plots</div> |
| 271 """ | 268 """ |
| 272 html_content += f""" | 269 html_content += f""" |
| 273 </div> | 270 </div> |
| 274 <div id="summary" class="tab-content"> | 271 <div id="summary" class="tab-content"> |
| 275 <h2>Setup Parameters</h2> | 272 <h2>Setup Parameters</h2> |
| 276 <table> | 273 {setup_params_table.to_html( |
| 277 <tr><th>Parameter</th><th>Value</th></tr> | 274 index=False, |
| 278 {setup_params_table.to_html( | 275 header=True, |
| 279 index=False, header=False, classes='table')} | 276 classes='table sortable' |
| 280 </table> | 277 )} |
| 281 <h5>If you want to know all the experiment setup parameters, | 278 <h5>If you want to know all the experiment setup parameters, |
| 282 please check the PyCaret documentation for | 279 please check the PyCaret documentation for |
| 283 the classification/regression <code>exp</code> function.</h5> | 280 the classification/regression <code>exp</code> function.</h5> |
| 284 <h2>Best Model: {model_name}</h2> | 281 <h2>Best Model: {model_name}</h2> |
| 285 <table> | 282 {best_model_params.to_html( |
| 286 <tr><th>Parameter</th><th>Value</th></tr> | 283 index=False, |
| 287 {best_model_params.to_html( | 284 header=True, |
| 288 index=False, header=False, classes='table')} | 285 classes='table sortable' |
| 289 </table> | 286 )} |
| 290 <h2>Comparison Results on the Cross-Validation Set</h2> | 287 <h2>Comparison Results on the Cross-Validation Set</h2> |
| 291 <table> | 288 {self.results.to_html(index=False, classes='table sortable')} |
| 292 {self.results.to_html(index=False, classes='table')} | |
| 293 </table> | |
| 294 <h2>Results on the Test Set for the best model</h2> | 289 <h2>Results on the Test Set for the best model</h2> |
| 295 <table> | 290 {self.test_result_df.to_html( |
| 296 {self.test_result_df.to_html(index=False, classes='table')} | 291 index=False, |
| 297 </table> | 292 classes='table sortable' |
| 293 )} | |
| 298 </div> | 294 </div> |
| 299 <div id="plots" class="tab-content"> | 295 <div id="plots" class="tab-content"> |
| 300 <h2>Best Model Plots on the testing set</h2> | 296 <h2>Best Model Plots on the testing set</h2> |
| 301 {plots_html} | 297 {plots_html} |
| 302 </div> | 298 </div> |
| 308 html_content += f""" | 304 html_content += f""" |
| 309 <div id="explainer" class="tab-content"> | 305 <div id="explainer" class="tab-content"> |
| 310 {self.plots_explainer_html} | 306 {self.plots_explainer_html} |
| 311 {tree_plots} | 307 {tree_plots} |
| 312 </div> | 308 </div> |
| 313 {get_html_closing()} | |
| 314 """ | 309 """ |
| 315 else: | 310 html_content += """ |
| 316 html_content += f""" | 311 <script> |
| 317 {get_html_closing()} | 312 document.addEventListener("DOMContentLoaded", function() { |
| 318 """ | 313 var tables = document.querySelectorAll("table.sortable"); |
| 319 with open(os.path.join( | 314 tables.forEach(function(table) { |
| 320 self.output_dir, "comparison_result.html"), "w") as file: | 315 var headers = table.querySelectorAll("th"); |
| 316 headers.forEach(function(header, index) { | |
| 317 header.style.cursor = "pointer"; | |
| 318 // Add initial arrow (up) to indicate sortability | |
| 319 header.innerHTML += '<span class="sort-arrow"> ↑</span>'; | |
| 320 header.addEventListener("click", function() { | |
| 321 var direction = this.getAttribute( | |
| 322 "data-sort-direction" | |
| 323 ) || "asc"; | |
| 324 // Reset arrows in all headers of this table | |
| 325 headers.forEach(function(h) { | |
| 326 var arrow = h.querySelector(".sort-arrow"); | |
| 327 if (arrow) arrow.textContent = " ↑"; | |
| 328 }); | |
| 329 // Set arrow for clicked header | |
| 330 var arrow = this.querySelector(".sort-arrow"); | |
| 331 arrow.textContent = direction === "asc" ? " ↓" : " ↑"; | |
| 332 sortTable(table, index, direction); | |
| 333 this.setAttribute("data-sort-direction", | |
| 334 direction === "asc" ? "desc" : "asc"); | |
| 335 }); | |
| 336 }); | |
| 337 }); | |
| 338 }); | |
| 339 | |
| 340 function sortTable(table, colNum, direction) { | |
| 341 var tb = table.tBodies[0]; | |
| 342 var tr = Array.prototype.slice.call(tb.rows, 0); | |
| 343 var multiplier = direction === "asc" ? 1 : -1; | |
| 344 tr = tr.sort(function(a, b) { | |
| 345 var aText = a.cells[colNum].textContent.trim(); | |
| 346 var bText = b.cells[colNum].textContent.trim(); | |
| 347 // Remove arrow from text comparison | |
| 348 aText = aText.replace(/[↑↓]/g, '').trim(); | |
| 349 bText = bText.replace(/[↑↓]/g, '').trim(); | |
| 350 if (!isNaN(aText) && !isNaN(bText)) { | |
| 351 return multiplier * ( | |
| 352 parseFloat(aText) - parseFloat(bText) | |
| 353 ); | |
| 354 } else { | |
| 355 return multiplier * aText.localeCompare(bText); | |
| 356 } | |
| 357 }); | |
| 358 for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]); | |
| 359 } | |
| 360 </script> | |
| 361 """ | |
| 362 html_content += f""" | |
| 363 {get_html_closing()} | |
| 364 """ | |
| 365 with open( | |
| 366 os.path.join(self.output_dir, "comparison_result.html"), | |
| 367 "w" | |
| 368 ) as file: | |
| 321 file.write(html_content) | 369 file.write(html_content) |
| 322 | 370 |
| 323 def save_dashboard(self): | 371 def save_dashboard(self): |
| 324 raise NotImplementedError("Subclasses should implement this method") | 372 raise NotImplementedError("Subclasses should implement this method") |
| 325 | 373 |
