diff base_model_trainer.py @ 12:15707141e7da draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 2b826699ef9518d4610f5cfb6468ce719ec8039d
author goeckslab
date Fri, 28 Nov 2025 22:28:12 +0000
parents 49f73a3c12f3
children
line wrap: on
line diff
--- a/base_model_trainer.py	Fri Nov 28 15:46:05 2025 +0000
+++ b/base_model_trainer.py	Fri Nov 28 22:28:12 2025 +0000
@@ -46,6 +46,7 @@
         self.results = None
         self.tuning_results = None
         self.features_name = None
+        self.plot_feature_names = None
         self.plots = {}
         self.explainer_plots = {}
         self.plots_explainer_html = None
@@ -53,6 +54,24 @@
         self.user_kwargs = kwargs.copy()
         for key, value in self.user_kwargs.items():
             setattr(self, key, value)
+        if not hasattr(self, "plot_feature_limit"):
+            self.plot_feature_limit = 30
+        self._shap_row_cap = None
+        if getattr(self, "polynomial_features", False):
+            # Keep feature importance responsive by trimming plots/SHAP rows
+            try:
+                limit_val = int(self.plot_feature_limit)
+            except (TypeError, ValueError):
+                limit_val = 30
+            self.plot_feature_limit = min(limit_val, 15)
+            self._shap_row_cap = 200
+            LOG.info(
+                "Polynomial features enabled; limiting feature plots to %s and SHAP rows to %s",
+                self.plot_feature_limit,
+                self._shap_row_cap,
+            )
+        self.imputed_training_data = None
+        self._best_model_metric_used = None
         self.setup_params = {}
         self.test_file = test_file
         self.test_data = None
@@ -127,23 +146,7 @@
         LOG.info(f"Dataset columns after processing: {names}")
 
         self.features_name = [n for n in names if n != self.target]
-
-        if getattr(self, "missing_value_strategy", None):
-            strat = self.missing_value_strategy
-            if strat == "mean":
-                self.data = self.data.fillna(
-                    self.data.mean(numeric_only=True)
-                )
-            elif strat == "median":
-                self.data = self.data.fillna(
-                    self.data.median(numeric_only=True)
-                )
-            elif strat == "drop":
-                self.data = self.data.dropna()
-        else:
-            self.data = self.data.fillna(
-                self.data.median(numeric_only=True)
-            )
+        self.plot_feature_names = self._select_plot_features(self.features_name)
 
         if self.test_file:
             LOG.info(f"Loading test data from {self.test_file}")
@@ -153,6 +156,52 @@
             df_test.columns = df_test.columns.str.replace(".", "_")
             self.test_data = df_test
 
+    def _select_plot_features(self, all_features):
+        limit = getattr(self, "plot_feature_limit", 30)
+        if not isinstance(limit, int) or limit <= 0:
+            LOG.info(
+                "Feature plotting limit disabled (plot_feature_limit=%s).", limit
+            )
+            return all_features
+        if len(all_features) <= limit:
+            LOG.info(
+                "Feature plotting limit not needed (%s features <= limit %s).",
+                len(all_features),
+                limit,
+            )
+            return all_features
+        df = self.data[all_features].copy()
+        numeric_cols = df.select_dtypes(include=["number"]).columns
+        ranked = []
+        if len(numeric_cols) > 0:
+            variances = (
+                df[numeric_cols]
+                .var()
+                .fillna(0)
+                .abs()
+                .sort_values(ascending=False)
+            )
+            ranked = variances.index.tolist()
+        selected = []
+        for col in ranked:
+            if len(selected) >= limit:
+                break
+            selected.append(col)
+        if len(selected) < limit:
+            for col in all_features:
+                if col in selected:
+                    continue
+                selected.append(col)
+                if len(selected) >= limit:
+                    break
+        LOG.info(
+            "Limiting feature-level plots to %s of %s available features (limit=%s).",
+            len(selected),
+            len(all_features),
+            limit,
+        )
+        return selected
+
     def setup_pycaret(self):
         LOG.info("Initializing PyCaret")
         self.setup_params = {
@@ -198,29 +247,41 @@
             )
 
         self.exp.setup(self.data, **self.setup_params)
+        self._capture_imputed_training_data()
         self.setup_params.update(self.user_kwargs)
 
-    def _normalize_metric(self, m: str) -> str:
-        if not m:
-            return "R2" if self.task_type == "regression" else "Accuracy"
-        m_low = str(m).strip().lower()
-        alias = {
-            "auc": "AUC", "roc_auc": "AUC", "roc-auc": "AUC",
-            "accuracy": "Accuracy",
-            "precision": "Precision",
-            "recall": "Recall",
-            "f1": "F1",
-            "kappa": "Kappa",
-            "logloss": "Log Loss", "log_loss": "Log Loss",
-            "pr_auc": "PR-AUC-Weighted", "prauc": "PR-AUC-Weighted",
-            "r2": "R2",
-            "mae": "MAE",
-            "mse": "MSE",
-            "rmse": "RMSE",
-            "rmsle": "RMSLE",
-            "mape": "MAPE",
-        }
-        return alias.get(m_low, m)
+    def _capture_imputed_training_data(self):
+        """
+        Cache the dataset as transformed/imputed by PyCaret so downstream
+        components (e.g., feature importance) can operate on the exact data
+        used for training.
+        """
+        if self.exp is None:
+            return
+        try:
+            X_processed = self.exp.get_config("X_transformed").copy()
+            y_processed = self.exp.get_config("y")
+            if isinstance(y_processed, pd.Series):
+                y_series = y_processed.reset_index(drop=True)
+            else:
+                y_series = pd.Series(y_processed)
+            y_series.name = self.target
+            X_processed = X_processed.reset_index(drop=True)
+            self.imputed_training_data = pd.concat(
+                [X_processed, y_series], axis=1
+            )
+            LOG.info(
+                "Captured imputed training dataset from PyCaret "
+                "(%s rows, %s features).",
+                self.imputed_training_data.shape[0],
+                self.imputed_training_data.shape[1] - 1,
+            )
+        except Exception as exc:
+            LOG.warning(
+                "Unable to capture processed training data from PyCaret: %s",
+                exc,
+            )
+            self.imputed_training_data = None
 
     def train_model(self):
         LOG.info("Training and selecting the best model")
@@ -245,17 +306,16 @@
         if getattr(self, "cross_validation_folds", None) is not None:
             compare_kwargs["fold"] = self.cross_validation_folds
 
-        chosen_metric = self._normalize_metric(getattr(self, "best_model_metric", None))
-        if chosen_metric:
-            compare_kwargs["sort"] = chosen_metric
-            self.chosen_metric_label = chosen_metric
-            try:
-                setattr(self.exp, "_fold_metric", chosen_metric)
-            except Exception as e:
-                LOG.warning(f"Failed to set '_fold_metric' to '{chosen_metric}': {e}", exc_info=True)
+        best_metric = getattr(self, "best_model_metric", None)
+        if best_metric:
+            compare_kwargs["sort"] = best_metric
+            self._best_model_metric_used = best_metric
+            LOG.info(f"Ranking models using metric: {best_metric}")
 
         LOG.info(f"compare_models kwargs: {compare_kwargs}")
         self.best_model = self.exp.compare_models(**compare_kwargs)
+        if self._best_model_metric_used is None:
+            self._best_model_metric_used = getattr(self.exp, "_fold_metric", None)
         self.results = self.exp.pull()
         if getattr(self, "tune_model", False):
             LOG.info("Tuning hyperparameters of the best model")
@@ -327,6 +387,31 @@
         with open(img_path, "rb") as img_file:
             return base64.b64encode(img_file.read()).decode("utf-8")
 
+    def _resolve_plot_callable(self, key, fig_or_fn, section):
+        """
+        Safely execute stored plot callables so a single failure does not
+        abort the entire HTML report generation.
+        """
+        if fig_or_fn is None:
+            return None
+        try:
+            return fig_or_fn() if callable(fig_or_fn) else fig_or_fn
+        except Exception as exc:
+            extra = ""
+            if isinstance(exc, ValueError) and "Input contains NaN" in str(exc):
+                extra = (
+                    " (model returned NaN probabilities; "
+                    "consider checking data preprocessing)"
+                )
+            LOG.warning(
+                "Skipping %s plot '%s' due to error: %s%s",
+                section,
+                key,
+                exc,
+                extra,
+            )
+            return None
+
     def save_html_report(self):
         LOG.info("Saving HTML report")
 
@@ -401,8 +486,11 @@
             else:
                 dv = v if v is not None else "None"
             setup_rows.append([key, dv])
-        if getattr(self, "chosen_metric_label", None):
-            setup_rows.append(["Best Model Metric", self.chosen_metric_label])
+        metric_label = self._best_model_metric_used or getattr(
+            self.exp, "_fold_metric", None
+        )
+        if metric_label:
+            setup_rows.append(["Best Model Metric", metric_label])
 
         df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"])
         df_setup.to_csv(
@@ -564,13 +652,16 @@
                 "roc_auc",
                 "pr_auc",
                 "lift_curve",
-                "threshold",
                 "cumulative_precision",
             ]
         for key in test_order:
             fig_or_fn = self.explainer_plots.pop(key, None)
             if fig_or_fn is not None:
-                fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
+                fig = self._resolve_plot_callable(
+                    key, fig_or_fn, section="test/explainer"
+                )
+                if fig is None:
+                    continue
                 title = plot_title_map.get(
                     key, key.replace("_", " ").title()
                 )
@@ -584,7 +675,6 @@
             # skipping anything
             if self.task_type == "classification" and (
                 name in {
-                    "threshold",
                     "pr_auc",
                     "class_report",
                 }
@@ -630,20 +720,57 @@
         feature_html = header
 
         # 6a) PyCaret’s default feature importances
-        feature_html += FeatureImportanceAnalyzer(
-            data=self.data,
+        imputed_data = (
+            self.imputed_training_data
+            if self.imputed_training_data is not None
+            else self.data
+        )
+        fi_analyzer = FeatureImportanceAnalyzer(
+            data=imputed_data,
             target_col=self.target_col,
             task_type=self.task_type,
             output_dir=self.output_dir,
             exp=self.exp,
             best_model=self.best_model,
-        ).run()
+            max_plot_features=self.plot_feature_limit,
+            processed_data=self.imputed_training_data,
+            max_shap_rows=self._shap_row_cap,
+        )
+        fi_html = fi_analyzer.run()
+        # Add a small table to show SHAP feature caps near the Best Model header.
+        cap_rows = []
+        if fi_analyzer.shap_total_features is not None:
+            cap_rows.append(
+                ("Total transformed features", fi_analyzer.shap_total_features)
+            )
+        if fi_analyzer.shap_used_features is not None:
+            cap_rows.append(
+                ("Features used in SHAP", fi_analyzer.shap_used_features)
+            )
+        if cap_rows:
+            cap_table = (
+                "<div class='table-wrapper'>"
+                "<table class='table sortable'>"
+                "<thead><tr><th>Feature Importance Scope</th><th>Count</th></tr></thead>"
+                "<tbody>"
+                + "".join(
+                    f"<tr><td>{label}</td><td>{value}</td></tr>"
+                    for label, value in cap_rows
+                )
+                + "</tbody></table></div>"
+            )
+            feature_html += cap_table
+        feature_html += fi_html
 
         # 6b) Explainer SHAP importances
         for key in ["shap_mean", "shap_perm"]:
             fig_or_fn = self.explainer_plots.pop(key, None)
             if fig_or_fn is not None:
-                fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
+                fig = self._resolve_plot_callable(
+                    key, fig_or_fn, section="feature importance"
+                )
+                if fig is None:
+                    continue
                 # give SHAP plots explicit titles
                 title = (
                     "Mean Absolute SHAP Value Impact"
@@ -661,7 +788,11 @@
         )
         for k in pdp_keys:
             fig_or_fn = self.explainer_plots[k]
-            fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
+            fig = self._resolve_plot_callable(
+                k, fig_or_fn, section="pdp"
+            )
+            if fig is None:
+                continue
             # extract feature name
             feature = k.split("__", 1)[1]
             title = f"Partial Dependence for {feature}"