Mercurial > repos > goeckslab > pycaret_predict
comparison base_model_trainer.py @ 8:1aed7d47c5ec draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author | goeckslab |
---|---|
date | Fri, 25 Jul 2025 19:02:32 +0000 |
parents | f4cb41f458fd |
children |
comparison
equal
deleted
inserted
replaced
7:f4cb41f458fd | 8:1aed7d47c5ec |
---|---|
1 import base64 | 1 import base64 |
2 import logging | 2 import logging |
3 import os | |
4 import tempfile | 3 import tempfile |
4 from pathlib import Path | |
5 | 5 |
6 import h5py | 6 import h5py |
7 import joblib | 7 import joblib |
8 import numpy as np | 8 import numpy as np |
9 import pandas as pd | 9 import pandas as pd |
10 from feature_help_modal import get_feature_metrics_help_modal | 10 from feature_help_modal import get_feature_metrics_help_modal |
11 from feature_importance import FeatureImportanceAnalyzer | 11 from feature_importance import FeatureImportanceAnalyzer |
12 from sklearn.metrics import average_precision_score | 12 from sklearn.metrics import average_precision_score |
13 from utils import get_html_closing, get_html_template | 13 from utils import ( |
14 add_hr_to_html, | |
15 add_plot_to_html, | |
16 build_tabbed_html, | |
17 encode_image_to_base64, | |
18 get_html_closing, | |
19 get_html_template, | |
20 ) | |
14 | 21 |
15 logging.basicConfig(level=logging.DEBUG) | 22 logging.basicConfig(level=logging.DEBUG) |
16 LOG = logging.getLogger(__name__) | 23 LOG = logging.getLogger(__name__) |
17 | 24 |
18 | 25 |
25 task_type, | 32 task_type, |
26 random_seed, | 33 random_seed, |
27 test_file=None, | 34 test_file=None, |
28 **kwargs, | 35 **kwargs, |
29 ): | 36 ): |
30 self.exp = None # This will be set in the subclass | 37 self.exp = None |
31 self.input_file = input_file | 38 self.input_file = input_file |
32 self.target_col = target_col | 39 self.target_col = target_col |
33 self.output_dir = output_dir | 40 self.output_dir = output_dir |
34 self.task_type = task_type | 41 self.task_type = task_type |
35 self.random_seed = random_seed | 42 self.random_seed = random_seed |
37 self.target = None | 44 self.target = None |
38 self.best_model = None | 45 self.best_model = None |
39 self.results = None | 46 self.results = None |
40 self.features_name = None | 47 self.features_name = None |
41 self.plots = {} | 48 self.plots = {} |
42 self.expaliner = None | 49 self.explainer_plots = {} |
43 self.plots_explainer_html = None | 50 self.plots_explainer_html = None |
44 self.trees = [] | 51 self.trees = [] |
45 for key, value in kwargs.items(): | 52 self.user_kwargs = kwargs.copy() |
53 for key, value in self.user_kwargs.items(): | |
46 setattr(self, key, value) | 54 setattr(self, key, value) |
47 self.setup_params = {} | 55 self.setup_params = {} |
48 self.test_file = test_file | 56 self.test_file = test_file |
49 self.test_data = None | 57 self.test_data = None |
50 | 58 |
55 | 63 |
56 def load_data(self): | 64 def load_data(self): |
57 LOG.info(f"Loading data from {self.input_file}") | 65 LOG.info(f"Loading data from {self.input_file}") |
58 self.data = pd.read_csv(self.input_file, sep=None, engine="python") | 66 self.data = pd.read_csv(self.input_file, sep=None, engine="python") |
59 self.data.columns = self.data.columns.str.replace(".", "_") | 67 self.data.columns = self.data.columns.str.replace(".", "_") |
60 | |
61 # Remove prediction_label if present | |
62 if "prediction_label" in self.data.columns: | 68 if "prediction_label" in self.data.columns: |
63 self.data = self.data.drop(columns=["prediction_label"]) | 69 self.data = self.data.drop(columns=["prediction_label"]) |
64 | 70 |
65 numeric_cols = self.data.select_dtypes(include=["number"]).columns | 71 numeric_cols = self.data.select_dtypes(include=["number"]).columns |
66 non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns | 72 non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns |
67 | |
68 self.data[numeric_cols] = self.data[numeric_cols].apply( | 73 self.data[numeric_cols] = self.data[numeric_cols].apply( |
69 pd.to_numeric, errors="coerce" | 74 pd.to_numeric, errors="coerce" |
70 ) | 75 ) |
71 | |
72 if len(non_numeric_cols) > 0: | 76 if len(non_numeric_cols) > 0: |
73 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") | 77 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") |
74 | 78 |
75 names = self.data.columns.to_list() | 79 names = self.data.columns.to_list() |
76 target_index = int(self.target_col) - 1 | 80 target_index = int(self.target_col) - 1 |
77 self.target = names[target_index] | 81 self.target = names[target_index] |
78 self.features_name = [name for i, name in enumerate(names) if i != target_index] | 82 self.features_name = [n for i, n in enumerate(names) if i != target_index] |
79 if hasattr(self, "missing_value_strategy"): | 83 |
80 if self.missing_value_strategy == "mean": | 84 if getattr(self, "missing_value_strategy", None): |
85 strat = self.missing_value_strategy | |
86 if strat == "mean": | |
81 self.data = self.data.fillna(self.data.mean(numeric_only=True)) | 87 self.data = self.data.fillna(self.data.mean(numeric_only=True)) |
82 elif self.missing_value_strategy == "median": | 88 elif strat == "median": |
83 self.data = self.data.fillna(self.data.median(numeric_only=True)) | 89 self.data = self.data.fillna(self.data.median(numeric_only=True)) |
84 elif self.missing_value_strategy == "drop": | 90 elif strat == "drop": |
85 self.data = self.data.dropna() | 91 self.data = self.data.dropna() |
86 else: | 92 else: |
87 # Default strategy if not specified | |
88 self.data = self.data.fillna(self.data.median(numeric_only=True)) | 93 self.data = self.data.fillna(self.data.median(numeric_only=True)) |
89 | 94 |
90 if self.test_file: | 95 if self.test_file: |
91 LOG.info(f"Loading test data from {self.test_file}") | 96 LOG.info(f"Loading test data from {self.test_file}") |
92 self.test_data = pd.read_csv(self.test_file, sep=None, engine="python") | 97 df_test = pd.read_csv(self.test_file, sep=None, engine="python") |
93 self.test_data = self.test_data[numeric_cols].apply( | 98 df_test.columns = df_test.columns.str.replace(".", "_") |
94 pd.to_numeric, errors="coerce" | 99 self.test_data = df_test |
95 ) | |
96 self.test_data.columns = self.test_data.columns.str.replace(".", "_") | |
97 | 100 |
98 def setup_pycaret(self): | 101 def setup_pycaret(self): |
99 LOG.info("Initializing PyCaret") | 102 LOG.info("Initializing PyCaret") |
100 self.setup_params = { | 103 self.setup_params = { |
101 "target": self.target, | 104 "target": self.target, |
103 "html": True, | 106 "html": True, |
104 "log_experiment": False, | 107 "log_experiment": False, |
105 "system_log": False, | 108 "system_log": False, |
106 "index": False, | 109 "index": False, |
107 } | 110 } |
108 | |
109 if self.test_data is not None: | 111 if self.test_data is not None: |
110 self.setup_params["test_data"] = self.test_data | 112 self.setup_params["test_data"] = self.test_data |
111 | 113 for attr in [ |
112 if ( | 114 "train_size", |
113 hasattr(self, "train_size") | 115 "normalize", |
114 and self.train_size is not None | 116 "feature_selection", |
115 and self.test_data is None | 117 "remove_outliers", |
116 ): | 118 "remove_multicollinearity", |
117 self.setup_params["train_size"] = self.train_size | 119 "polynomial_features", |
118 | 120 "feature_interaction", |
119 if hasattr(self, "normalize") and self.normalize is not None: | 121 "feature_ratio", |
120 self.setup_params["normalize"] = self.normalize | 122 "fix_imbalance", |
121 | 123 ]: |
122 if hasattr(self, "feature_selection") and self.feature_selection is not None: | 124 val = getattr(self, attr, None) |
123 self.setup_params["feature_selection"] = self.feature_selection | 125 if val is not None: |
124 | 126 self.setup_params[attr] = val |
125 if ( | 127 if getattr(self, "cross_validation_folds", None) is not None: |
126 hasattr(self, "cross_validation") | 128 self.setup_params["fold"] = self.cross_validation_folds |
127 and self.cross_validation is not None | |
128 and self.cross_validation is False | |
129 ): | |
130 logging.info( | |
131 "cross_validation is set to False. This will disable cross-validation." | |
132 ) | |
133 | |
134 if hasattr(self, "cross_validation") and self.cross_validation: | |
135 if hasattr(self, "cross_validation_folds"): | |
136 self.setup_params["fold"] = self.cross_validation_folds | |
137 | |
138 if hasattr(self, "remove_outliers") and self.remove_outliers is not None: | |
139 self.setup_params["remove_outliers"] = self.remove_outliers | |
140 | |
141 if ( | |
142 hasattr(self, "remove_multicollinearity") | |
143 and self.remove_multicollinearity is not None | |
144 ): | |
145 self.setup_params["remove_multicollinearity"] = ( | |
146 self.remove_multicollinearity | |
147 ) | |
148 | |
149 if ( | |
150 hasattr(self, "polynomial_features") | |
151 and self.polynomial_features is not None | |
152 ): | |
153 self.setup_params["polynomial_features"] = self.polynomial_features | |
154 | |
155 if hasattr(self, "fix_imbalance") and self.fix_imbalance is not None: | |
156 self.setup_params["fix_imbalance"] = self.fix_imbalance | |
157 | |
158 LOG.info(self.setup_params) | 129 LOG.info(self.setup_params) |
159 | 130 |
160 # Solution: instantiate the correct PyCaret experiment based on task_type | |
161 if self.task_type == "classification": | 131 if self.task_type == "classification": |
162 from pycaret.classification import ClassificationExperiment | 132 from pycaret.classification import ClassificationExperiment |
163 | 133 |
164 self.exp = ClassificationExperiment() | 134 self.exp = ClassificationExperiment() |
165 elif self.task_type == "regression": | 135 elif self.task_type == "regression": |
168 self.exp = RegressionExperiment() | 138 self.exp = RegressionExperiment() |
169 else: | 139 else: |
170 raise ValueError("task_type must be 'classification' or 'regression'") | 140 raise ValueError("task_type must be 'classification' or 'regression'") |
171 | 141 |
172 self.exp.setup(self.data, **self.setup_params) | 142 self.exp.setup(self.data, **self.setup_params) |
143 self.setup_params.update(self.user_kwargs) | |
173 | 144 |
174 def train_model(self): | 145 def train_model(self): |
175 LOG.info("Training and selecting the best model") | 146 LOG.info("Training and selecting the best model") |
176 if self.task_type == "classification": | 147 if self.task_type == "classification": |
177 average_displayed = "Weighted" | |
178 self.exp.add_metric( | 148 self.exp.add_metric( |
179 id=f"PR-AUC-{average_displayed}", | 149 id="PR-AUC-Weighted", |
180 name=f"PR-AUC-{average_displayed}", | 150 name="PR-AUC-Weighted", |
181 target="pred_proba", | 151 target="pred_proba", |
182 score_func=average_precision_score, | 152 score_func=average_precision_score, |
183 average="weighted", | 153 average="weighted", |
184 ) | 154 ) |
185 | 155 # Build arguments for compare_models() |
186 if hasattr(self, "models") and self.models is not None: | 156 compare_kwargs = {} |
187 self.best_model = self.exp.compare_models(include=self.models, cross_validation=self.cross_validation) | 157 if getattr(self, "models", None): |
188 else: | 158 compare_kwargs["include"] = self.models |
189 self.best_model = self.exp.compare_models(cross_validation=self.cross_validation) | 159 |
160 # Respect explicit cross-validation flag | |
161 if getattr(self, "cross_validation", None) is not None: | |
162 compare_kwargs["cross_validation"] = self.cross_validation | |
163 | |
164 # Respect explicit fold count | |
165 if getattr(self, "cross_validation_folds", None) is not None: | |
166 compare_kwargs["fold"] = self.cross_validation_folds | |
167 | |
168 LOG.info(f"compare_models kwargs: {compare_kwargs}") | |
169 self.best_model = self.exp.compare_models(**compare_kwargs) | |
190 self.results = self.exp.pull() | 170 self.results = self.exp.pull() |
171 if getattr(self, "tune_model", False): | |
172 LOG.info("Tuning hyperparameters of the best model") | |
173 self.best_model = self.exp.tune_model(self.best_model) | |
174 self.results = self.exp.pull() | |
191 | 175 |
192 if self.task_type == "classification": | 176 if self.task_type == "classification": |
193 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) | 177 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) |
194 | |
195 _ = self.exp.predict_model(self.best_model) | 178 _ = self.exp.predict_model(self.best_model) |
196 self.test_result_df = self.exp.pull() | 179 self.test_result_df = self.exp.pull() |
197 if self.task_type == "classification": | 180 if self.task_type == "classification": |
198 self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True) | 181 self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True) |
199 | 182 |
200 def save_model(self): | 183 def save_model(self): |
201 hdf5_model_path = "pycaret_model.h5" | 184 hdf5_path = Path(self.output_dir) / "pycaret_model.h5" |
202 with h5py.File(hdf5_model_path, "w") as f: | 185 with h5py.File(hdf5_path, "w") as f: |
203 with tempfile.NamedTemporaryFile(delete=False) as temp_file: | 186 with tempfile.NamedTemporaryFile(delete=False) as tmp: |
204 joblib.dump(self.best_model, temp_file.name) | 187 joblib.dump(self.best_model, tmp.name) |
205 temp_file.seek(0) | 188 tmp.seek(0) |
206 model_bytes = temp_file.read() | 189 model_bytes = tmp.read() |
207 f.create_dataset("model", data=np.void(model_bytes)) | 190 f.create_dataset("model", data=np.void(model_bytes)) |
208 | 191 |
209 def generate_plots(self): | 192 def generate_plots(self): |
210 raise NotImplementedError("Subclasses should implement this method") | 193 LOG.info("Generating PyCaret diagnostic pltos") |
211 | 194 |
212 def encode_image_to_base64(self, img_path): | 195 # choose the right plots based on task |
196 if self.task_type == "classification": | |
197 plot_names = [ | |
198 "learning", | |
199 "vc", | |
200 "calibration", | |
201 "dimension", | |
202 "manifold", | |
203 "rfe", | |
204 "threshold", | |
205 "percentage_above_below", | |
206 "class_report", | |
207 "pr_auc", | |
208 "roc_auc", | |
209 ] | |
210 else: | |
211 plot_names = ["residuals", "vc", "parameter", "error", "learning"] | |
212 for name in plot_names: | |
213 try: | |
214 ax = self.exp.plot_model(self.best_model, plot=name, save=False) | |
215 out_path = Path(self.output_dir) / f"plot_{name}.png" | |
216 fig = ax.get_figure() | |
217 fig.savefig(out_path, bbox_inches="tight") | |
218 self.plots[name] = str(out_path) | |
219 except Exception as e: | |
220 LOG.warning(f"Could not generate {name} plot: {e}") | |
221 | |
222 def encode_image_to_base64(self, img_path: str) -> str: | |
213 with open(img_path, "rb") as img_file: | 223 with open(img_path, "rb") as img_file: |
214 return base64.b64encode(img_file.read()).decode("utf-8") | 224 return base64.b64encode(img_file.read()).decode("utf-8") |
215 | 225 |
216 def save_html_report(self): | 226 def save_html_report(self): |
217 LOG.info("Saving HTML report") | 227 LOG.info("Saving HTML report") |
218 | 228 |
219 if not self.output_dir: | 229 # 1) Determine best model name |
220 raise ValueError("output_dir must be specified and not None") | 230 try: |
221 | 231 best_model_name = str(self.results.iloc[0]["Model"]) |
222 model_name = type(self.best_model).__name__ | 232 except Exception: |
223 excluded_params = ["html", "log_experiment", "system_log", "test_data"] | 233 best_model_name = type(self.best_model).__name__ |
224 filtered_setup_params = { | 234 LOG.info(f"Best model determined as: {best_model_name}") |
225 k: v for k, v in self.setup_params.items() if k not in excluded_params | 235 |
236 # 2) Compute training sample count | |
237 try: | |
238 n_train = self.exp.X_train.shape[0] | |
239 except Exception: | |
240 n_train = getattr(self.exp, "X_train_transformed", pd.DataFrame()).shape[0] | |
241 total_rows = self.data.shape[0] | |
242 | |
243 # 3) Build setup parameters table | |
244 all_params = self.setup_params | |
245 display_keys = [ | |
246 "Target", | |
247 "Session ID", | |
248 "Train Size", | |
249 "Normalize", | |
250 "Feature Selection", | |
251 "Cross Validation", | |
252 "Cross Validation Folds", | |
253 "Remove Outliers", | |
254 "Remove Multicollinearity", | |
255 "Polynomial Features", | |
256 "Fix Imbalance", | |
257 "Models", | |
258 ] | |
259 setup_rows = [] | |
260 for key in display_keys: | |
261 pk = key.lower().replace(" ", "_") | |
262 v = all_params.get(pk) | |
263 if key == "Train Size": | |
264 frac = ( | |
265 float(v) | |
266 if v is not None | |
267 else (n_train / total_rows if total_rows else 0) | |
268 ) | |
269 dv = f"{frac:.2f} ({n_train} rows)" | |
270 elif key in { | |
271 "Normalize", | |
272 "Feature Selection", | |
273 "Cross Validation", | |
274 "Remove Outliers", | |
275 "Remove Multicollinearity", | |
276 "Polynomial Features", | |
277 "Fix Imbalance", | |
278 }: | |
279 dv = bool(v) | |
280 elif key == "Cross Validation Folds": | |
281 dv = v if v is not None else "None" | |
282 elif key == "Models": | |
283 dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None" | |
284 else: | |
285 dv = v if v is not None else "None" | |
286 setup_rows.append([key, dv]) | |
287 if hasattr(self.exp, "_fold_metric"): | |
288 setup_rows.append(["best_model_metric", self.exp._fold_metric]) | |
289 | |
290 df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"]) | |
291 df_setup.to_csv(Path(self.output_dir) / "setup_params.csv", index=False) | |
292 | |
293 # 4) Persist CSVs | |
294 self.results.to_csv( | |
295 Path(self.output_dir) / "comparison_results.csv", index=False | |
296 ) | |
297 self.test_result_df.to_csv( | |
298 Path(self.output_dir) / "test_results.csv", index=False | |
299 ) | |
300 pd.DataFrame( | |
301 self.best_model.get_params().items(), columns=["Parameter", "Value"] | |
302 ).to_csv(Path(self.output_dir) / "best_model.csv", index=False) | |
303 | |
304 # 5) Header | |
305 header = f"<h2>Best Model: {best_model_name}</h2>" | |
306 | |
307 # — Validation Summary & Configuration — | |
308 val_df = self.results.copy() | |
309 # mapping raw plot keys to user-friendly titles | |
310 plot_title_map = { | |
311 "learning": "Learning Curve", | |
312 "vc": "Validation Curve", | |
313 "calibration": "Calibration Curve", | |
314 "dimension": "Dimensionality Reduction", | |
315 "manifold": "Manifold Learning", | |
316 "rfe": "Recursive Feature Elimination", | |
317 "threshold": "Threshold Plot", | |
318 "percentage_above_below": "Percentage Above vs. Below Cutoff", | |
319 "class_report": "Classification Report", | |
320 "pr_auc": "Precision-Recall AUC", | |
321 "roc_auc": "Receiver Operating Characteristic AUC", | |
322 "residuals": "Residuals Distribution", | |
323 "error": "Prediction Error Distribution", | |
226 } | 324 } |
227 setup_params_table = pd.DataFrame( | 325 val_df.drop(columns=["TT (Ec)", "TT (Sec)"], errors="ignore", inplace=True) |
228 list(filtered_setup_params.items()), columns=["Parameter", "Value"] | 326 summary_html = ( |
327 header | |
328 + "<h2>Train & Validation Summary</h2>" | |
329 + '<div class="table-wrapper">' | |
330 + val_df.to_html(index=False, classes="table sortable") | |
331 + "</div>" | |
332 + "<h2>Setup Parameters</h2>" | |
333 + '<div class="table-wrapper">' | |
334 + df_setup.to_html(index=False, classes="table sortable") | |
335 + "</div>" | |
336 # — Hyperparameters | |
337 + "<h2>Best Model Hyperparameters</h2>" | |
338 + '<div class="table-wrapper">' | |
339 + pd.DataFrame( | |
340 self.best_model.get_params().items(), columns=["Parameter", "Value"] | |
341 ).to_html(index=False, classes="table sortable") | |
342 + "</div>" | |
229 ) | 343 ) |
230 | 344 |
231 best_model_params = pd.DataFrame( | 345 # choose summary plots based on task type |
232 self.best_model.get_params().items(), columns=["Parameter", "Value"] | 346 if self.task_type == "classification": |
347 summary_plots = [ | |
348 "learning", | |
349 "vc", | |
350 "calibration", | |
351 "dimension", | |
352 "manifold", | |
353 "rfe", | |
354 "threshold", | |
355 "percentage_above_below", | |
356 ] | |
357 else: | |
358 summary_plots = ["learning", "vc", "parameter", "residuals"] | |
359 | |
360 for name in summary_plots: | |
361 if name in self.plots: | |
362 summary_html += "<hr>" | |
363 b64 = encode_image_to_base64(self.plots[name]) | |
364 title = plot_title_map.get(name, name.replace("_", " ").title()) | |
365 summary_html += ( | |
366 '<div class="plot">' | |
367 f"<h2>{title}</h2>" | |
368 f'<img src="data:image/png;base64,{b64}" ' | |
369 'style="max-width:90%;max-height:600px;border:1px solid #ddd;"/>' | |
370 "</div>" | |
371 ) | |
372 | |
373 # — Test Summary — | |
374 test_html = ( | |
375 header | |
376 + '<div class="table-wrapper">' | |
377 + self.test_result_df.to_html(index=False, classes="table sortable") | |
378 + "</div>" | |
233 ) | 379 ) |
234 best_model_params.to_csv( | 380 if self.task_type == "regression": |
235 os.path.join(self.output_dir, "best_model.csv"), index=False | 381 try: |
236 ) | 382 y_true = ( |
237 self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv")) | 383 pd.Series(self.exp.y_test_transformed) |
238 self.test_result_df.to_csv(os.path.join(self.output_dir, "test_results.csv")) | 384 .reset_index(drop=True) |
239 | 385 .rename("True") |
240 plots_html = "" | 386 ) |
241 length = len(self.plots) | 387 y_pred = pd.Series( |
242 for i, (plot_name, plot_path) in enumerate(self.plots.items()): | 388 self.best_model.predict(self.exp.X_test_transformed) |
243 encoded_image = self.encode_image_to_base64(plot_path) | 389 ).rename("Predicted") |
244 plots_html += ( | 390 df_tp = pd.concat([y_true, y_pred], axis=1) |
245 f'<div class="plot">' | 391 test_html += "<h2>True vs Predicted Values</h2>" |
246 f"<h3>{plot_name.capitalize()}</h3>" | 392 test_html += ( |
247 f'<img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">' | 393 '<div class="table-wrapper" style="max-height:400px; overflow-y:auto;">' |
248 f"</div>" | 394 + df_tp.head(50).to_html(index=False, classes="table sortable") |
249 ) | 395 + "</div>" |
250 if i < length - 1: | 396 + add_hr_to_html() |
251 plots_html += "<hr>" | 397 ) |
252 | 398 except Exception as e: |
253 tree_plots = "" | 399 LOG.warning(f"Could not generate True vs Predicted table: {e}") |
254 for i, tree in enumerate(self.trees): | 400 |
255 if tree: | 401 # 5a) Explainer-substituted plots in order |
256 tree_plots += ( | 402 if self.task_type == "regression": |
257 f'<div class="plot">' | 403 test_order = ["residuals"] |
258 f"<h3>Tree {i + 1}</h3>" | 404 else: |
259 f'<img src="data:image/png;base64,{tree}" alt="tree {i + 1}">' | 405 test_order = [ |
260 f"</div>" | 406 "confusion_matrix", |
261 ) | 407 "roc_auc", |
262 | 408 "pr_auc", |
263 analyzer = FeatureImportanceAnalyzer( | 409 "lift_curve", |
410 "threshold", | |
411 "cumulative_precision", | |
412 ] | |
413 for key in test_order: | |
414 fig_or_fn = self.explainer_plots.pop(key, None) | |
415 if fig_or_fn is not None: | |
416 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn | |
417 title = plot_title_map.get(key, key.replace("_", " ").title()) | |
418 test_html += ( | |
419 f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() | |
420 ) | |
421 # 5b) Remaining PyCaret test plots | |
422 for name, path in self.plots.items(): | |
423 # classification: include only the small extras, before skipping anything | |
424 if self.task_type == "classification" and name in { | |
425 "threshold", | |
426 "pr_auc", | |
427 "class_report", | |
428 }: | |
429 title = plot_title_map.get(name, name.replace("_", " ").title()) | |
430 b64 = encode_image_to_base64(path) | |
431 test_html += ( | |
432 f"<h2>{title}</h2>" | |
433 "<div class='plot'>" | |
434 f"<img src='data:image/png;base64,{b64}' " | |
435 "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>" | |
436 "</div>" + add_hr_to_html() | |
437 ) | |
438 continue | |
439 | |
440 # regression: explicitly include the 'error' plot, before skipping | |
441 if self.task_type == "regression" and name == "error": | |
442 title = plot_title_map.get("error", "Prediction Error Distribution") | |
443 b64 = encode_image_to_base64(path) | |
444 test_html += ( | |
445 f"<h2>{title}</h2>" | |
446 "<div class='plot'>" | |
447 f"<img src='data:image/png;base64,{b64}' " | |
448 "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>" | |
449 "</div>" + add_hr_to_html() | |
450 ) | |
451 continue | |
452 | |
453 # now skip any plots already rendered via test_order | |
454 if name in test_order: | |
455 continue | |
456 | |
457 # — Feature Importance — | |
458 feature_html = header | |
459 | |
460 # 6a) PyCaret’s default feature importances | |
461 feature_html += FeatureImportanceAnalyzer( | |
264 data=self.data, | 462 data=self.data, |
265 target_col=self.target_col, | 463 target_col=self.target_col, |
266 task_type=self.task_type, | 464 task_type=self.task_type, |
267 output_dir=self.output_dir, | 465 output_dir=self.output_dir, |
268 exp=self.exp, | 466 exp=self.exp, |
269 best_model=self.best_model, | 467 best_model=self.best_model, |
468 ).run() | |
469 | |
470 # 6b) Explainer SHAP importances | |
471 for key in ["shap_mean", "shap_perm"]: | |
472 fig_or_fn = self.explainer_plots.pop(key, None) | |
473 if fig_or_fn is not None: | |
474 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn | |
475 # give SHAP plots explicit titles | |
476 title = ( | |
477 "Mean Absolute SHAP Value Impact" | |
478 if key == "shap_mean" | |
479 else "Permutation Feature Importance" | |
480 ) | |
481 feature_html += ( | |
482 f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() | |
483 ) | |
484 | |
485 # 6c) PDPs last | |
486 pdp_keys = sorted(k for k in self.explainer_plots if k.startswith("pdp__")) | |
487 for k in pdp_keys: | |
488 fig_or_fn = self.explainer_plots[k] | |
489 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn | |
490 # extract feature name | |
491 feature = k.split("__", 1)[1] | |
492 title = f"Partial Dependence for {feature}" | |
493 feature_html += ( | |
494 f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() | |
495 ) | |
496 # 7) Assemble final HTML (three tabs) | |
497 html = get_html_template() | |
498 html += "<h1>Tabular Learner Model Report</h1>" | |
499 html += build_tabbed_html(summary_html, test_html, feature_html) | |
500 html += get_feature_metrics_help_modal() | |
501 html += get_html_closing() | |
502 | |
503 # 8) Write out | |
504 (Path(self.output_dir) / "comparison_result.html").write_text( | |
505 html, encoding="utf-8" | |
270 ) | 506 ) |
271 feature_importance_html = analyzer.run() | 507 LOG.info(f"HTML report generated at: {self.output_dir}/comparison_result.html") |
272 | |
273 # --- Feature Metrics Help Button --- | |
274 feature_metrics_button_html = ( | |
275 '<button class="help-modal-btn" id="openFeatureMetricsHelp" style="margin-bottom:12px;">' | |
276 "Help: Metrics Guide" | |
277 "</button>" | |
278 "<style>" | |
279 ".help-modal-btn {" | |
280 "background-color: #17623b;" | |
281 "color: #fff;" | |
282 "border: none;" | |
283 "border-radius: 24px;" | |
284 "padding: 10px 28px;" | |
285 "font-size: 1.1rem;" | |
286 "font-weight: bold;" | |
287 "letter-spacing: 0.03em;" | |
288 "cursor: pointer;" | |
289 "transition: background 0.2s, box-shadow 0.2s;" | |
290 "box-shadow: 0 2px 8px rgba(23,98,59,0.07);" | |
291 "}" | |
292 ".help-modal-btn:hover, .help-modal-btn:focus {" | |
293 "background-color: #21895e;" | |
294 "outline: none;" | |
295 "box-shadow: 0 4px 16px rgba(23,98,59,0.14);" | |
296 "}" | |
297 "</style>" | |
298 ) | |
299 | |
300 html_content = ( | |
301 f"{get_html_template()}" | |
302 "<h1>Tabular Learner Model Report</h1>" | |
303 f"{feature_metrics_button_html}" | |
304 '<div class="tabs">' | |
305 '<div class="tab" onclick="openTab(event, \'summary\')">' | |
306 "Validation Result Summary & Config</div>" | |
307 '<div class="tab" onclick="openTab(event, \'plots\')">' | |
308 "Test Results</div>" | |
309 '<div class="tab" onclick="openTab(event, \'feature\')">' | |
310 "Feature Importance</div>" | |
311 ) | |
312 if self.plots_explainer_html: | |
313 html_content += ( | |
314 '<div class="tab" onclick="openTab(event, \'explainer\')">' | |
315 "Explainer Plots</div>" | |
316 ) | |
317 html_content += ( | |
318 "</div>" | |
319 '<div id="summary" class="tab-content">' | |
320 f"<h2>Model Metrics from {'Cross-Validation Set' if self.cross_validation else 'Validation set'}</h2>" | |
321 f"<h2>Best Model: {model_name}</h2>" | |
322 "<h5>The best model is selected by: Accuracy (Classification)" | |
323 " or R2 (Regression).</h5>" | |
324 f"{self.results.to_html(index=False, classes='table sortable')}" | |
325 "<h2>Best Model's Hyperparameters</h2>" | |
326 f"{best_model_params.to_html(index=False, header=True, classes='table sortable')}" | |
327 "<h2>Setup Parameters</h2>" | |
328 f"{setup_params_table.to_html(index=False, header=True, classes='table sortable')}" | |
329 "<h5>If you want to know all the experiment setup parameters," | |
330 " please check the PyCaret documentation for" | |
331 " the classification/regression <code>exp</code> function.</h5>" | |
332 "</div>" | |
333 '<div id="plots" class="tab-content">' | |
334 f"<h2>Best Model: {model_name}</h2>" | |
335 "<h5>The best model is selected by: Accuracy (Classification)" | |
336 " or R2 (Regression).</h5>" | |
337 "<h2>Test Metrics</h2>" | |
338 f"{self.test_result_df.to_html(index=False)}" | |
339 "<h2>Test Results</h2>" | |
340 f"{plots_html}" | |
341 "</div>" | |
342 '<div id="feature" class="tab-content">' | |
343 f"{feature_importance_html}" | |
344 "</div>" | |
345 ) | |
346 if self.plots_explainer_html: | |
347 html_content += ( | |
348 '<div id="explainer" class="tab-content">' | |
349 f"{self.plots_explainer_html}" | |
350 f"{tree_plots}" | |
351 "</div>" | |
352 ) | |
353 html_content += ( | |
354 "<script>" | |
355 "document.addEventListener(\"DOMContentLoaded\", function() {" | |
356 "var tables = document.querySelectorAll(\"table.sortable\");" | |
357 "tables.forEach(function(table) {" | |
358 "var headers = table.querySelectorAll(\"th\");" | |
359 "headers.forEach(function(header, index) {" | |
360 "header.style.cursor = \"pointer\";" | |
361 "// Add initial arrow (up) to indicate sortability, use Unicode ↑ (U+2191)" | |
362 "header.innerHTML += '<span class=\"sort-arrow\"> ↑</span>';" | |
363 "header.addEventListener(\"click\", function() {" | |
364 "var direction = this.getAttribute(" | |
365 "\"data-sort-direction\"" | |
366 ") || \"asc\";" | |
367 "// Reset arrows in all headers of this table" | |
368 "headers.forEach(function(h) {" | |
369 "var arrow = h.querySelector(\".sort-arrow\");" | |
370 "if (arrow) arrow.textContent = \" ↑\";" | |
371 "});" | |
372 "// Set arrow for clicked header" | |
373 "var arrow = this.querySelector(\".sort-arrow\");" | |
374 "arrow.textContent = direction === \"asc\" ? \" ↓\" : \" ↑\";" | |
375 "sortTable(table, index, direction);" | |
376 "this.setAttribute(\"data-sort-direction\"," | |
377 "direction === \"asc\" ? \"desc\" : \"asc\");" | |
378 "});" | |
379 "});" | |
380 "});" | |
381 "});" | |
382 "function sortTable(table, colNum, direction) {" | |
383 "var tb = table.tBodies[0];" | |
384 "var tr = Array.prototype.slice.call(tb.rows, 0);" | |
385 "var multiplier = direction === \"asc\" ? 1 : -1;" | |
386 "tr = tr.sort(function(a, b) {" | |
387 "var aText = a.cells[colNum].textContent.trim();" | |
388 "var bText = b.cells[colNum].textContent.trim();" | |
389 "// Remove arrow from text comparison" | |
390 "aText = aText.replace(/[↑↓]/g, '').trim();" | |
391 "bText = bText.replace(/[↑↓]/g, '').trim();" | |
392 "if (!isNaN(aText) && !isNaN(bText)) {" | |
393 "return multiplier * (" | |
394 "parseFloat(aText) - parseFloat(bText)" | |
395 ");" | |
396 "} else {" | |
397 "return multiplier * aText.localeCompare(bText);" | |
398 "}" | |
399 "});" | |
400 "for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]);" | |
401 "}" | |
402 "</script>" | |
403 ) | |
404 # --- Add the Feature Metrics Help Modal --- | |
405 html_content += get_feature_metrics_help_modal() | |
406 html_content += f"{get_html_closing()}" | |
407 with open( | |
408 os.path.join(self.output_dir, "comparison_result.html"), | |
409 "w", | |
410 encoding="utf-8", | |
411 ) as file: | |
412 file.write(html_content) | |
413 | 508 |
414 def save_dashboard(self): | 509 def save_dashboard(self): |
415 raise NotImplementedError("Subclasses should implement this method") | 510 raise NotImplementedError("Subclasses should implement this method") |
416 | 511 |
417 def generate_plots_explainer(self): | 512 def generate_plots_explainer(self): |
424 | 519 |
425 LOG.info("Generating tree plots") | 520 LOG.info("Generating tree plots") |
426 X_test = self.exp.X_test_transformed.copy() | 521 X_test = self.exp.X_test_transformed.copy() |
427 y_test = self.exp.y_test_transformed | 522 y_test = self.exp.y_test_transformed |
428 | 523 |
429 is_rf = isinstance( | 524 if isinstance(self.best_model, (RandomForestClassifier, RandomForestRegressor)): |
430 self.best_model, (RandomForestClassifier, RandomForestRegressor) | 525 n_trees = self.best_model.n_estimators |
431 ) | 526 elif isinstance(self.best_model, (XGBClassifier, XGBRegressor)): |
432 is_xgb = isinstance(self.best_model, (XGBClassifier, XGBRegressor)) | 527 n_trees = len(self.best_model.get_booster().get_dump()) |
433 | |
434 num_trees = None | |
435 if is_rf: | |
436 num_trees = self.best_model.n_estimators | |
437 elif is_xgb: | |
438 num_trees = len(self.best_model.get_booster().get_dump()) | |
439 else: | 528 else: |
440 LOG.warning("Tree plots not supported for this model type.") | 529 LOG.warning("Tree plots not supported for this model type.") |
441 return | 530 return |
442 | 531 |
443 try: | 532 explainer = RandomForestExplainer(self.best_model, X_test, y_test) |
444 explainer = RandomForestExplainer(self.best_model, X_test, y_test) | 533 for i in range(n_trees): |
445 for i in range(num_trees): | 534 fig = explainer.decisiontree_encoded(tree_idx=i, index=0) |
446 fig = explainer.decisiontree_encoded(tree_idx=i, index=0) | 535 self.trees.append(fig) |
447 LOG.info(f"Tree {i + 1}") | |
448 LOG.info(fig) | |
449 self.trees.append(fig) | |
450 except Exception as e: | |
451 LOG.error(f"Error generating tree plots: {e}") | |
452 | 536 |
453 def run(self): | 537 def run(self): |
454 self.load_data() | 538 self.load_data() |
455 self.setup_pycaret() | 539 self.setup_pycaret() |
456 self.train_model() | 540 self.train_model() |