Mercurial > repos > goeckslab > tabular_learner
comparison base_model_trainer.py @ 4:11fdac5affb3 draft
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
| author | goeckslab |
|---|---|
| date | Fri, 25 Jul 2025 19:02:12 +0000 |
| parents | f6a65e05d6ec |
| children | 3d42f82b3c7f |
comparison
equal
deleted
inserted
replaced
| 3:f6a65e05d6ec | 4:11fdac5affb3 |
|---|---|
| 1 import base64 | 1 import base64 |
| 2 import logging | 2 import logging |
| 3 import os | |
| 4 import tempfile | 3 import tempfile |
| 4 from pathlib import Path | |
| 5 | 5 |
| 6 import h5py | 6 import h5py |
| 7 import joblib | 7 import joblib |
| 8 import numpy as np | 8 import numpy as np |
| 9 import pandas as pd | 9 import pandas as pd |
| 10 from feature_help_modal import get_feature_metrics_help_modal | 10 from feature_help_modal import get_feature_metrics_help_modal |
| 11 from feature_importance import FeatureImportanceAnalyzer | 11 from feature_importance import FeatureImportanceAnalyzer |
| 12 from sklearn.metrics import average_precision_score | 12 from sklearn.metrics import average_precision_score |
| 13 from utils import get_html_closing, get_html_template | 13 from utils import ( |
| 14 add_hr_to_html, | |
| 15 add_plot_to_html, | |
| 16 build_tabbed_html, | |
| 17 encode_image_to_base64, | |
| 18 get_html_closing, | |
| 19 get_html_template, | |
| 20 ) | |
| 14 | 21 |
| 15 logging.basicConfig(level=logging.DEBUG) | 22 logging.basicConfig(level=logging.DEBUG) |
| 16 LOG = logging.getLogger(__name__) | 23 LOG = logging.getLogger(__name__) |
| 17 | 24 |
| 18 | 25 |
| 25 task_type, | 32 task_type, |
| 26 random_seed, | 33 random_seed, |
| 27 test_file=None, | 34 test_file=None, |
| 28 **kwargs, | 35 **kwargs, |
| 29 ): | 36 ): |
| 30 self.exp = None # This will be set in the subclass | 37 self.exp = None |
| 31 self.input_file = input_file | 38 self.input_file = input_file |
| 32 self.target_col = target_col | 39 self.target_col = target_col |
| 33 self.output_dir = output_dir | 40 self.output_dir = output_dir |
| 34 self.task_type = task_type | 41 self.task_type = task_type |
| 35 self.random_seed = random_seed | 42 self.random_seed = random_seed |
| 37 self.target = None | 44 self.target = None |
| 38 self.best_model = None | 45 self.best_model = None |
| 39 self.results = None | 46 self.results = None |
| 40 self.features_name = None | 47 self.features_name = None |
| 41 self.plots = {} | 48 self.plots = {} |
| 42 self.expaliner = None | 49 self.explainer_plots = {} |
| 43 self.plots_explainer_html = None | 50 self.plots_explainer_html = None |
| 44 self.trees = [] | 51 self.trees = [] |
| 45 for key, value in kwargs.items(): | 52 self.user_kwargs = kwargs.copy() |
| 53 for key, value in self.user_kwargs.items(): | |
| 46 setattr(self, key, value) | 54 setattr(self, key, value) |
| 47 self.setup_params = {} | 55 self.setup_params = {} |
| 48 self.test_file = test_file | 56 self.test_file = test_file |
| 49 self.test_data = None | 57 self.test_data = None |
| 50 | 58 |
| 55 | 63 |
| 56 def load_data(self): | 64 def load_data(self): |
| 57 LOG.info(f"Loading data from {self.input_file}") | 65 LOG.info(f"Loading data from {self.input_file}") |
| 58 self.data = pd.read_csv(self.input_file, sep=None, engine="python") | 66 self.data = pd.read_csv(self.input_file, sep=None, engine="python") |
| 59 self.data.columns = self.data.columns.str.replace(".", "_") | 67 self.data.columns = self.data.columns.str.replace(".", "_") |
| 60 | |
| 61 # Remove prediction_label if present | |
| 62 if "prediction_label" in self.data.columns: | 68 if "prediction_label" in self.data.columns: |
| 63 self.data = self.data.drop(columns=["prediction_label"]) | 69 self.data = self.data.drop(columns=["prediction_label"]) |
| 64 | 70 |
| 65 numeric_cols = self.data.select_dtypes(include=["number"]).columns | 71 numeric_cols = self.data.select_dtypes(include=["number"]).columns |
| 66 non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns | 72 non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns |
| 67 | |
| 68 self.data[numeric_cols] = self.data[numeric_cols].apply( | 73 self.data[numeric_cols] = self.data[numeric_cols].apply( |
| 69 pd.to_numeric, errors="coerce" | 74 pd.to_numeric, errors="coerce" |
| 70 ) | 75 ) |
| 71 | |
| 72 if len(non_numeric_cols) > 0: | 76 if len(non_numeric_cols) > 0: |
| 73 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") | 77 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") |
| 74 | 78 |
| 75 names = self.data.columns.to_list() | 79 names = self.data.columns.to_list() |
| 76 target_index = int(self.target_col) - 1 | 80 target_index = int(self.target_col) - 1 |
| 77 self.target = names[target_index] | 81 self.target = names[target_index] |
| 78 self.features_name = [name for i, name in enumerate(names) if i != target_index] | 82 self.features_name = [n for i, n in enumerate(names) if i != target_index] |
| 79 if hasattr(self, "missing_value_strategy"): | 83 |
| 80 if self.missing_value_strategy == "mean": | 84 if getattr(self, "missing_value_strategy", None): |
| 85 strat = self.missing_value_strategy | |
| 86 if strat == "mean": | |
| 81 self.data = self.data.fillna(self.data.mean(numeric_only=True)) | 87 self.data = self.data.fillna(self.data.mean(numeric_only=True)) |
| 82 elif self.missing_value_strategy == "median": | 88 elif strat == "median": |
| 83 self.data = self.data.fillna(self.data.median(numeric_only=True)) | 89 self.data = self.data.fillna(self.data.median(numeric_only=True)) |
| 84 elif self.missing_value_strategy == "drop": | 90 elif strat == "drop": |
| 85 self.data = self.data.dropna() | 91 self.data = self.data.dropna() |
| 86 else: | 92 else: |
| 87 # Default strategy if not specified | |
| 88 self.data = self.data.fillna(self.data.median(numeric_only=True)) | 93 self.data = self.data.fillna(self.data.median(numeric_only=True)) |
| 89 | 94 |
| 90 if self.test_file: | 95 if self.test_file: |
| 91 LOG.info(f"Loading test data from {self.test_file}") | 96 LOG.info(f"Loading test data from {self.test_file}") |
| 92 self.test_data = pd.read_csv(self.test_file, sep=None, engine="python") | 97 df_test = pd.read_csv(self.test_file, sep=None, engine="python") |
| 93 self.test_data = self.test_data[numeric_cols].apply( | 98 df_test.columns = df_test.columns.str.replace(".", "_") |
| 94 pd.to_numeric, errors="coerce" | 99 self.test_data = df_test |
| 95 ) | |
| 96 self.test_data.columns = self.test_data.columns.str.replace(".", "_") | |
| 97 | 100 |
| 98 def setup_pycaret(self): | 101 def setup_pycaret(self): |
| 99 LOG.info("Initializing PyCaret") | 102 LOG.info("Initializing PyCaret") |
| 100 self.setup_params = { | 103 self.setup_params = { |
| 101 "target": self.target, | 104 "target": self.target, |
| 103 "html": True, | 106 "html": True, |
| 104 "log_experiment": False, | 107 "log_experiment": False, |
| 105 "system_log": False, | 108 "system_log": False, |
| 106 "index": False, | 109 "index": False, |
| 107 } | 110 } |
| 108 | |
| 109 if self.test_data is not None: | 111 if self.test_data is not None: |
| 110 self.setup_params["test_data"] = self.test_data | 112 self.setup_params["test_data"] = self.test_data |
| 111 | 113 for attr in [ |
| 112 if ( | 114 "train_size", |
| 113 hasattr(self, "train_size") | 115 "normalize", |
| 114 and self.train_size is not None | 116 "feature_selection", |
| 115 and self.test_data is None | 117 "remove_outliers", |
| 116 ): | 118 "remove_multicollinearity", |
| 117 self.setup_params["train_size"] = self.train_size | 119 "polynomial_features", |
| 118 | 120 "feature_interaction", |
| 119 if hasattr(self, "normalize") and self.normalize is not None: | 121 "feature_ratio", |
| 120 self.setup_params["normalize"] = self.normalize | 122 "fix_imbalance", |
| 121 | 123 ]: |
| 122 if hasattr(self, "feature_selection") and self.feature_selection is not None: | 124 val = getattr(self, attr, None) |
| 123 self.setup_params["feature_selection"] = self.feature_selection | 125 if val is not None: |
| 124 | 126 self.setup_params[attr] = val |
| 125 if ( | 127 if getattr(self, "cross_validation_folds", None) is not None: |
| 126 hasattr(self, "cross_validation") | 128 self.setup_params["fold"] = self.cross_validation_folds |
| 127 and self.cross_validation is not None | |
| 128 and self.cross_validation is False | |
| 129 ): | |
| 130 logging.info( | |
| 131 "cross_validation is set to False. This will disable cross-validation." | |
| 132 ) | |
| 133 | |
| 134 if hasattr(self, "cross_validation") and self.cross_validation: | |
| 135 if hasattr(self, "cross_validation_folds"): | |
| 136 self.setup_params["fold"] = self.cross_validation_folds | |
| 137 | |
| 138 if hasattr(self, "remove_outliers") and self.remove_outliers is not None: | |
| 139 self.setup_params["remove_outliers"] = self.remove_outliers | |
| 140 | |
| 141 if ( | |
| 142 hasattr(self, "remove_multicollinearity") | |
| 143 and self.remove_multicollinearity is not None | |
| 144 ): | |
| 145 self.setup_params["remove_multicollinearity"] = ( | |
| 146 self.remove_multicollinearity | |
| 147 ) | |
| 148 | |
| 149 if ( | |
| 150 hasattr(self, "polynomial_features") | |
| 151 and self.polynomial_features is not None | |
| 152 ): | |
| 153 self.setup_params["polynomial_features"] = self.polynomial_features | |
| 154 | |
| 155 if hasattr(self, "fix_imbalance") and self.fix_imbalance is not None: | |
| 156 self.setup_params["fix_imbalance"] = self.fix_imbalance | |
| 157 | |
| 158 LOG.info(self.setup_params) | 129 LOG.info(self.setup_params) |
| 159 | 130 |
| 160 # Solution: instantiate the correct PyCaret experiment based on task_type | |
| 161 if self.task_type == "classification": | 131 if self.task_type == "classification": |
| 162 from pycaret.classification import ClassificationExperiment | 132 from pycaret.classification import ClassificationExperiment |
| 163 | 133 |
| 164 self.exp = ClassificationExperiment() | 134 self.exp = ClassificationExperiment() |
| 165 elif self.task_type == "regression": | 135 elif self.task_type == "regression": |
| 168 self.exp = RegressionExperiment() | 138 self.exp = RegressionExperiment() |
| 169 else: | 139 else: |
| 170 raise ValueError("task_type must be 'classification' or 'regression'") | 140 raise ValueError("task_type must be 'classification' or 'regression'") |
| 171 | 141 |
| 172 self.exp.setup(self.data, **self.setup_params) | 142 self.exp.setup(self.data, **self.setup_params) |
| 143 self.setup_params.update(self.user_kwargs) | |
| 173 | 144 |
| 174 def train_model(self): | 145 def train_model(self): |
| 175 LOG.info("Training and selecting the best model") | 146 LOG.info("Training and selecting the best model") |
| 176 if self.task_type == "classification": | 147 if self.task_type == "classification": |
| 177 average_displayed = "Weighted" | |
| 178 self.exp.add_metric( | 148 self.exp.add_metric( |
| 179 id=f"PR-AUC-{average_displayed}", | 149 id="PR-AUC-Weighted", |
| 180 name=f"PR-AUC-{average_displayed}", | 150 name="PR-AUC-Weighted", |
| 181 target="pred_proba", | 151 target="pred_proba", |
| 182 score_func=average_precision_score, | 152 score_func=average_precision_score, |
| 183 average="weighted", | 153 average="weighted", |
| 184 ) | 154 ) |
| 185 | 155 # Build arguments for compare_models() |
| 186 if hasattr(self, "models") and self.models is not None: | 156 compare_kwargs = {} |
| 187 self.best_model = self.exp.compare_models(include=self.models, cross_validation=self.cross_validation) | 157 if getattr(self, "models", None): |
| 188 else: | 158 compare_kwargs["include"] = self.models |
| 189 self.best_model = self.exp.compare_models(cross_validation=self.cross_validation) | 159 |
| 160 # Respect explicit cross-validation flag | |
| 161 if getattr(self, "cross_validation", None) is not None: | |
| 162 compare_kwargs["cross_validation"] = self.cross_validation | |
| 163 | |
| 164 # Respect explicit fold count | |
| 165 if getattr(self, "cross_validation_folds", None) is not None: | |
| 166 compare_kwargs["fold"] = self.cross_validation_folds | |
| 167 | |
| 168 LOG.info(f"compare_models kwargs: {compare_kwargs}") | |
| 169 self.best_model = self.exp.compare_models(**compare_kwargs) | |
| 190 self.results = self.exp.pull() | 170 self.results = self.exp.pull() |
| 171 if getattr(self, "tune_model", False): | |
| 172 LOG.info("Tuning hyperparameters of the best model") | |
| 173 self.best_model = self.exp.tune_model(self.best_model) | |
| 174 self.results = self.exp.pull() | |
| 191 | 175 |
| 192 if self.task_type == "classification": | 176 if self.task_type == "classification": |
| 193 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) | 177 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) |
| 194 | |
| 195 _ = self.exp.predict_model(self.best_model) | 178 _ = self.exp.predict_model(self.best_model) |
| 196 self.test_result_df = self.exp.pull() | 179 self.test_result_df = self.exp.pull() |
| 197 if self.task_type == "classification": | 180 if self.task_type == "classification": |
| 198 self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True) | 181 self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True) |
| 199 | 182 |
| 200 def save_model(self): | 183 def save_model(self): |
| 201 hdf5_model_path = "pycaret_model.h5" | 184 hdf5_path = Path(self.output_dir) / "pycaret_model.h5" |
| 202 with h5py.File(hdf5_model_path, "w") as f: | 185 with h5py.File(hdf5_path, "w") as f: |
| 203 with tempfile.NamedTemporaryFile(delete=False) as temp_file: | 186 with tempfile.NamedTemporaryFile(delete=False) as tmp: |
| 204 joblib.dump(self.best_model, temp_file.name) | 187 joblib.dump(self.best_model, tmp.name) |
| 205 temp_file.seek(0) | 188 tmp.seek(0) |
| 206 model_bytes = temp_file.read() | 189 model_bytes = tmp.read() |
| 207 f.create_dataset("model", data=np.void(model_bytes)) | 190 f.create_dataset("model", data=np.void(model_bytes)) |
| 208 | 191 |
| 209 def generate_plots(self): | 192 def generate_plots(self): |
| 210 raise NotImplementedError("Subclasses should implement this method") | 193 LOG.info("Generating PyCaret diagnostic pltos") |
| 211 | 194 |
| 212 def encode_image_to_base64(self, img_path): | 195 # choose the right plots based on task |
| 196 if self.task_type == "classification": | |
| 197 plot_names = [ | |
| 198 "learning", | |
| 199 "vc", | |
| 200 "calibration", | |
| 201 "dimension", | |
| 202 "manifold", | |
| 203 "rfe", | |
| 204 "threshold", | |
| 205 "percentage_above_below", | |
| 206 "class_report", | |
| 207 "pr_auc", | |
| 208 "roc_auc", | |
| 209 ] | |
| 210 else: | |
| 211 plot_names = ["residuals", "vc", "parameter", "error", "learning"] | |
| 212 for name in plot_names: | |
| 213 try: | |
| 214 ax = self.exp.plot_model(self.best_model, plot=name, save=False) | |
| 215 out_path = Path(self.output_dir) / f"plot_{name}.png" | |
| 216 fig = ax.get_figure() | |
| 217 fig.savefig(out_path, bbox_inches="tight") | |
| 218 self.plots[name] = str(out_path) | |
| 219 except Exception as e: | |
| 220 LOG.warning(f"Could not generate {name} plot: {e}") | |
| 221 | |
| 222 def encode_image_to_base64(self, img_path: str) -> str: | |
| 213 with open(img_path, "rb") as img_file: | 223 with open(img_path, "rb") as img_file: |
| 214 return base64.b64encode(img_file.read()).decode("utf-8") | 224 return base64.b64encode(img_file.read()).decode("utf-8") |
| 215 | 225 |
| 216 def save_html_report(self): | 226 def save_html_report(self): |
| 217 LOG.info("Saving HTML report") | 227 LOG.info("Saving HTML report") |
| 218 | 228 |
| 219 if not self.output_dir: | 229 # 1) Determine best model name |
| 220 raise ValueError("output_dir must be specified and not None") | 230 try: |
| 221 | 231 best_model_name = str(self.results.iloc[0]["Model"]) |
| 222 model_name = type(self.best_model).__name__ | 232 except Exception: |
| 223 excluded_params = ["html", "log_experiment", "system_log", "test_data"] | 233 best_model_name = type(self.best_model).__name__ |
| 224 filtered_setup_params = { | 234 LOG.info(f"Best model determined as: {best_model_name}") |
| 225 k: v for k, v in self.setup_params.items() if k not in excluded_params | 235 |
| 236 # 2) Compute training sample count | |
| 237 try: | |
| 238 n_train = self.exp.X_train.shape[0] | |
| 239 except Exception: | |
| 240 n_train = getattr(self.exp, "X_train_transformed", pd.DataFrame()).shape[0] | |
| 241 total_rows = self.data.shape[0] | |
| 242 | |
| 243 # 3) Build setup parameters table | |
| 244 all_params = self.setup_params | |
| 245 display_keys = [ | |
| 246 "Target", | |
| 247 "Session ID", | |
| 248 "Train Size", | |
| 249 "Normalize", | |
| 250 "Feature Selection", | |
| 251 "Cross Validation", | |
| 252 "Cross Validation Folds", | |
| 253 "Remove Outliers", | |
| 254 "Remove Multicollinearity", | |
| 255 "Polynomial Features", | |
| 256 "Fix Imbalance", | |
| 257 "Models", | |
| 258 ] | |
| 259 setup_rows = [] | |
| 260 for key in display_keys: | |
| 261 pk = key.lower().replace(" ", "_") | |
| 262 v = all_params.get(pk) | |
| 263 if key == "Train Size": | |
| 264 frac = ( | |
| 265 float(v) | |
| 266 if v is not None | |
| 267 else (n_train / total_rows if total_rows else 0) | |
| 268 ) | |
| 269 dv = f"{frac:.2f} ({n_train} rows)" | |
| 270 elif key in { | |
| 271 "Normalize", | |
| 272 "Feature Selection", | |
| 273 "Cross Validation", | |
| 274 "Remove Outliers", | |
| 275 "Remove Multicollinearity", | |
| 276 "Polynomial Features", | |
| 277 "Fix Imbalance", | |
| 278 }: | |
| 279 dv = bool(v) | |
| 280 elif key == "Cross Validation Folds": | |
| 281 dv = v if v is not None else "None" | |
| 282 elif key == "Models": | |
| 283 dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None" | |
| 284 else: | |
| 285 dv = v if v is not None else "None" | |
| 286 setup_rows.append([key, dv]) | |
| 287 if hasattr(self.exp, "_fold_metric"): | |
| 288 setup_rows.append(["best_model_metric", self.exp._fold_metric]) | |
| 289 | |
| 290 df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"]) | |
| 291 df_setup.to_csv(Path(self.output_dir) / "setup_params.csv", index=False) | |
| 292 | |
| 293 # 4) Persist CSVs | |
| 294 self.results.to_csv( | |
| 295 Path(self.output_dir) / "comparison_results.csv", index=False | |
| 296 ) | |
| 297 self.test_result_df.to_csv( | |
| 298 Path(self.output_dir) / "test_results.csv", index=False | |
| 299 ) | |
| 300 pd.DataFrame( | |
| 301 self.best_model.get_params().items(), columns=["Parameter", "Value"] | |
| 302 ).to_csv(Path(self.output_dir) / "best_model.csv", index=False) | |
| 303 | |
| 304 # 5) Header | |
| 305 header = f"<h2>Best Model: {best_model_name}</h2>" | |
| 306 | |
| 307 # — Validation Summary & Configuration — | |
| 308 val_df = self.results.copy() | |
| 309 # mapping raw plot keys to user-friendly titles | |
| 310 plot_title_map = { | |
| 311 "learning": "Learning Curve", | |
| 312 "vc": "Validation Curve", | |
| 313 "calibration": "Calibration Curve", | |
| 314 "dimension": "Dimensionality Reduction", | |
| 315 "manifold": "Manifold Learning", | |
| 316 "rfe": "Recursive Feature Elimination", | |
| 317 "threshold": "Threshold Plot", | |
| 318 "percentage_above_below": "Percentage Above vs. Below Cutoff", | |
| 319 "class_report": "Classification Report", | |
| 320 "pr_auc": "Precision-Recall AUC", | |
| 321 "roc_auc": "Receiver Operating Characteristic AUC", | |
| 322 "residuals": "Residuals Distribution", | |
| 323 "error": "Prediction Error Distribution", | |
| 226 } | 324 } |
| 227 setup_params_table = pd.DataFrame( | 325 val_df.drop(columns=["TT (Ec)", "TT (Sec)"], errors="ignore", inplace=True) |
| 228 list(filtered_setup_params.items()), columns=["Parameter", "Value"] | 326 summary_html = ( |
| 327 header | |
| 328 + "<h2>Train & Validation Summary</h2>" | |
| 329 + '<div class="table-wrapper">' | |
| 330 + val_df.to_html(index=False, classes="table sortable") | |
| 331 + "</div>" | |
| 332 + "<h2>Setup Parameters</h2>" | |
| 333 + '<div class="table-wrapper">' | |
| 334 + df_setup.to_html(index=False, classes="table sortable") | |
| 335 + "</div>" | |
| 336 # — Hyperparameters | |
| 337 + "<h2>Best Model Hyperparameters</h2>" | |
| 338 + '<div class="table-wrapper">' | |
| 339 + pd.DataFrame( | |
| 340 self.best_model.get_params().items(), columns=["Parameter", "Value"] | |
| 341 ).to_html(index=False, classes="table sortable") | |
| 342 + "</div>" | |
| 229 ) | 343 ) |
| 230 | 344 |
| 231 best_model_params = pd.DataFrame( | 345 # choose summary plots based on task type |
| 232 self.best_model.get_params().items(), columns=["Parameter", "Value"] | 346 if self.task_type == "classification": |
| 347 summary_plots = [ | |
| 348 "learning", | |
| 349 "vc", | |
| 350 "calibration", | |
| 351 "dimension", | |
| 352 "manifold", | |
| 353 "rfe", | |
| 354 "threshold", | |
| 355 "percentage_above_below", | |
| 356 ] | |
| 357 else: | |
| 358 summary_plots = ["learning", "vc", "parameter", "residuals"] | |
| 359 | |
| 360 for name in summary_plots: | |
| 361 if name in self.plots: | |
| 362 summary_html += "<hr>" | |
| 363 b64 = encode_image_to_base64(self.plots[name]) | |
| 364 title = plot_title_map.get(name, name.replace("_", " ").title()) | |
| 365 summary_html += ( | |
| 366 '<div class="plot">' | |
| 367 f"<h2>{title}</h2>" | |
| 368 f'<img src="data:image/png;base64,{b64}" ' | |
| 369 'style="max-width:90%;max-height:600px;border:1px solid #ddd;"/>' | |
| 370 "</div>" | |
| 371 ) | |
| 372 | |
| 373 # — Test Summary — | |
| 374 test_html = ( | |
| 375 header | |
| 376 + '<div class="table-wrapper">' | |
| 377 + self.test_result_df.to_html(index=False, classes="table sortable") | |
| 378 + "</div>" | |
| 233 ) | 379 ) |
| 234 best_model_params.to_csv( | 380 if self.task_type == "regression": |
| 235 os.path.join(self.output_dir, "best_model.csv"), index=False | 381 try: |
| 236 ) | 382 y_true = ( |
| 237 self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv")) | 383 pd.Series(self.exp.y_test_transformed) |
| 238 self.test_result_df.to_csv(os.path.join(self.output_dir, "test_results.csv")) | 384 .reset_index(drop=True) |
| 239 | 385 .rename("True") |
| 240 plots_html = "" | 386 ) |
| 241 length = len(self.plots) | 387 y_pred = pd.Series( |
| 242 for i, (plot_name, plot_path) in enumerate(self.plots.items()): | 388 self.best_model.predict(self.exp.X_test_transformed) |
| 243 encoded_image = self.encode_image_to_base64(plot_path) | 389 ).rename("Predicted") |
| 244 plots_html += ( | 390 df_tp = pd.concat([y_true, y_pred], axis=1) |
| 245 f'<div class="plot">' | 391 test_html += "<h2>True vs Predicted Values</h2>" |
| 246 f"<h3>{plot_name.capitalize()}</h3>" | 392 test_html += ( |
| 247 f'<img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">' | 393 '<div class="table-wrapper" style="max-height:400px; overflow-y:auto;">' |
| 248 f"</div>" | 394 + df_tp.head(50).to_html(index=False, classes="table sortable") |
| 249 ) | 395 + "</div>" |
| 250 if i < length - 1: | 396 + add_hr_to_html() |
| 251 plots_html += "<hr>" | 397 ) |
| 252 | 398 except Exception as e: |
| 253 tree_plots = "" | 399 LOG.warning(f"Could not generate True vs Predicted table: {e}") |
| 254 for i, tree in enumerate(self.trees): | 400 |
| 255 if tree: | 401 # 5a) Explainer-substituted plots in order |
| 256 tree_plots += ( | 402 if self.task_type == "regression": |
| 257 f'<div class="plot">' | 403 test_order = ["residuals"] |
| 258 f"<h3>Tree {i + 1}</h3>" | 404 else: |
| 259 f'<img src="data:image/png;base64,{tree}" alt="tree {i + 1}">' | 405 test_order = [ |
| 260 f"</div>" | 406 "confusion_matrix", |
| 261 ) | 407 "roc_auc", |
| 262 | 408 "pr_auc", |
| 263 analyzer = FeatureImportanceAnalyzer( | 409 "lift_curve", |
| 410 "threshold", | |
| 411 "cumulative_precision", | |
| 412 ] | |
| 413 for key in test_order: | |
| 414 fig_or_fn = self.explainer_plots.pop(key, None) | |
| 415 if fig_or_fn is not None: | |
| 416 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn | |
| 417 title = plot_title_map.get(key, key.replace("_", " ").title()) | |
| 418 test_html += ( | |
| 419 f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() | |
| 420 ) | |
| 421 # 5b) Remaining PyCaret test plots | |
| 422 for name, path in self.plots.items(): | |
| 423 # classification: include only the small extras, before skipping anything | |
| 424 if self.task_type == "classification" and name in { | |
| 425 "threshold", | |
| 426 "pr_auc", | |
| 427 "class_report", | |
| 428 }: | |
| 429 title = plot_title_map.get(name, name.replace("_", " ").title()) | |
| 430 b64 = encode_image_to_base64(path) | |
| 431 test_html += ( | |
| 432 f"<h2>{title}</h2>" | |
| 433 "<div class='plot'>" | |
| 434 f"<img src='data:image/png;base64,{b64}' " | |
| 435 "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>" | |
| 436 "</div>" + add_hr_to_html() | |
| 437 ) | |
| 438 continue | |
| 439 | |
| 440 # regression: explicitly include the 'error' plot, before skipping | |
| 441 if self.task_type == "regression" and name == "error": | |
| 442 title = plot_title_map.get("error", "Prediction Error Distribution") | |
| 443 b64 = encode_image_to_base64(path) | |
| 444 test_html += ( | |
| 445 f"<h2>{title}</h2>" | |
| 446 "<div class='plot'>" | |
| 447 f"<img src='data:image/png;base64,{b64}' " | |
| 448 "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>" | |
| 449 "</div>" + add_hr_to_html() | |
| 450 ) | |
| 451 continue | |
| 452 | |
| 453 # now skip any plots already rendered via test_order | |
| 454 if name in test_order: | |
| 455 continue | |
| 456 | |
| 457 # — Feature Importance — | |
| 458 feature_html = header | |
| 459 | |
| 460 # 6a) PyCaret’s default feature importances | |
| 461 feature_html += FeatureImportanceAnalyzer( | |
| 264 data=self.data, | 462 data=self.data, |
| 265 target_col=self.target_col, | 463 target_col=self.target_col, |
| 266 task_type=self.task_type, | 464 task_type=self.task_type, |
| 267 output_dir=self.output_dir, | 465 output_dir=self.output_dir, |
| 268 exp=self.exp, | 466 exp=self.exp, |
| 269 best_model=self.best_model, | 467 best_model=self.best_model, |
| 468 ).run() | |
| 469 | |
| 470 # 6b) Explainer SHAP importances | |
| 471 for key in ["shap_mean", "shap_perm"]: | |
| 472 fig_or_fn = self.explainer_plots.pop(key, None) | |
| 473 if fig_or_fn is not None: | |
| 474 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn | |
| 475 # give SHAP plots explicit titles | |
| 476 title = ( | |
| 477 "Mean Absolute SHAP Value Impact" | |
| 478 if key == "shap_mean" | |
| 479 else "Permutation Feature Importance" | |
| 480 ) | |
| 481 feature_html += ( | |
| 482 f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() | |
| 483 ) | |
| 484 | |
| 485 # 6c) PDPs last | |
| 486 pdp_keys = sorted(k for k in self.explainer_plots if k.startswith("pdp__")) | |
| 487 for k in pdp_keys: | |
| 488 fig_or_fn = self.explainer_plots[k] | |
| 489 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn | |
| 490 # extract feature name | |
| 491 feature = k.split("__", 1)[1] | |
| 492 title = f"Partial Dependence for {feature}" | |
| 493 feature_html += ( | |
| 494 f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() | |
| 495 ) | |
| 496 # 7) Assemble final HTML (three tabs) | |
| 497 html = get_html_template() | |
| 498 html += "<h1>Tabular Learner Model Report</h1>" | |
| 499 html += build_tabbed_html(summary_html, test_html, feature_html) | |
| 500 html += get_feature_metrics_help_modal() | |
| 501 html += get_html_closing() | |
| 502 | |
| 503 # 8) Write out | |
| 504 (Path(self.output_dir) / "comparison_result.html").write_text( | |
| 505 html, encoding="utf-8" | |
| 270 ) | 506 ) |
| 271 feature_importance_html = analyzer.run() | 507 LOG.info(f"HTML report generated at: {self.output_dir}/comparison_result.html") |
| 272 | |
| 273 # --- Feature Metrics Help Button --- | |
| 274 feature_metrics_button_html = ( | |
| 275 '<button class="help-modal-btn" id="openFeatureMetricsHelp" style="margin-bottom:12px;">' | |
| 276 "Help: Metrics Guide" | |
| 277 "</button>" | |
| 278 "<style>" | |
| 279 ".help-modal-btn {" | |
| 280 "background-color: #17623b;" | |
| 281 "color: #fff;" | |
| 282 "border: none;" | |
| 283 "border-radius: 24px;" | |
| 284 "padding: 10px 28px;" | |
| 285 "font-size: 1.1rem;" | |
| 286 "font-weight: bold;" | |
| 287 "letter-spacing: 0.03em;" | |
| 288 "cursor: pointer;" | |
| 289 "transition: background 0.2s, box-shadow 0.2s;" | |
| 290 "box-shadow: 0 2px 8px rgba(23,98,59,0.07);" | |
| 291 "}" | |
| 292 ".help-modal-btn:hover, .help-modal-btn:focus {" | |
| 293 "background-color: #21895e;" | |
| 294 "outline: none;" | |
| 295 "box-shadow: 0 4px 16px rgba(23,98,59,0.14);" | |
| 296 "}" | |
| 297 "</style>" | |
| 298 ) | |
| 299 | |
| 300 html_content = ( | |
| 301 f"{get_html_template()}" | |
| 302 "<h1>Tabular Learner Model Report</h1>" | |
| 303 f"{feature_metrics_button_html}" | |
| 304 '<div class="tabs">' | |
| 305 '<div class="tab" onclick="openTab(event, \'summary\')">' | |
| 306 "Validation Result Summary & Config</div>" | |
| 307 '<div class="tab" onclick="openTab(event, \'plots\')">' | |
| 308 "Test Results</div>" | |
| 309 '<div class="tab" onclick="openTab(event, \'feature\')">' | |
| 310 "Feature Importance</div>" | |
| 311 ) | |
| 312 if self.plots_explainer_html: | |
| 313 html_content += ( | |
| 314 '<div class="tab" onclick="openTab(event, \'explainer\')">' | |
| 315 "Explainer Plots</div>" | |
| 316 ) | |
| 317 html_content += ( | |
| 318 "</div>" | |
| 319 '<div id="summary" class="tab-content">' | |
| 320 f"<h2>Model Metrics from {'Cross-Validation Set' if self.cross_validation else 'Validation set'}</h2>" | |
| 321 f"<h2>Best Model: {model_name}</h2>" | |
| 322 "<h5>The best model is selected by: Accuracy (Classification)" | |
| 323 " or R2 (Regression).</h5>" | |
| 324 f"{self.results.to_html(index=False, classes='table sortable')}" | |
| 325 "<h2>Best Model's Hyperparameters</h2>" | |
| 326 f"{best_model_params.to_html(index=False, header=True, classes='table sortable')}" | |
| 327 "<h2>Setup Parameters</h2>" | |
| 328 f"{setup_params_table.to_html(index=False, header=True, classes='table sortable')}" | |
| 329 "<h5>If you want to know all the experiment setup parameters," | |
| 330 " please check the PyCaret documentation for" | |
| 331 " the classification/regression <code>exp</code> function.</h5>" | |
| 332 "</div>" | |
| 333 '<div id="plots" class="tab-content">' | |
| 334 f"<h2>Best Model: {model_name}</h2>" | |
| 335 "<h5>The best model is selected by: Accuracy (Classification)" | |
| 336 " or R2 (Regression).</h5>" | |
| 337 "<h2>Test Metrics</h2>" | |
| 338 f"{self.test_result_df.to_html(index=False)}" | |
| 339 "<h2>Test Results</h2>" | |
| 340 f"{plots_html}" | |
| 341 "</div>" | |
| 342 '<div id="feature" class="tab-content">' | |
| 343 f"{feature_importance_html}" | |
| 344 "</div>" | |
| 345 ) | |
| 346 if self.plots_explainer_html: | |
| 347 html_content += ( | |
| 348 '<div id="explainer" class="tab-content">' | |
| 349 f"{self.plots_explainer_html}" | |
| 350 f"{tree_plots}" | |
| 351 "</div>" | |
| 352 ) | |
| 353 html_content += ( | |
| 354 "<script>" | |
| 355 "document.addEventListener(\"DOMContentLoaded\", function() {" | |
| 356 "var tables = document.querySelectorAll(\"table.sortable\");" | |
| 357 "tables.forEach(function(table) {" | |
| 358 "var headers = table.querySelectorAll(\"th\");" | |
| 359 "headers.forEach(function(header, index) {" | |
| 360 "header.style.cursor = \"pointer\";" | |
| 361 "// Add initial arrow (up) to indicate sortability, use Unicode ↑ (U+2191)" | |
| 362 "header.innerHTML += '<span class=\"sort-arrow\"> ↑</span>';" | |
| 363 "header.addEventListener(\"click\", function() {" | |
| 364 "var direction = this.getAttribute(" | |
| 365 "\"data-sort-direction\"" | |
| 366 ") || \"asc\";" | |
| 367 "// Reset arrows in all headers of this table" | |
| 368 "headers.forEach(function(h) {" | |
| 369 "var arrow = h.querySelector(\".sort-arrow\");" | |
| 370 "if (arrow) arrow.textContent = \" ↑\";" | |
| 371 "});" | |
| 372 "// Set arrow for clicked header" | |
| 373 "var arrow = this.querySelector(\".sort-arrow\");" | |
| 374 "arrow.textContent = direction === \"asc\" ? \" ↓\" : \" ↑\";" | |
| 375 "sortTable(table, index, direction);" | |
| 376 "this.setAttribute(\"data-sort-direction\"," | |
| 377 "direction === \"asc\" ? \"desc\" : \"asc\");" | |
| 378 "});" | |
| 379 "});" | |
| 380 "});" | |
| 381 "});" | |
| 382 "function sortTable(table, colNum, direction) {" | |
| 383 "var tb = table.tBodies[0];" | |
| 384 "var tr = Array.prototype.slice.call(tb.rows, 0);" | |
| 385 "var multiplier = direction === \"asc\" ? 1 : -1;" | |
| 386 "tr = tr.sort(function(a, b) {" | |
| 387 "var aText = a.cells[colNum].textContent.trim();" | |
| 388 "var bText = b.cells[colNum].textContent.trim();" | |
| 389 "// Remove arrow from text comparison" | |
| 390 "aText = aText.replace(/[↑↓]/g, '').trim();" | |
| 391 "bText = bText.replace(/[↑↓]/g, '').trim();" | |
| 392 "if (!isNaN(aText) && !isNaN(bText)) {" | |
| 393 "return multiplier * (" | |
| 394 "parseFloat(aText) - parseFloat(bText)" | |
| 395 ");" | |
| 396 "} else {" | |
| 397 "return multiplier * aText.localeCompare(bText);" | |
| 398 "}" | |
| 399 "});" | |
| 400 "for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]);" | |
| 401 "}" | |
| 402 "</script>" | |
| 403 ) | |
| 404 # --- Add the Feature Metrics Help Modal --- | |
| 405 html_content += get_feature_metrics_help_modal() | |
| 406 html_content += f"{get_html_closing()}" | |
| 407 with open( | |
| 408 os.path.join(self.output_dir, "comparison_result.html"), | |
| 409 "w", | |
| 410 encoding="utf-8", | |
| 411 ) as file: | |
| 412 file.write(html_content) | |
| 413 | 508 |
| 414 def save_dashboard(self): | 509 def save_dashboard(self): |
| 415 raise NotImplementedError("Subclasses should implement this method") | 510 raise NotImplementedError("Subclasses should implement this method") |
| 416 | 511 |
| 417 def generate_plots_explainer(self): | 512 def generate_plots_explainer(self): |
| 424 | 519 |
| 425 LOG.info("Generating tree plots") | 520 LOG.info("Generating tree plots") |
| 426 X_test = self.exp.X_test_transformed.copy() | 521 X_test = self.exp.X_test_transformed.copy() |
| 427 y_test = self.exp.y_test_transformed | 522 y_test = self.exp.y_test_transformed |
| 428 | 523 |
| 429 is_rf = isinstance( | 524 if isinstance(self.best_model, (RandomForestClassifier, RandomForestRegressor)): |
| 430 self.best_model, (RandomForestClassifier, RandomForestRegressor) | 525 n_trees = self.best_model.n_estimators |
| 431 ) | 526 elif isinstance(self.best_model, (XGBClassifier, XGBRegressor)): |
| 432 is_xgb = isinstance(self.best_model, (XGBClassifier, XGBRegressor)) | 527 n_trees = len(self.best_model.get_booster().get_dump()) |
| 433 | |
| 434 num_trees = None | |
| 435 if is_rf: | |
| 436 num_trees = self.best_model.n_estimators | |
| 437 elif is_xgb: | |
| 438 num_trees = len(self.best_model.get_booster().get_dump()) | |
| 439 else: | 528 else: |
| 440 LOG.warning("Tree plots not supported for this model type.") | 529 LOG.warning("Tree plots not supported for this model type.") |
| 441 return | 530 return |
| 442 | 531 |
| 443 try: | 532 explainer = RandomForestExplainer(self.best_model, X_test, y_test) |
| 444 explainer = RandomForestExplainer(self.best_model, X_test, y_test) | 533 for i in range(n_trees): |
| 445 for i in range(num_trees): | 534 fig = explainer.decisiontree_encoded(tree_idx=i, index=0) |
| 446 fig = explainer.decisiontree_encoded(tree_idx=i, index=0) | 535 self.trees.append(fig) |
| 447 LOG.info(f"Tree {i + 1}") | |
| 448 LOG.info(fig) | |
| 449 self.trees.append(fig) | |
| 450 except Exception as e: | |
| 451 LOG.error(f"Error generating tree plots: {e}") | |
| 452 | 536 |
| 453 def run(self): | 537 def run(self): |
| 454 self.load_data() | 538 self.load_data() |
| 455 self.setup_pycaret() | 539 self.setup_pycaret() |
| 456 self.train_model() | 540 self.train_model() |
