comparison base_model_trainer.py @ 0:209b663a4f62 draft

planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
author goeckslab
date Wed, 18 Jun 2025 15:38:19 +0000
parents
children 77c88226bfde
comparison
equal deleted inserted replaced
-1:000000000000 0:209b663a4f62
1 import base64
2 import logging
3 import os
4 import tempfile
5
6 import h5py
7 import joblib
8 import numpy as np
9 import pandas as pd
10 from feature_importance import FeatureImportanceAnalyzer
11 from sklearn.metrics import average_precision_score
12 from utils import get_html_closing, get_html_template
13
14 logging.basicConfig(level=logging.DEBUG)
15 LOG = logging.getLogger(__name__)
16
17
18 class BaseModelTrainer:
19
20 def __init__(
21 self,
22 input_file,
23 target_col,
24 output_dir,
25 task_type,
26 random_seed,
27 test_file=None,
28 **kwargs):
29 self.exp = None # This will be set in the subclass
30 self.input_file = input_file
31 self.target_col = target_col
32 self.output_dir = output_dir
33 self.task_type = task_type
34 self.random_seed = random_seed
35 self.data = None
36 self.target = None
37 self.best_model = None
38 self.results = None
39 self.features_name = None
40 self.plots = {}
41 self.expaliner = None
42 self.plots_explainer_html = None
43 self.trees = []
44 for key, value in kwargs.items():
45 setattr(self, key, value)
46 self.setup_params = {}
47 self.test_file = test_file
48 self.test_data = None
49
50 LOG.info(f"Model kwargs: {self.__dict__}")
51
52 def load_data(self):
53 LOG.info(f"Loading data from {self.input_file}")
54 self.data = pd.read_csv(self.input_file, sep=None, engine='python')
55 self.data.columns = self.data.columns.str.replace('.', '_')
56
57 numeric_cols = self.data.select_dtypes(include=['number']).columns
58 non_numeric_cols = self.data.select_dtypes(exclude=['number']).columns
59
60 self.data[numeric_cols] = self.data[numeric_cols].apply(
61 pd.to_numeric, errors='coerce')
62
63 if len(non_numeric_cols) > 0:
64 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}")
65
66 names = self.data.columns.to_list()
67 target_index = int(self.target_col) - 1
68 self.target = names[target_index]
69 self.features_name = [name
70 for i, name in enumerate(names)
71 if i != target_index]
72 if hasattr(self, 'missing_value_strategy'):
73 if self.missing_value_strategy == 'mean':
74 self.data = self.data.fillna(
75 self.data.mean(numeric_only=True))
76 elif self.missing_value_strategy == 'median':
77 self.data = self.data.fillna(
78 self.data.median(numeric_only=True))
79 elif self.missing_value_strategy == 'drop':
80 self.data = self.data.dropna()
81 else:
82 # Default strategy if not specified
83 self.data = self.data.fillna(self.data.median(numeric_only=True))
84
85 if self.test_file:
86 LOG.info(f"Loading test data from {self.test_file}")
87 self.test_data = pd.read_csv(
88 self.test_file, sep=None, engine='python')
89 self.test_data = self.test_data[numeric_cols].apply(
90 pd.to_numeric, errors='coerce')
91 self.test_data.columns = self.test_data.columns.str.replace(
92 '.', '_'
93 )
94
95 def setup_pycaret(self):
96 LOG.info("Initializing PyCaret")
97 self.setup_params = {
98 'target': self.target,
99 'session_id': self.random_seed,
100 'html': True,
101 'log_experiment': False,
102 'system_log': False,
103 'index': False,
104 }
105
106 if self.test_data is not None:
107 self.setup_params['test_data'] = self.test_data
108
109 if hasattr(self, 'train_size') and self.train_size is not None \
110 and self.test_data is None:
111 self.setup_params['train_size'] = self.train_size
112
113 if hasattr(self, 'normalize') and self.normalize is not None:
114 self.setup_params['normalize'] = self.normalize
115
116 if hasattr(self, 'feature_selection') and \
117 self.feature_selection is not None:
118 self.setup_params['feature_selection'] = self.feature_selection
119
120 if hasattr(self, 'cross_validation') and \
121 self.cross_validation is not None \
122 and self.cross_validation is False:
123 self.setup_params['cross_validation'] = self.cross_validation
124
125 if hasattr(self, 'cross_validation') and \
126 self.cross_validation is not None:
127 if hasattr(self, 'cross_validation_folds'):
128 self.setup_params['fold'] = self.cross_validation_folds
129
130 if hasattr(self, 'remove_outliers') and \
131 self.remove_outliers is not None:
132 self.setup_params['remove_outliers'] = self.remove_outliers
133
134 if hasattr(self, 'remove_multicollinearity') and \
135 self.remove_multicollinearity is not None:
136 self.setup_params['remove_multicollinearity'] = \
137 self.remove_multicollinearity
138
139 if hasattr(self, 'polynomial_features') and \
140 self.polynomial_features is not None:
141 self.setup_params['polynomial_features'] = self.polynomial_features
142
143 if hasattr(self, 'fix_imbalance') and \
144 self.fix_imbalance is not None:
145 self.setup_params['fix_imbalance'] = self.fix_imbalance
146
147 LOG.info(self.setup_params)
148 self.exp.setup(self.data, **self.setup_params)
149
150 def train_model(self):
151 LOG.info("Training and selecting the best model")
152 if self.task_type == "classification":
153 average_displayed = "Weighted"
154 self.exp.add_metric(id=f'PR-AUC-{average_displayed}',
155 name=f'PR-AUC-{average_displayed}',
156 target='pred_proba',
157 score_func=average_precision_score,
158 average='weighted'
159 )
160
161 if hasattr(self, 'models') and self.models is not None:
162 self.best_model = self.exp.compare_models(
163 include=self.models)
164 else:
165 self.best_model = self.exp.compare_models()
166 self.results = self.exp.pull()
167 if self.task_type == "classification":
168 self.results.rename(columns={'AUC': 'ROC-AUC'}, inplace=True)
169
170 _ = self.exp.predict_model(self.best_model)
171 self.test_result_df = self.exp.pull()
172 if self.task_type == "classification":
173 self.test_result_df.rename(
174 columns={'AUC': 'ROC-AUC'}, inplace=True)
175
176 def save_model(self):
177 hdf5_model_path = "pycaret_model.h5"
178 with h5py.File(hdf5_model_path, 'w') as f:
179 with tempfile.NamedTemporaryFile(delete=False) as temp_file:
180 joblib.dump(self.best_model, temp_file.name)
181 temp_file.seek(0)
182 model_bytes = temp_file.read()
183 f.create_dataset('model', data=np.void(model_bytes))
184
185 def generate_plots(self):
186 raise NotImplementedError("Subclasses should implement this method")
187
188 def encode_image_to_base64(self, img_path):
189 with open(img_path, 'rb') as img_file:
190 return base64.b64encode(img_file.read()).decode('utf-8')
191
192 def save_html_report(self):
193 LOG.info("Saving HTML report")
194
195 model_name = type(self.best_model).__name__
196 excluded_params = ['html', 'log_experiment', 'system_log', 'test_data']
197 filtered_setup_params = {
198 k: v
199 for k, v in self.setup_params.items() if k not in excluded_params
200 }
201 setup_params_table = pd.DataFrame(
202 list(filtered_setup_params.items()), columns=['Parameter', 'Value']
203 )
204
205 best_model_params = pd.DataFrame(
206 self.best_model.get_params().items(),
207 columns=['Parameter', 'Value']
208 )
209 best_model_params.to_csv(
210 os.path.join(self.output_dir, "best_model.csv"), index=False
211 )
212 self.results.to_csv(
213 os.path.join(self.output_dir, "comparison_results.csv")
214 )
215 self.test_result_df.to_csv(
216 os.path.join(self.output_dir, "test_results.csv")
217 )
218
219 plots_html = ""
220 length = len(self.plots)
221 for i, (plot_name, plot_path) in enumerate(self.plots.items()):
222 encoded_image = self.encode_image_to_base64(plot_path)
223 plots_html += f"""
224 <div class="plot">
225 <h3>{plot_name.capitalize()}</h3>
226 <img src="data:image/png;base64,{encoded_image}"
227 alt="{plot_name}">
228 </div>
229 """
230 if i < length - 1:
231 plots_html += "<hr>"
232
233 tree_plots = ""
234 for i, tree in enumerate(self.trees):
235 if tree:
236 tree_plots += f"""
237 <div class="plot">
238 <h3>Tree {i+1}</h3>
239 <img src="data:image/png;base64,
240 {tree}"
241 alt="tree {i+1}">
242 </div>
243 """
244
245 analyzer = FeatureImportanceAnalyzer(
246 data=self.data,
247 target_col=self.target_col,
248 task_type=self.task_type,
249 output_dir=self.output_dir,
250 )
251 feature_importance_html = analyzer.run()
252
253 html_content = f"""
254 {get_html_template()}
255 <h1>PyCaret Model Training Report</h1>
256 <div class="tabs">
257 <div class="tab" onclick="openTab(event, 'summary')">
258 Setup & Best Model</div>
259 <div class="tab" onclick="openTab(event, 'plots')">
260 Best Model Plots</div>
261 <div class="tab" onclick="openTab(event, 'feature')">
262 Feature Importance</div>
263 """
264 if self.plots_explainer_html:
265 html_content += """
266 <div class="tab" onclick="openTab(event, 'explainer')">
267 Explainer Plots</div>
268 """
269 html_content += f"""
270 </div>
271 <div id="summary" class="tab-content">
272 <h2>Setup Parameters</h2>
273 {setup_params_table.to_html(
274 index=False,
275 header=True,
276 classes='table sortable'
277 )}
278 <h5>If you want to know all the experiment setup parameters,
279 please check the PyCaret documentation for
280 the classification/regression <code>exp</code> function.</h5>
281 <h2>Best Model: {model_name}</h2>
282 {best_model_params.to_html(
283 index=False,
284 header=True,
285 classes='table sortable'
286 )}
287 <h2>Comparison Results on the Cross-Validation Set</h2>
288 {self.results.to_html(index=False, classes='table sortable')}
289 <h2>Results on the Test Set for the best model</h2>
290 {self.test_result_df.to_html(
291 index=False,
292 classes='table sortable'
293 )}
294 </div>
295 <div id="plots" class="tab-content">
296 <h2>Best Model Plots on the testing set</h2>
297 {plots_html}
298 </div>
299 <div id="feature" class="tab-content">
300 {feature_importance_html}
301 </div>
302 """
303 if self.plots_explainer_html:
304 html_content += f"""
305 <div id="explainer" class="tab-content">
306 {self.plots_explainer_html}
307 {tree_plots}
308 </div>
309 """
310 html_content += """
311 <script>
312 document.addEventListener("DOMContentLoaded", function() {
313 var tables = document.querySelectorAll("table.sortable");
314 tables.forEach(function(table) {
315 var headers = table.querySelectorAll("th");
316 headers.forEach(function(header, index) {
317 header.style.cursor = "pointer";
318 // Add initial arrow (up) to indicate sortability
319 header.innerHTML += '<span class="sort-arrow"> ↑</span>';
320 header.addEventListener("click", function() {
321 var direction = this.getAttribute(
322 "data-sort-direction"
323 ) || "asc";
324 // Reset arrows in all headers of this table
325 headers.forEach(function(h) {
326 var arrow = h.querySelector(".sort-arrow");
327 if (arrow) arrow.textContent = " ↑";
328 });
329 // Set arrow for clicked header
330 var arrow = this.querySelector(".sort-arrow");
331 arrow.textContent = direction === "asc" ? " ↓" : " ↑";
332 sortTable(table, index, direction);
333 this.setAttribute("data-sort-direction",
334 direction === "asc" ? "desc" : "asc");
335 });
336 });
337 });
338 });
339
340 function sortTable(table, colNum, direction) {
341 var tb = table.tBodies[0];
342 var tr = Array.prototype.slice.call(tb.rows, 0);
343 var multiplier = direction === "asc" ? 1 : -1;
344 tr = tr.sort(function(a, b) {
345 var aText = a.cells[colNum].textContent.trim();
346 var bText = b.cells[colNum].textContent.trim();
347 // Remove arrow from text comparison
348 aText = aText.replace(/[↑↓]/g, '').trim();
349 bText = bText.replace(/[↑↓]/g, '').trim();
350 if (!isNaN(aText) && !isNaN(bText)) {
351 return multiplier * (
352 parseFloat(aText) - parseFloat(bText)
353 );
354 } else {
355 return multiplier * aText.localeCompare(bText);
356 }
357 });
358 for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]);
359 }
360 </script>
361 """
362 html_content += f"""
363 {get_html_closing()}
364 """
365 with open(
366 os.path.join(self.output_dir, "comparison_result.html"),
367 "w"
368 ) as file:
369 file.write(html_content)
370
371 def save_dashboard(self):
372 raise NotImplementedError("Subclasses should implement this method")
373
374 def generate_plots_explainer(self):
375 raise NotImplementedError("Subclasses should implement this method")
376
377 # not working now
378 def generate_tree_plots(self):
379 from sklearn.ensemble import RandomForestClassifier, \
380 RandomForestRegressor
381 from xgboost import XGBClassifier, XGBRegressor
382 from explainerdashboard.explainers import RandomForestExplainer
383
384 LOG.info("Generating tree plots")
385 X_test = self.exp.X_test_transformed.copy()
386 y_test = self.exp.y_test_transformed
387
388 is_rf = isinstance(self.best_model, RandomForestClassifier) or \
389 isinstance(self.best_model, RandomForestRegressor)
390
391 is_xgb = isinstance(self.best_model, XGBClassifier) or \
392 isinstance(self.best_model, XGBRegressor)
393
394 try:
395 if is_rf:
396 num_trees = self.best_model.n_estimators
397 if is_xgb:
398 num_trees = len(self.best_model.get_booster().get_dump())
399 explainer = RandomForestExplainer(self.best_model, X_test, y_test)
400 for i in range(num_trees):
401 fig = explainer.decisiontree_encoded(tree_idx=i, index=0)
402 LOG.info(f"Tree {i+1}")
403 LOG.info(fig)
404 self.trees.append(fig)
405 except Exception as e:
406 LOG.error(f"Error generating tree plots: {e}")
407
408 def run(self):
409 self.load_data()
410 self.setup_pycaret()
411 self.train_model()
412 self.save_model()
413 self.generate_plots()
414 self.generate_plots_explainer()
415 self.generate_tree_plots()
416 self.save_html_report()
417 # self.save_dashboard()