comparison base_model_trainer.py @ 8:1aed7d47c5ec draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author goeckslab
date Fri, 25 Jul 2025 19:02:32 +0000
parents f4cb41f458fd
children
comparison
equal deleted inserted replaced
7:f4cb41f458fd 8:1aed7d47c5ec
1 import base64 1 import base64
2 import logging 2 import logging
3 import os
4 import tempfile 3 import tempfile
4 from pathlib import Path
5 5
6 import h5py 6 import h5py
7 import joblib 7 import joblib
8 import numpy as np 8 import numpy as np
9 import pandas as pd 9 import pandas as pd
10 from feature_help_modal import get_feature_metrics_help_modal 10 from feature_help_modal import get_feature_metrics_help_modal
11 from feature_importance import FeatureImportanceAnalyzer 11 from feature_importance import FeatureImportanceAnalyzer
12 from sklearn.metrics import average_precision_score 12 from sklearn.metrics import average_precision_score
13 from utils import get_html_closing, get_html_template 13 from utils import (
14 add_hr_to_html,
15 add_plot_to_html,
16 build_tabbed_html,
17 encode_image_to_base64,
18 get_html_closing,
19 get_html_template,
20 )
14 21
15 logging.basicConfig(level=logging.DEBUG) 22 logging.basicConfig(level=logging.DEBUG)
16 LOG = logging.getLogger(__name__) 23 LOG = logging.getLogger(__name__)
17 24
18 25
25 task_type, 32 task_type,
26 random_seed, 33 random_seed,
27 test_file=None, 34 test_file=None,
28 **kwargs, 35 **kwargs,
29 ): 36 ):
30 self.exp = None # This will be set in the subclass 37 self.exp = None
31 self.input_file = input_file 38 self.input_file = input_file
32 self.target_col = target_col 39 self.target_col = target_col
33 self.output_dir = output_dir 40 self.output_dir = output_dir
34 self.task_type = task_type 41 self.task_type = task_type
35 self.random_seed = random_seed 42 self.random_seed = random_seed
37 self.target = None 44 self.target = None
38 self.best_model = None 45 self.best_model = None
39 self.results = None 46 self.results = None
40 self.features_name = None 47 self.features_name = None
41 self.plots = {} 48 self.plots = {}
42 self.expaliner = None 49 self.explainer_plots = {}
43 self.plots_explainer_html = None 50 self.plots_explainer_html = None
44 self.trees = [] 51 self.trees = []
45 for key, value in kwargs.items(): 52 self.user_kwargs = kwargs.copy()
53 for key, value in self.user_kwargs.items():
46 setattr(self, key, value) 54 setattr(self, key, value)
47 self.setup_params = {} 55 self.setup_params = {}
48 self.test_file = test_file 56 self.test_file = test_file
49 self.test_data = None 57 self.test_data = None
50 58
55 63
56 def load_data(self): 64 def load_data(self):
57 LOG.info(f"Loading data from {self.input_file}") 65 LOG.info(f"Loading data from {self.input_file}")
58 self.data = pd.read_csv(self.input_file, sep=None, engine="python") 66 self.data = pd.read_csv(self.input_file, sep=None, engine="python")
59 self.data.columns = self.data.columns.str.replace(".", "_") 67 self.data.columns = self.data.columns.str.replace(".", "_")
60
61 # Remove prediction_label if present
62 if "prediction_label" in self.data.columns: 68 if "prediction_label" in self.data.columns:
63 self.data = self.data.drop(columns=["prediction_label"]) 69 self.data = self.data.drop(columns=["prediction_label"])
64 70
65 numeric_cols = self.data.select_dtypes(include=["number"]).columns 71 numeric_cols = self.data.select_dtypes(include=["number"]).columns
66 non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns 72 non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns
67
68 self.data[numeric_cols] = self.data[numeric_cols].apply( 73 self.data[numeric_cols] = self.data[numeric_cols].apply(
69 pd.to_numeric, errors="coerce" 74 pd.to_numeric, errors="coerce"
70 ) 75 )
71
72 if len(non_numeric_cols) > 0: 76 if len(non_numeric_cols) > 0:
73 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") 77 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}")
74 78
75 names = self.data.columns.to_list() 79 names = self.data.columns.to_list()
76 target_index = int(self.target_col) - 1 80 target_index = int(self.target_col) - 1
77 self.target = names[target_index] 81 self.target = names[target_index]
78 self.features_name = [name for i, name in enumerate(names) if i != target_index] 82 self.features_name = [n for i, n in enumerate(names) if i != target_index]
79 if hasattr(self, "missing_value_strategy"): 83
80 if self.missing_value_strategy == "mean": 84 if getattr(self, "missing_value_strategy", None):
85 strat = self.missing_value_strategy
86 if strat == "mean":
81 self.data = self.data.fillna(self.data.mean(numeric_only=True)) 87 self.data = self.data.fillna(self.data.mean(numeric_only=True))
82 elif self.missing_value_strategy == "median": 88 elif strat == "median":
83 self.data = self.data.fillna(self.data.median(numeric_only=True)) 89 self.data = self.data.fillna(self.data.median(numeric_only=True))
84 elif self.missing_value_strategy == "drop": 90 elif strat == "drop":
85 self.data = self.data.dropna() 91 self.data = self.data.dropna()
86 else: 92 else:
87 # Default strategy if not specified
88 self.data = self.data.fillna(self.data.median(numeric_only=True)) 93 self.data = self.data.fillna(self.data.median(numeric_only=True))
89 94
90 if self.test_file: 95 if self.test_file:
91 LOG.info(f"Loading test data from {self.test_file}") 96 LOG.info(f"Loading test data from {self.test_file}")
92 self.test_data = pd.read_csv(self.test_file, sep=None, engine="python") 97 df_test = pd.read_csv(self.test_file, sep=None, engine="python")
93 self.test_data = self.test_data[numeric_cols].apply( 98 df_test.columns = df_test.columns.str.replace(".", "_")
94 pd.to_numeric, errors="coerce" 99 self.test_data = df_test
95 )
96 self.test_data.columns = self.test_data.columns.str.replace(".", "_")
97 100
98 def setup_pycaret(self): 101 def setup_pycaret(self):
99 LOG.info("Initializing PyCaret") 102 LOG.info("Initializing PyCaret")
100 self.setup_params = { 103 self.setup_params = {
101 "target": self.target, 104 "target": self.target,
103 "html": True, 106 "html": True,
104 "log_experiment": False, 107 "log_experiment": False,
105 "system_log": False, 108 "system_log": False,
106 "index": False, 109 "index": False,
107 } 110 }
108
109 if self.test_data is not None: 111 if self.test_data is not None:
110 self.setup_params["test_data"] = self.test_data 112 self.setup_params["test_data"] = self.test_data
111 113 for attr in [
112 if ( 114 "train_size",
113 hasattr(self, "train_size") 115 "normalize",
114 and self.train_size is not None 116 "feature_selection",
115 and self.test_data is None 117 "remove_outliers",
116 ): 118 "remove_multicollinearity",
117 self.setup_params["train_size"] = self.train_size 119 "polynomial_features",
118 120 "feature_interaction",
119 if hasattr(self, "normalize") and self.normalize is not None: 121 "feature_ratio",
120 self.setup_params["normalize"] = self.normalize 122 "fix_imbalance",
121 123 ]:
122 if hasattr(self, "feature_selection") and self.feature_selection is not None: 124 val = getattr(self, attr, None)
123 self.setup_params["feature_selection"] = self.feature_selection 125 if val is not None:
124 126 self.setup_params[attr] = val
125 if ( 127 if getattr(self, "cross_validation_folds", None) is not None:
126 hasattr(self, "cross_validation") 128 self.setup_params["fold"] = self.cross_validation_folds
127 and self.cross_validation is not None
128 and self.cross_validation is False
129 ):
130 logging.info(
131 "cross_validation is set to False. This will disable cross-validation."
132 )
133
134 if hasattr(self, "cross_validation") and self.cross_validation:
135 if hasattr(self, "cross_validation_folds"):
136 self.setup_params["fold"] = self.cross_validation_folds
137
138 if hasattr(self, "remove_outliers") and self.remove_outliers is not None:
139 self.setup_params["remove_outliers"] = self.remove_outliers
140
141 if (
142 hasattr(self, "remove_multicollinearity")
143 and self.remove_multicollinearity is not None
144 ):
145 self.setup_params["remove_multicollinearity"] = (
146 self.remove_multicollinearity
147 )
148
149 if (
150 hasattr(self, "polynomial_features")
151 and self.polynomial_features is not None
152 ):
153 self.setup_params["polynomial_features"] = self.polynomial_features
154
155 if hasattr(self, "fix_imbalance") and self.fix_imbalance is not None:
156 self.setup_params["fix_imbalance"] = self.fix_imbalance
157
158 LOG.info(self.setup_params) 129 LOG.info(self.setup_params)
159 130
160 # Solution: instantiate the correct PyCaret experiment based on task_type
161 if self.task_type == "classification": 131 if self.task_type == "classification":
162 from pycaret.classification import ClassificationExperiment 132 from pycaret.classification import ClassificationExperiment
163 133
164 self.exp = ClassificationExperiment() 134 self.exp = ClassificationExperiment()
165 elif self.task_type == "regression": 135 elif self.task_type == "regression":
168 self.exp = RegressionExperiment() 138 self.exp = RegressionExperiment()
169 else: 139 else:
170 raise ValueError("task_type must be 'classification' or 'regression'") 140 raise ValueError("task_type must be 'classification' or 'regression'")
171 141
172 self.exp.setup(self.data, **self.setup_params) 142 self.exp.setup(self.data, **self.setup_params)
143 self.setup_params.update(self.user_kwargs)
173 144
174 def train_model(self): 145 def train_model(self):
175 LOG.info("Training and selecting the best model") 146 LOG.info("Training and selecting the best model")
176 if self.task_type == "classification": 147 if self.task_type == "classification":
177 average_displayed = "Weighted"
178 self.exp.add_metric( 148 self.exp.add_metric(
179 id=f"PR-AUC-{average_displayed}", 149 id="PR-AUC-Weighted",
180 name=f"PR-AUC-{average_displayed}", 150 name="PR-AUC-Weighted",
181 target="pred_proba", 151 target="pred_proba",
182 score_func=average_precision_score, 152 score_func=average_precision_score,
183 average="weighted", 153 average="weighted",
184 ) 154 )
185 155 # Build arguments for compare_models()
186 if hasattr(self, "models") and self.models is not None: 156 compare_kwargs = {}
187 self.best_model = self.exp.compare_models(include=self.models, cross_validation=self.cross_validation) 157 if getattr(self, "models", None):
188 else: 158 compare_kwargs["include"] = self.models
189 self.best_model = self.exp.compare_models(cross_validation=self.cross_validation) 159
160 # Respect explicit cross-validation flag
161 if getattr(self, "cross_validation", None) is not None:
162 compare_kwargs["cross_validation"] = self.cross_validation
163
164 # Respect explicit fold count
165 if getattr(self, "cross_validation_folds", None) is not None:
166 compare_kwargs["fold"] = self.cross_validation_folds
167
168 LOG.info(f"compare_models kwargs: {compare_kwargs}")
169 self.best_model = self.exp.compare_models(**compare_kwargs)
190 self.results = self.exp.pull() 170 self.results = self.exp.pull()
171 if getattr(self, "tune_model", False):
172 LOG.info("Tuning hyperparameters of the best model")
173 self.best_model = self.exp.tune_model(self.best_model)
174 self.results = self.exp.pull()
191 175
192 if self.task_type == "classification": 176 if self.task_type == "classification":
193 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) 177 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
194
195 _ = self.exp.predict_model(self.best_model) 178 _ = self.exp.predict_model(self.best_model)
196 self.test_result_df = self.exp.pull() 179 self.test_result_df = self.exp.pull()
197 if self.task_type == "classification": 180 if self.task_type == "classification":
198 self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True) 181 self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
199 182
200 def save_model(self): 183 def save_model(self):
201 hdf5_model_path = "pycaret_model.h5" 184 hdf5_path = Path(self.output_dir) / "pycaret_model.h5"
202 with h5py.File(hdf5_model_path, "w") as f: 185 with h5py.File(hdf5_path, "w") as f:
203 with tempfile.NamedTemporaryFile(delete=False) as temp_file: 186 with tempfile.NamedTemporaryFile(delete=False) as tmp:
204 joblib.dump(self.best_model, temp_file.name) 187 joblib.dump(self.best_model, tmp.name)
205 temp_file.seek(0) 188 tmp.seek(0)
206 model_bytes = temp_file.read() 189 model_bytes = tmp.read()
207 f.create_dataset("model", data=np.void(model_bytes)) 190 f.create_dataset("model", data=np.void(model_bytes))
208 191
209 def generate_plots(self): 192 def generate_plots(self):
210 raise NotImplementedError("Subclasses should implement this method") 193 LOG.info("Generating PyCaret diagnostic pltos")
211 194
212 def encode_image_to_base64(self, img_path): 195 # choose the right plots based on task
196 if self.task_type == "classification":
197 plot_names = [
198 "learning",
199 "vc",
200 "calibration",
201 "dimension",
202 "manifold",
203 "rfe",
204 "threshold",
205 "percentage_above_below",
206 "class_report",
207 "pr_auc",
208 "roc_auc",
209 ]
210 else:
211 plot_names = ["residuals", "vc", "parameter", "error", "learning"]
212 for name in plot_names:
213 try:
214 ax = self.exp.plot_model(self.best_model, plot=name, save=False)
215 out_path = Path(self.output_dir) / f"plot_{name}.png"
216 fig = ax.get_figure()
217 fig.savefig(out_path, bbox_inches="tight")
218 self.plots[name] = str(out_path)
219 except Exception as e:
220 LOG.warning(f"Could not generate {name} plot: {e}")
221
222 def encode_image_to_base64(self, img_path: str) -> str:
213 with open(img_path, "rb") as img_file: 223 with open(img_path, "rb") as img_file:
214 return base64.b64encode(img_file.read()).decode("utf-8") 224 return base64.b64encode(img_file.read()).decode("utf-8")
215 225
216 def save_html_report(self): 226 def save_html_report(self):
217 LOG.info("Saving HTML report") 227 LOG.info("Saving HTML report")
218 228
219 if not self.output_dir: 229 # 1) Determine best model name
220 raise ValueError("output_dir must be specified and not None") 230 try:
221 231 best_model_name = str(self.results.iloc[0]["Model"])
222 model_name = type(self.best_model).__name__ 232 except Exception:
223 excluded_params = ["html", "log_experiment", "system_log", "test_data"] 233 best_model_name = type(self.best_model).__name__
224 filtered_setup_params = { 234 LOG.info(f"Best model determined as: {best_model_name}")
225 k: v for k, v in self.setup_params.items() if k not in excluded_params 235
236 # 2) Compute training sample count
237 try:
238 n_train = self.exp.X_train.shape[0]
239 except Exception:
240 n_train = getattr(self.exp, "X_train_transformed", pd.DataFrame()).shape[0]
241 total_rows = self.data.shape[0]
242
243 # 3) Build setup parameters table
244 all_params = self.setup_params
245 display_keys = [
246 "Target",
247 "Session ID",
248 "Train Size",
249 "Normalize",
250 "Feature Selection",
251 "Cross Validation",
252 "Cross Validation Folds",
253 "Remove Outliers",
254 "Remove Multicollinearity",
255 "Polynomial Features",
256 "Fix Imbalance",
257 "Models",
258 ]
259 setup_rows = []
260 for key in display_keys:
261 pk = key.lower().replace(" ", "_")
262 v = all_params.get(pk)
263 if key == "Train Size":
264 frac = (
265 float(v)
266 if v is not None
267 else (n_train / total_rows if total_rows else 0)
268 )
269 dv = f"{frac:.2f} ({n_train} rows)"
270 elif key in {
271 "Normalize",
272 "Feature Selection",
273 "Cross Validation",
274 "Remove Outliers",
275 "Remove Multicollinearity",
276 "Polynomial Features",
277 "Fix Imbalance",
278 }:
279 dv = bool(v)
280 elif key == "Cross Validation Folds":
281 dv = v if v is not None else "None"
282 elif key == "Models":
283 dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None"
284 else:
285 dv = v if v is not None else "None"
286 setup_rows.append([key, dv])
287 if hasattr(self.exp, "_fold_metric"):
288 setup_rows.append(["best_model_metric", self.exp._fold_metric])
289
290 df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"])
291 df_setup.to_csv(Path(self.output_dir) / "setup_params.csv", index=False)
292
293 # 4) Persist CSVs
294 self.results.to_csv(
295 Path(self.output_dir) / "comparison_results.csv", index=False
296 )
297 self.test_result_df.to_csv(
298 Path(self.output_dir) / "test_results.csv", index=False
299 )
300 pd.DataFrame(
301 self.best_model.get_params().items(), columns=["Parameter", "Value"]
302 ).to_csv(Path(self.output_dir) / "best_model.csv", index=False)
303
304 # 5) Header
305 header = f"<h2>Best Model: {best_model_name}</h2>"
306
307 # — Validation Summary & Configuration —
308 val_df = self.results.copy()
309 # mapping raw plot keys to user-friendly titles
310 plot_title_map = {
311 "learning": "Learning Curve",
312 "vc": "Validation Curve",
313 "calibration": "Calibration Curve",
314 "dimension": "Dimensionality Reduction",
315 "manifold": "Manifold Learning",
316 "rfe": "Recursive Feature Elimination",
317 "threshold": "Threshold Plot",
318 "percentage_above_below": "Percentage Above vs. Below Cutoff",
319 "class_report": "Classification Report",
320 "pr_auc": "Precision-Recall AUC",
321 "roc_auc": "Receiver Operating Characteristic AUC",
322 "residuals": "Residuals Distribution",
323 "error": "Prediction Error Distribution",
226 } 324 }
227 setup_params_table = pd.DataFrame( 325 val_df.drop(columns=["TT (Ec)", "TT (Sec)"], errors="ignore", inplace=True)
228 list(filtered_setup_params.items()), columns=["Parameter", "Value"] 326 summary_html = (
327 header
328 + "<h2>Train & Validation Summary</h2>"
329 + '<div class="table-wrapper">'
330 + val_df.to_html(index=False, classes="table sortable")
331 + "</div>"
332 + "<h2>Setup Parameters</h2>"
333 + '<div class="table-wrapper">'
334 + df_setup.to_html(index=False, classes="table sortable")
335 + "</div>"
336 # — Hyperparameters
337 + "<h2>Best Model Hyperparameters</h2>"
338 + '<div class="table-wrapper">'
339 + pd.DataFrame(
340 self.best_model.get_params().items(), columns=["Parameter", "Value"]
341 ).to_html(index=False, classes="table sortable")
342 + "</div>"
229 ) 343 )
230 344
231 best_model_params = pd.DataFrame( 345 # choose summary plots based on task type
232 self.best_model.get_params().items(), columns=["Parameter", "Value"] 346 if self.task_type == "classification":
347 summary_plots = [
348 "learning",
349 "vc",
350 "calibration",
351 "dimension",
352 "manifold",
353 "rfe",
354 "threshold",
355 "percentage_above_below",
356 ]
357 else:
358 summary_plots = ["learning", "vc", "parameter", "residuals"]
359
360 for name in summary_plots:
361 if name in self.plots:
362 summary_html += "<hr>"
363 b64 = encode_image_to_base64(self.plots[name])
364 title = plot_title_map.get(name, name.replace("_", " ").title())
365 summary_html += (
366 '<div class="plot">'
367 f"<h2>{title}</h2>"
368 f'<img src="data:image/png;base64,{b64}" '
369 'style="max-width:90%;max-height:600px;border:1px solid #ddd;"/>'
370 "</div>"
371 )
372
373 # — Test Summary —
374 test_html = (
375 header
376 + '<div class="table-wrapper">'
377 + self.test_result_df.to_html(index=False, classes="table sortable")
378 + "</div>"
233 ) 379 )
234 best_model_params.to_csv( 380 if self.task_type == "regression":
235 os.path.join(self.output_dir, "best_model.csv"), index=False 381 try:
236 ) 382 y_true = (
237 self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv")) 383 pd.Series(self.exp.y_test_transformed)
238 self.test_result_df.to_csv(os.path.join(self.output_dir, "test_results.csv")) 384 .reset_index(drop=True)
239 385 .rename("True")
240 plots_html = "" 386 )
241 length = len(self.plots) 387 y_pred = pd.Series(
242 for i, (plot_name, plot_path) in enumerate(self.plots.items()): 388 self.best_model.predict(self.exp.X_test_transformed)
243 encoded_image = self.encode_image_to_base64(plot_path) 389 ).rename("Predicted")
244 plots_html += ( 390 df_tp = pd.concat([y_true, y_pred], axis=1)
245 f'<div class="plot">' 391 test_html += "<h2>True vs Predicted Values</h2>"
246 f"<h3>{plot_name.capitalize()}</h3>" 392 test_html += (
247 f'<img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">' 393 '<div class="table-wrapper" style="max-height:400px; overflow-y:auto;">'
248 f"</div>" 394 + df_tp.head(50).to_html(index=False, classes="table sortable")
249 ) 395 + "</div>"
250 if i < length - 1: 396 + add_hr_to_html()
251 plots_html += "<hr>" 397 )
252 398 except Exception as e:
253 tree_plots = "" 399 LOG.warning(f"Could not generate True vs Predicted table: {e}")
254 for i, tree in enumerate(self.trees): 400
255 if tree: 401 # 5a) Explainer-substituted plots in order
256 tree_plots += ( 402 if self.task_type == "regression":
257 f'<div class="plot">' 403 test_order = ["residuals"]
258 f"<h3>Tree {i + 1}</h3>" 404 else:
259 f'<img src="data:image/png;base64,{tree}" alt="tree {i + 1}">' 405 test_order = [
260 f"</div>" 406 "confusion_matrix",
261 ) 407 "roc_auc",
262 408 "pr_auc",
263 analyzer = FeatureImportanceAnalyzer( 409 "lift_curve",
410 "threshold",
411 "cumulative_precision",
412 ]
413 for key in test_order:
414 fig_or_fn = self.explainer_plots.pop(key, None)
415 if fig_or_fn is not None:
416 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
417 title = plot_title_map.get(key, key.replace("_", " ").title())
418 test_html += (
419 f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html()
420 )
421 # 5b) Remaining PyCaret test plots
422 for name, path in self.plots.items():
423 # classification: include only the small extras, before skipping anything
424 if self.task_type == "classification" and name in {
425 "threshold",
426 "pr_auc",
427 "class_report",
428 }:
429 title = plot_title_map.get(name, name.replace("_", " ").title())
430 b64 = encode_image_to_base64(path)
431 test_html += (
432 f"<h2>{title}</h2>"
433 "<div class='plot'>"
434 f"<img src='data:image/png;base64,{b64}' "
435 "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>"
436 "</div>" + add_hr_to_html()
437 )
438 continue
439
440 # regression: explicitly include the 'error' plot, before skipping
441 if self.task_type == "regression" and name == "error":
442 title = plot_title_map.get("error", "Prediction Error Distribution")
443 b64 = encode_image_to_base64(path)
444 test_html += (
445 f"<h2>{title}</h2>"
446 "<div class='plot'>"
447 f"<img src='data:image/png;base64,{b64}' "
448 "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>"
449 "</div>" + add_hr_to_html()
450 )
451 continue
452
453 # now skip any plots already rendered via test_order
454 if name in test_order:
455 continue
456
457 # — Feature Importance —
458 feature_html = header
459
460 # 6a) PyCaret’s default feature importances
461 feature_html += FeatureImportanceAnalyzer(
264 data=self.data, 462 data=self.data,
265 target_col=self.target_col, 463 target_col=self.target_col,
266 task_type=self.task_type, 464 task_type=self.task_type,
267 output_dir=self.output_dir, 465 output_dir=self.output_dir,
268 exp=self.exp, 466 exp=self.exp,
269 best_model=self.best_model, 467 best_model=self.best_model,
468 ).run()
469
470 # 6b) Explainer SHAP importances
471 for key in ["shap_mean", "shap_perm"]:
472 fig_or_fn = self.explainer_plots.pop(key, None)
473 if fig_or_fn is not None:
474 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
475 # give SHAP plots explicit titles
476 title = (
477 "Mean Absolute SHAP Value Impact"
478 if key == "shap_mean"
479 else "Permutation Feature Importance"
480 )
481 feature_html += (
482 f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html()
483 )
484
485 # 6c) PDPs last
486 pdp_keys = sorted(k for k in self.explainer_plots if k.startswith("pdp__"))
487 for k in pdp_keys:
488 fig_or_fn = self.explainer_plots[k]
489 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
490 # extract feature name
491 feature = k.split("__", 1)[1]
492 title = f"Partial Dependence for {feature}"
493 feature_html += (
494 f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html()
495 )
496 # 7) Assemble final HTML (three tabs)
497 html = get_html_template()
498 html += "<h1>Tabular Learner Model Report</h1>"
499 html += build_tabbed_html(summary_html, test_html, feature_html)
500 html += get_feature_metrics_help_modal()
501 html += get_html_closing()
502
503 # 8) Write out
504 (Path(self.output_dir) / "comparison_result.html").write_text(
505 html, encoding="utf-8"
270 ) 506 )
271 feature_importance_html = analyzer.run() 507 LOG.info(f"HTML report generated at: {self.output_dir}/comparison_result.html")
272
273 # --- Feature Metrics Help Button ---
274 feature_metrics_button_html = (
275 '<button class="help-modal-btn" id="openFeatureMetricsHelp" style="margin-bottom:12px;">'
276 "Help: Metrics Guide"
277 "</button>"
278 "<style>"
279 ".help-modal-btn {"
280 "background-color: #17623b;"
281 "color: #fff;"
282 "border: none;"
283 "border-radius: 24px;"
284 "padding: 10px 28px;"
285 "font-size: 1.1rem;"
286 "font-weight: bold;"
287 "letter-spacing: 0.03em;"
288 "cursor: pointer;"
289 "transition: background 0.2s, box-shadow 0.2s;"
290 "box-shadow: 0 2px 8px rgba(23,98,59,0.07);"
291 "}"
292 ".help-modal-btn:hover, .help-modal-btn:focus {"
293 "background-color: #21895e;"
294 "outline: none;"
295 "box-shadow: 0 4px 16px rgba(23,98,59,0.14);"
296 "}"
297 "</style>"
298 )
299
300 html_content = (
301 f"{get_html_template()}"
302 "<h1>Tabular Learner Model Report</h1>"
303 f"{feature_metrics_button_html}"
304 '<div class="tabs">'
305 '<div class="tab" onclick="openTab(event, \'summary\')">'
306 "Validation Result Summary & Config</div>"
307 '<div class="tab" onclick="openTab(event, \'plots\')">'
308 "Test Results</div>"
309 '<div class="tab" onclick="openTab(event, \'feature\')">'
310 "Feature Importance</div>"
311 )
312 if self.plots_explainer_html:
313 html_content += (
314 '<div class="tab" onclick="openTab(event, \'explainer\')">'
315 "Explainer Plots</div>"
316 )
317 html_content += (
318 "</div>"
319 '<div id="summary" class="tab-content">'
320 f"<h2>Model Metrics from {'Cross-Validation Set' if self.cross_validation else 'Validation set'}</h2>"
321 f"<h2>Best Model: {model_name}</h2>"
322 "<h5>The best model is selected by: Accuracy (Classification)"
323 " or R2 (Regression).</h5>"
324 f"{self.results.to_html(index=False, classes='table sortable')}"
325 "<h2>Best Model's Hyperparameters</h2>"
326 f"{best_model_params.to_html(index=False, header=True, classes='table sortable')}"
327 "<h2>Setup Parameters</h2>"
328 f"{setup_params_table.to_html(index=False, header=True, classes='table sortable')}"
329 "<h5>If you want to know all the experiment setup parameters,"
330 " please check the PyCaret documentation for"
331 " the classification/regression <code>exp</code> function.</h5>"
332 "</div>"
333 '<div id="plots" class="tab-content">'
334 f"<h2>Best Model: {model_name}</h2>"
335 "<h5>The best model is selected by: Accuracy (Classification)"
336 " or R2 (Regression).</h5>"
337 "<h2>Test Metrics</h2>"
338 f"{self.test_result_df.to_html(index=False)}"
339 "<h2>Test Results</h2>"
340 f"{plots_html}"
341 "</div>"
342 '<div id="feature" class="tab-content">'
343 f"{feature_importance_html}"
344 "</div>"
345 )
346 if self.plots_explainer_html:
347 html_content += (
348 '<div id="explainer" class="tab-content">'
349 f"{self.plots_explainer_html}"
350 f"{tree_plots}"
351 "</div>"
352 )
353 html_content += (
354 "<script>"
355 "document.addEventListener(\"DOMContentLoaded\", function() {"
356 "var tables = document.querySelectorAll(\"table.sortable\");"
357 "tables.forEach(function(table) {"
358 "var headers = table.querySelectorAll(\"th\");"
359 "headers.forEach(function(header, index) {"
360 "header.style.cursor = \"pointer\";"
361 "// Add initial arrow (up) to indicate sortability, use Unicode ↑ (U+2191)"
362 "header.innerHTML += '<span class=\"sort-arrow\"> ↑</span>';"
363 "header.addEventListener(\"click\", function() {"
364 "var direction = this.getAttribute("
365 "\"data-sort-direction\""
366 ") || \"asc\";"
367 "// Reset arrows in all headers of this table"
368 "headers.forEach(function(h) {"
369 "var arrow = h.querySelector(\".sort-arrow\");"
370 "if (arrow) arrow.textContent = \" ↑\";"
371 "});"
372 "// Set arrow for clicked header"
373 "var arrow = this.querySelector(\".sort-arrow\");"
374 "arrow.textContent = direction === \"asc\" ? \" ↓\" : \" ↑\";"
375 "sortTable(table, index, direction);"
376 "this.setAttribute(\"data-sort-direction\","
377 "direction === \"asc\" ? \"desc\" : \"asc\");"
378 "});"
379 "});"
380 "});"
381 "});"
382 "function sortTable(table, colNum, direction) {"
383 "var tb = table.tBodies[0];"
384 "var tr = Array.prototype.slice.call(tb.rows, 0);"
385 "var multiplier = direction === \"asc\" ? 1 : -1;"
386 "tr = tr.sort(function(a, b) {"
387 "var aText = a.cells[colNum].textContent.trim();"
388 "var bText = b.cells[colNum].textContent.trim();"
389 "// Remove arrow from text comparison"
390 "aText = aText.replace(/[↑↓]/g, '').trim();"
391 "bText = bText.replace(/[↑↓]/g, '').trim();"
392 "if (!isNaN(aText) && !isNaN(bText)) {"
393 "return multiplier * ("
394 "parseFloat(aText) - parseFloat(bText)"
395 ");"
396 "} else {"
397 "return multiplier * aText.localeCompare(bText);"
398 "}"
399 "});"
400 "for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]);"
401 "}"
402 "</script>"
403 )
404 # --- Add the Feature Metrics Help Modal ---
405 html_content += get_feature_metrics_help_modal()
406 html_content += f"{get_html_closing()}"
407 with open(
408 os.path.join(self.output_dir, "comparison_result.html"),
409 "w",
410 encoding="utf-8",
411 ) as file:
412 file.write(html_content)
413 508
414 def save_dashboard(self): 509 def save_dashboard(self):
415 raise NotImplementedError("Subclasses should implement this method") 510 raise NotImplementedError("Subclasses should implement this method")
416 511
417 def generate_plots_explainer(self): 512 def generate_plots_explainer(self):
424 519
425 LOG.info("Generating tree plots") 520 LOG.info("Generating tree plots")
426 X_test = self.exp.X_test_transformed.copy() 521 X_test = self.exp.X_test_transformed.copy()
427 y_test = self.exp.y_test_transformed 522 y_test = self.exp.y_test_transformed
428 523
429 is_rf = isinstance( 524 if isinstance(self.best_model, (RandomForestClassifier, RandomForestRegressor)):
430 self.best_model, (RandomForestClassifier, RandomForestRegressor) 525 n_trees = self.best_model.n_estimators
431 ) 526 elif isinstance(self.best_model, (XGBClassifier, XGBRegressor)):
432 is_xgb = isinstance(self.best_model, (XGBClassifier, XGBRegressor)) 527 n_trees = len(self.best_model.get_booster().get_dump())
433
434 num_trees = None
435 if is_rf:
436 num_trees = self.best_model.n_estimators
437 elif is_xgb:
438 num_trees = len(self.best_model.get_booster().get_dump())
439 else: 528 else:
440 LOG.warning("Tree plots not supported for this model type.") 529 LOG.warning("Tree plots not supported for this model type.")
441 return 530 return
442 531
443 try: 532 explainer = RandomForestExplainer(self.best_model, X_test, y_test)
444 explainer = RandomForestExplainer(self.best_model, X_test, y_test) 533 for i in range(n_trees):
445 for i in range(num_trees): 534 fig = explainer.decisiontree_encoded(tree_idx=i, index=0)
446 fig = explainer.decisiontree_encoded(tree_idx=i, index=0) 535 self.trees.append(fig)
447 LOG.info(f"Tree {i + 1}")
448 LOG.info(fig)
449 self.trees.append(fig)
450 except Exception as e:
451 LOG.error(f"Error generating tree plots: {e}")
452 536
453 def run(self): 537 def run(self):
454 self.load_data() 538 self.load_data()
455 self.setup_pycaret() 539 self.setup_pycaret()
456 self.train_model() 540 self.train_model()