Mercurial > repos > goeckslab > tabular_learner
diff base_model_trainer.py @ 12:15707141e7da draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 2b826699ef9518d4610f5cfb6468ce719ec8039d
| author | goeckslab |
|---|---|
| date | Fri, 28 Nov 2025 22:28:12 +0000 |
| parents | 49f73a3c12f3 |
| children |
line wrap: on
line diff
--- a/base_model_trainer.py Fri Nov 28 15:46:05 2025 +0000 +++ b/base_model_trainer.py Fri Nov 28 22:28:12 2025 +0000 @@ -46,6 +46,7 @@ self.results = None self.tuning_results = None self.features_name = None + self.plot_feature_names = None self.plots = {} self.explainer_plots = {} self.plots_explainer_html = None @@ -53,6 +54,24 @@ self.user_kwargs = kwargs.copy() for key, value in self.user_kwargs.items(): setattr(self, key, value) + if not hasattr(self, "plot_feature_limit"): + self.plot_feature_limit = 30 + self._shap_row_cap = None + if getattr(self, "polynomial_features", False): + # Keep feature importance responsive by trimming plots/SHAP rows + try: + limit_val = int(self.plot_feature_limit) + except (TypeError, ValueError): + limit_val = 30 + self.plot_feature_limit = min(limit_val, 15) + self._shap_row_cap = 200 + LOG.info( + "Polynomial features enabled; limiting feature plots to %s and SHAP rows to %s", + self.plot_feature_limit, + self._shap_row_cap, + ) + self.imputed_training_data = None + self._best_model_metric_used = None self.setup_params = {} self.test_file = test_file self.test_data = None @@ -127,23 +146,7 @@ LOG.info(f"Dataset columns after processing: {names}") self.features_name = [n for n in names if n != self.target] - - if getattr(self, "missing_value_strategy", None): - strat = self.missing_value_strategy - if strat == "mean": - self.data = self.data.fillna( - self.data.mean(numeric_only=True) - ) - elif strat == "median": - self.data = self.data.fillna( - self.data.median(numeric_only=True) - ) - elif strat == "drop": - self.data = self.data.dropna() - else: - self.data = self.data.fillna( - self.data.median(numeric_only=True) - ) + self.plot_feature_names = self._select_plot_features(self.features_name) if self.test_file: LOG.info(f"Loading test data from {self.test_file}") @@ -153,6 +156,52 @@ df_test.columns = df_test.columns.str.replace(".", "_") self.test_data = df_test + def _select_plot_features(self, all_features): + limit = getattr(self, "plot_feature_limit", 30) + if not isinstance(limit, int) or limit <= 0: + LOG.info( + "Feature plotting limit disabled (plot_feature_limit=%s).", limit + ) + return all_features + if len(all_features) <= limit: + LOG.info( + "Feature plotting limit not needed (%s features <= limit %s).", + len(all_features), + limit, + ) + return all_features + df = self.data[all_features].copy() + numeric_cols = df.select_dtypes(include=["number"]).columns + ranked = [] + if len(numeric_cols) > 0: + variances = ( + df[numeric_cols] + .var() + .fillna(0) + .abs() + .sort_values(ascending=False) + ) + ranked = variances.index.tolist() + selected = [] + for col in ranked: + if len(selected) >= limit: + break + selected.append(col) + if len(selected) < limit: + for col in all_features: + if col in selected: + continue + selected.append(col) + if len(selected) >= limit: + break + LOG.info( + "Limiting feature-level plots to %s of %s available features (limit=%s).", + len(selected), + len(all_features), + limit, + ) + return selected + def setup_pycaret(self): LOG.info("Initializing PyCaret") self.setup_params = { @@ -198,29 +247,41 @@ ) self.exp.setup(self.data, **self.setup_params) + self._capture_imputed_training_data() self.setup_params.update(self.user_kwargs) - def _normalize_metric(self, m: str) -> str: - if not m: - return "R2" if self.task_type == "regression" else "Accuracy" - m_low = str(m).strip().lower() - alias = { - "auc": "AUC", "roc_auc": "AUC", "roc-auc": "AUC", - "accuracy": "Accuracy", - "precision": "Precision", - "recall": "Recall", - "f1": "F1", - "kappa": "Kappa", - "logloss": "Log Loss", "log_loss": "Log Loss", - "pr_auc": "PR-AUC-Weighted", "prauc": "PR-AUC-Weighted", - "r2": "R2", - "mae": "MAE", - "mse": "MSE", - "rmse": "RMSE", - "rmsle": "RMSLE", - "mape": "MAPE", - } - return alias.get(m_low, m) + def _capture_imputed_training_data(self): + """ + Cache the dataset as transformed/imputed by PyCaret so downstream + components (e.g., feature importance) can operate on the exact data + used for training. + """ + if self.exp is None: + return + try: + X_processed = self.exp.get_config("X_transformed").copy() + y_processed = self.exp.get_config("y") + if isinstance(y_processed, pd.Series): + y_series = y_processed.reset_index(drop=True) + else: + y_series = pd.Series(y_processed) + y_series.name = self.target + X_processed = X_processed.reset_index(drop=True) + self.imputed_training_data = pd.concat( + [X_processed, y_series], axis=1 + ) + LOG.info( + "Captured imputed training dataset from PyCaret " + "(%s rows, %s features).", + self.imputed_training_data.shape[0], + self.imputed_training_data.shape[1] - 1, + ) + except Exception as exc: + LOG.warning( + "Unable to capture processed training data from PyCaret: %s", + exc, + ) + self.imputed_training_data = None def train_model(self): LOG.info("Training and selecting the best model") @@ -245,17 +306,16 @@ if getattr(self, "cross_validation_folds", None) is not None: compare_kwargs["fold"] = self.cross_validation_folds - chosen_metric = self._normalize_metric(getattr(self, "best_model_metric", None)) - if chosen_metric: - compare_kwargs["sort"] = chosen_metric - self.chosen_metric_label = chosen_metric - try: - setattr(self.exp, "_fold_metric", chosen_metric) - except Exception as e: - LOG.warning(f"Failed to set '_fold_metric' to '{chosen_metric}': {e}", exc_info=True) + best_metric = getattr(self, "best_model_metric", None) + if best_metric: + compare_kwargs["sort"] = best_metric + self._best_model_metric_used = best_metric + LOG.info(f"Ranking models using metric: {best_metric}") LOG.info(f"compare_models kwargs: {compare_kwargs}") self.best_model = self.exp.compare_models(**compare_kwargs) + if self._best_model_metric_used is None: + self._best_model_metric_used = getattr(self.exp, "_fold_metric", None) self.results = self.exp.pull() if getattr(self, "tune_model", False): LOG.info("Tuning hyperparameters of the best model") @@ -327,6 +387,31 @@ with open(img_path, "rb") as img_file: return base64.b64encode(img_file.read()).decode("utf-8") + def _resolve_plot_callable(self, key, fig_or_fn, section): + """ + Safely execute stored plot callables so a single failure does not + abort the entire HTML report generation. + """ + if fig_or_fn is None: + return None + try: + return fig_or_fn() if callable(fig_or_fn) else fig_or_fn + except Exception as exc: + extra = "" + if isinstance(exc, ValueError) and "Input contains NaN" in str(exc): + extra = ( + " (model returned NaN probabilities; " + "consider checking data preprocessing)" + ) + LOG.warning( + "Skipping %s plot '%s' due to error: %s%s", + section, + key, + exc, + extra, + ) + return None + def save_html_report(self): LOG.info("Saving HTML report") @@ -401,8 +486,11 @@ else: dv = v if v is not None else "None" setup_rows.append([key, dv]) - if getattr(self, "chosen_metric_label", None): - setup_rows.append(["Best Model Metric", self.chosen_metric_label]) + metric_label = self._best_model_metric_used or getattr( + self.exp, "_fold_metric", None + ) + if metric_label: + setup_rows.append(["Best Model Metric", metric_label]) df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"]) df_setup.to_csv( @@ -564,13 +652,16 @@ "roc_auc", "pr_auc", "lift_curve", - "threshold", "cumulative_precision", ] for key in test_order: fig_or_fn = self.explainer_plots.pop(key, None) if fig_or_fn is not None: - fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn + fig = self._resolve_plot_callable( + key, fig_or_fn, section="test/explainer" + ) + if fig is None: + continue title = plot_title_map.get( key, key.replace("_", " ").title() ) @@ -584,7 +675,6 @@ # skipping anything if self.task_type == "classification" and ( name in { - "threshold", "pr_auc", "class_report", } @@ -630,20 +720,57 @@ feature_html = header # 6a) PyCaret’s default feature importances - feature_html += FeatureImportanceAnalyzer( - data=self.data, + imputed_data = ( + self.imputed_training_data + if self.imputed_training_data is not None + else self.data + ) + fi_analyzer = FeatureImportanceAnalyzer( + data=imputed_data, target_col=self.target_col, task_type=self.task_type, output_dir=self.output_dir, exp=self.exp, best_model=self.best_model, - ).run() + max_plot_features=self.plot_feature_limit, + processed_data=self.imputed_training_data, + max_shap_rows=self._shap_row_cap, + ) + fi_html = fi_analyzer.run() + # Add a small table to show SHAP feature caps near the Best Model header. + cap_rows = [] + if fi_analyzer.shap_total_features is not None: + cap_rows.append( + ("Total transformed features", fi_analyzer.shap_total_features) + ) + if fi_analyzer.shap_used_features is not None: + cap_rows.append( + ("Features used in SHAP", fi_analyzer.shap_used_features) + ) + if cap_rows: + cap_table = ( + "<div class='table-wrapper'>" + "<table class='table sortable'>" + "<thead><tr><th>Feature Importance Scope</th><th>Count</th></tr></thead>" + "<tbody>" + + "".join( + f"<tr><td>{label}</td><td>{value}</td></tr>" + for label, value in cap_rows + ) + + "</tbody></table></div>" + ) + feature_html += cap_table + feature_html += fi_html # 6b) Explainer SHAP importances for key in ["shap_mean", "shap_perm"]: fig_or_fn = self.explainer_plots.pop(key, None) if fig_or_fn is not None: - fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn + fig = self._resolve_plot_callable( + key, fig_or_fn, section="feature importance" + ) + if fig is None: + continue # give SHAP plots explicit titles title = ( "Mean Absolute SHAP Value Impact" @@ -661,7 +788,11 @@ ) for k in pdp_keys: fig_or_fn = self.explainer_plots[k] - fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn + fig = self._resolve_plot_callable( + k, fig_or_fn, section="pdp" + ) + if fig is None: + continue # extract feature name feature = k.split("__", 1)[1] title = f"Partial Dependence for {feature}"
