changeset 12:15707141e7da draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 2b826699ef9518d4610f5cfb6468ce719ec8039d
author goeckslab
date Fri, 28 Nov 2025 22:28:12 +0000
parents a76dfceb62e0
children
files base_model_trainer.py feature_importance.py pycaret_macros.xml
diffstat 3 files changed, 475 insertions(+), 100 deletions(-) [+]
line wrap: on
line diff
--- a/base_model_trainer.py	Fri Nov 28 15:46:05 2025 +0000
+++ b/base_model_trainer.py	Fri Nov 28 22:28:12 2025 +0000
@@ -46,6 +46,7 @@
         self.results = None
         self.tuning_results = None
         self.features_name = None
+        self.plot_feature_names = None
         self.plots = {}
         self.explainer_plots = {}
         self.plots_explainer_html = None
@@ -53,6 +54,24 @@
         self.user_kwargs = kwargs.copy()
         for key, value in self.user_kwargs.items():
             setattr(self, key, value)
+        if not hasattr(self, "plot_feature_limit"):
+            self.plot_feature_limit = 30
+        self._shap_row_cap = None
+        if getattr(self, "polynomial_features", False):
+            # Keep feature importance responsive by trimming plots/SHAP rows
+            try:
+                limit_val = int(self.plot_feature_limit)
+            except (TypeError, ValueError):
+                limit_val = 30
+            self.plot_feature_limit = min(limit_val, 15)
+            self._shap_row_cap = 200
+            LOG.info(
+                "Polynomial features enabled; limiting feature plots to %s and SHAP rows to %s",
+                self.plot_feature_limit,
+                self._shap_row_cap,
+            )
+        self.imputed_training_data = None
+        self._best_model_metric_used = None
         self.setup_params = {}
         self.test_file = test_file
         self.test_data = None
@@ -127,23 +146,7 @@
         LOG.info(f"Dataset columns after processing: {names}")
 
         self.features_name = [n for n in names if n != self.target]
-
-        if getattr(self, "missing_value_strategy", None):
-            strat = self.missing_value_strategy
-            if strat == "mean":
-                self.data = self.data.fillna(
-                    self.data.mean(numeric_only=True)
-                )
-            elif strat == "median":
-                self.data = self.data.fillna(
-                    self.data.median(numeric_only=True)
-                )
-            elif strat == "drop":
-                self.data = self.data.dropna()
-        else:
-            self.data = self.data.fillna(
-                self.data.median(numeric_only=True)
-            )
+        self.plot_feature_names = self._select_plot_features(self.features_name)
 
         if self.test_file:
             LOG.info(f"Loading test data from {self.test_file}")
@@ -153,6 +156,52 @@
             df_test.columns = df_test.columns.str.replace(".", "_")
             self.test_data = df_test
 
+    def _select_plot_features(self, all_features):
+        limit = getattr(self, "plot_feature_limit", 30)
+        if not isinstance(limit, int) or limit <= 0:
+            LOG.info(
+                "Feature plotting limit disabled (plot_feature_limit=%s).", limit
+            )
+            return all_features
+        if len(all_features) <= limit:
+            LOG.info(
+                "Feature plotting limit not needed (%s features <= limit %s).",
+                len(all_features),
+                limit,
+            )
+            return all_features
+        df = self.data[all_features].copy()
+        numeric_cols = df.select_dtypes(include=["number"]).columns
+        ranked = []
+        if len(numeric_cols) > 0:
+            variances = (
+                df[numeric_cols]
+                .var()
+                .fillna(0)
+                .abs()
+                .sort_values(ascending=False)
+            )
+            ranked = variances.index.tolist()
+        selected = []
+        for col in ranked:
+            if len(selected) >= limit:
+                break
+            selected.append(col)
+        if len(selected) < limit:
+            for col in all_features:
+                if col in selected:
+                    continue
+                selected.append(col)
+                if len(selected) >= limit:
+                    break
+        LOG.info(
+            "Limiting feature-level plots to %s of %s available features (limit=%s).",
+            len(selected),
+            len(all_features),
+            limit,
+        )
+        return selected
+
     def setup_pycaret(self):
         LOG.info("Initializing PyCaret")
         self.setup_params = {
@@ -198,29 +247,41 @@
             )
 
         self.exp.setup(self.data, **self.setup_params)
+        self._capture_imputed_training_data()
         self.setup_params.update(self.user_kwargs)
 
-    def _normalize_metric(self, m: str) -> str:
-        if not m:
-            return "R2" if self.task_type == "regression" else "Accuracy"
-        m_low = str(m).strip().lower()
-        alias = {
-            "auc": "AUC", "roc_auc": "AUC", "roc-auc": "AUC",
-            "accuracy": "Accuracy",
-            "precision": "Precision",
-            "recall": "Recall",
-            "f1": "F1",
-            "kappa": "Kappa",
-            "logloss": "Log Loss", "log_loss": "Log Loss",
-            "pr_auc": "PR-AUC-Weighted", "prauc": "PR-AUC-Weighted",
-            "r2": "R2",
-            "mae": "MAE",
-            "mse": "MSE",
-            "rmse": "RMSE",
-            "rmsle": "RMSLE",
-            "mape": "MAPE",
-        }
-        return alias.get(m_low, m)
+    def _capture_imputed_training_data(self):
+        """
+        Cache the dataset as transformed/imputed by PyCaret so downstream
+        components (e.g., feature importance) can operate on the exact data
+        used for training.
+        """
+        if self.exp is None:
+            return
+        try:
+            X_processed = self.exp.get_config("X_transformed").copy()
+            y_processed = self.exp.get_config("y")
+            if isinstance(y_processed, pd.Series):
+                y_series = y_processed.reset_index(drop=True)
+            else:
+                y_series = pd.Series(y_processed)
+            y_series.name = self.target
+            X_processed = X_processed.reset_index(drop=True)
+            self.imputed_training_data = pd.concat(
+                [X_processed, y_series], axis=1
+            )
+            LOG.info(
+                "Captured imputed training dataset from PyCaret "
+                "(%s rows, %s features).",
+                self.imputed_training_data.shape[0],
+                self.imputed_training_data.shape[1] - 1,
+            )
+        except Exception as exc:
+            LOG.warning(
+                "Unable to capture processed training data from PyCaret: %s",
+                exc,
+            )
+            self.imputed_training_data = None
 
     def train_model(self):
         LOG.info("Training and selecting the best model")
@@ -245,17 +306,16 @@
         if getattr(self, "cross_validation_folds", None) is not None:
             compare_kwargs["fold"] = self.cross_validation_folds
 
-        chosen_metric = self._normalize_metric(getattr(self, "best_model_metric", None))
-        if chosen_metric:
-            compare_kwargs["sort"] = chosen_metric
-            self.chosen_metric_label = chosen_metric
-            try:
-                setattr(self.exp, "_fold_metric", chosen_metric)
-            except Exception as e:
-                LOG.warning(f"Failed to set '_fold_metric' to '{chosen_metric}': {e}", exc_info=True)
+        best_metric = getattr(self, "best_model_metric", None)
+        if best_metric:
+            compare_kwargs["sort"] = best_metric
+            self._best_model_metric_used = best_metric
+            LOG.info(f"Ranking models using metric: {best_metric}")
 
         LOG.info(f"compare_models kwargs: {compare_kwargs}")
         self.best_model = self.exp.compare_models(**compare_kwargs)
+        if self._best_model_metric_used is None:
+            self._best_model_metric_used = getattr(self.exp, "_fold_metric", None)
         self.results = self.exp.pull()
         if getattr(self, "tune_model", False):
             LOG.info("Tuning hyperparameters of the best model")
@@ -327,6 +387,31 @@
         with open(img_path, "rb") as img_file:
             return base64.b64encode(img_file.read()).decode("utf-8")
 
+    def _resolve_plot_callable(self, key, fig_or_fn, section):
+        """
+        Safely execute stored plot callables so a single failure does not
+        abort the entire HTML report generation.
+        """
+        if fig_or_fn is None:
+            return None
+        try:
+            return fig_or_fn() if callable(fig_or_fn) else fig_or_fn
+        except Exception as exc:
+            extra = ""
+            if isinstance(exc, ValueError) and "Input contains NaN" in str(exc):
+                extra = (
+                    " (model returned NaN probabilities; "
+                    "consider checking data preprocessing)"
+                )
+            LOG.warning(
+                "Skipping %s plot '%s' due to error: %s%s",
+                section,
+                key,
+                exc,
+                extra,
+            )
+            return None
+
     def save_html_report(self):
         LOG.info("Saving HTML report")
 
@@ -401,8 +486,11 @@
             else:
                 dv = v if v is not None else "None"
             setup_rows.append([key, dv])
-        if getattr(self, "chosen_metric_label", None):
-            setup_rows.append(["Best Model Metric", self.chosen_metric_label])
+        metric_label = self._best_model_metric_used or getattr(
+            self.exp, "_fold_metric", None
+        )
+        if metric_label:
+            setup_rows.append(["Best Model Metric", metric_label])
 
         df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"])
         df_setup.to_csv(
@@ -564,13 +652,16 @@
                 "roc_auc",
                 "pr_auc",
                 "lift_curve",
-                "threshold",
                 "cumulative_precision",
             ]
         for key in test_order:
             fig_or_fn = self.explainer_plots.pop(key, None)
             if fig_or_fn is not None:
-                fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
+                fig = self._resolve_plot_callable(
+                    key, fig_or_fn, section="test/explainer"
+                )
+                if fig is None:
+                    continue
                 title = plot_title_map.get(
                     key, key.replace("_", " ").title()
                 )
@@ -584,7 +675,6 @@
             # skipping anything
             if self.task_type == "classification" and (
                 name in {
-                    "threshold",
                     "pr_auc",
                     "class_report",
                 }
@@ -630,20 +720,57 @@
         feature_html = header
 
         # 6a) PyCaret’s default feature importances
-        feature_html += FeatureImportanceAnalyzer(
-            data=self.data,
+        imputed_data = (
+            self.imputed_training_data
+            if self.imputed_training_data is not None
+            else self.data
+        )
+        fi_analyzer = FeatureImportanceAnalyzer(
+            data=imputed_data,
             target_col=self.target_col,
             task_type=self.task_type,
             output_dir=self.output_dir,
             exp=self.exp,
             best_model=self.best_model,
-        ).run()
+            max_plot_features=self.plot_feature_limit,
+            processed_data=self.imputed_training_data,
+            max_shap_rows=self._shap_row_cap,
+        )
+        fi_html = fi_analyzer.run()
+        # Add a small table to show SHAP feature caps near the Best Model header.
+        cap_rows = []
+        if fi_analyzer.shap_total_features is not None:
+            cap_rows.append(
+                ("Total transformed features", fi_analyzer.shap_total_features)
+            )
+        if fi_analyzer.shap_used_features is not None:
+            cap_rows.append(
+                ("Features used in SHAP", fi_analyzer.shap_used_features)
+            )
+        if cap_rows:
+            cap_table = (
+                "<div class='table-wrapper'>"
+                "<table class='table sortable'>"
+                "<thead><tr><th>Feature Importance Scope</th><th>Count</th></tr></thead>"
+                "<tbody>"
+                + "".join(
+                    f"<tr><td>{label}</td><td>{value}</td></tr>"
+                    for label, value in cap_rows
+                )
+                + "</tbody></table></div>"
+            )
+            feature_html += cap_table
+        feature_html += fi_html
 
         # 6b) Explainer SHAP importances
         for key in ["shap_mean", "shap_perm"]:
             fig_or_fn = self.explainer_plots.pop(key, None)
             if fig_or_fn is not None:
-                fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
+                fig = self._resolve_plot_callable(
+                    key, fig_or_fn, section="feature importance"
+                )
+                if fig is None:
+                    continue
                 # give SHAP plots explicit titles
                 title = (
                     "Mean Absolute SHAP Value Impact"
@@ -661,7 +788,11 @@
         )
         for k in pdp_keys:
             fig_or_fn = self.explainer_plots[k]
-            fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
+            fig = self._resolve_plot_callable(
+                k, fig_or_fn, section="pdp"
+            )
+            if fig is None:
+                continue
             # extract feature name
             feature = k.split("__", 1)[1]
             title = f"Partial Dependence for {feature}"
--- a/feature_importance.py	Fri Nov 28 15:46:05 2025 +0000
+++ b/feature_importance.py	Fri Nov 28 22:28:12 2025 +0000
@@ -22,11 +22,23 @@
         target_col=None,
         exp=None,
         best_model=None,
+        max_plot_features=None,
+        processed_data=None,
+        max_shap_rows=None,
     ):
         self.task_type = task_type
         self.output_dir = output_dir
         self.exp = exp
         self.best_model = best_model
+        self._skip_messages = []
+        self.shap_total_features = None
+        self.shap_used_features = None
+        if isinstance(max_plot_features, int) and max_plot_features > 0:
+            self.max_plot_features = max_plot_features
+        elif max_plot_features is None:
+            self.max_plot_features = 30
+        else:
+            self.max_plot_features = None
 
         if exp is not None:
             # Assume all configs (data, target) are in exp
@@ -48,8 +60,55 @@
                 if task_type == "classification"
                 else RegressionExperiment()
             )
+        if processed_data is not None:
+            self.data = processed_data
 
         self.plots = {}
+        self.max_shap_rows = max_shap_rows
+
+    def _get_feature_names_from_model(self, model):
+        """Best-effort extraction of feature names seen by the estimator."""
+        if model is None:
+            return None
+
+        candidates = [model]
+        if hasattr(model, "named_steps"):
+            candidates.extend(model.named_steps.values())
+        elif hasattr(model, "steps"):
+            candidates.extend(step for _, step in model.steps)
+
+        for candidate in candidates:
+            names = getattr(candidate, "feature_names_in_", None)
+            if names is not None:
+                return list(names)
+        return None
+
+    def _get_transformed_frame(self, model=None, prefer_test=True):
+        """Return a DataFrame that mirrors the matrix fed to the estimator."""
+        key_order = ["X_test_transformed", "X_train_transformed"]
+        if not prefer_test:
+            key_order.reverse()
+        key_order.append("X_transformed")
+
+        feature_names = self._get_feature_names_from_model(model)
+        for key in key_order:
+            try:
+                frame = self.exp.get_config(key)
+            except KeyError:
+                continue
+            if frame is None:
+                continue
+            if isinstance(frame, pd.DataFrame):
+                return frame.copy()
+            try:
+                n_features = frame.shape[1]
+            except Exception:
+                continue
+            if feature_names and len(feature_names) == n_features:
+                return pd.DataFrame(frame, columns=feature_names)
+            # Fallback to positional names so downstream logic still works
+            return pd.DataFrame(frame, columns=[f"f{i}" for i in range(n_features)])
+        return None
 
     def setup_pycaret(self):
         if self.exp is not None and hasattr(self.exp, "is_setup") and self.exp.is_setup:
@@ -67,7 +126,14 @@
 
     def save_tree_importance(self):
         model = self.best_model or self.exp.get_config("best_model")
-        processed_features = self.exp.get_config("X_transformed").columns
+        processed_frame = self._get_transformed_frame(model, prefer_test=False)
+        if processed_frame is None:
+            LOG.warning(
+                "Unable to determine transformed feature names; skipping tree importance plot."
+            )
+            self.tree_model_name = None
+            return
+        processed_features = list(processed_frame.columns)
 
         importances = None
         model_type = model.__class__.__name__
@@ -85,20 +151,42 @@
             return
 
         if len(importances) != len(processed_features):
-            LOG.warning(
-                f"Importances ({len(importances)}) != features ({len(processed_features)}). Skipping tree importance."
-            )
-            self.tree_model_name = None
-            return
+            model_feature_names = self._get_feature_names_from_model(model)
+            if model_feature_names and len(model_feature_names) == len(importances):
+                processed_features = model_feature_names
+            else:
+                LOG.warning(
+                    "Importances (%s) != features (%s). Skipping tree importance.",
+                    len(importances),
+                    len(processed_features),
+                )
+                self.tree_model_name = None
+                return
 
         feature_importances = pd.DataFrame(
             {"Feature": processed_features, "Importance": importances}
         ).sort_values(by="Importance", ascending=False)
+        cap = (
+            min(self.max_plot_features, len(feature_importances))
+            if self.max_plot_features is not None
+            else len(feature_importances)
+        )
+        plot_importances = feature_importances.head(cap)
+        if cap < len(feature_importances):
+            LOG.info(
+                "Tree importance plot limited to top %s of %s features",
+                cap,
+                len(feature_importances),
+            )
         plt.figure(figsize=(10, 6))
-        plt.barh(feature_importances["Feature"], feature_importances["Importance"])
+        plt.barh(
+            plot_importances["Feature"],
+            plot_importances["Importance"],
+        )
         plt.xlabel("Importance")
-        plt.title(f"Feature Importance ({model_type})")
+        plt.title(f"Feature Importance ({model_type}) (top {cap})")
         plot_path = os.path.join(self.output_dir, "tree_importance.png")
+        plt.tight_layout()
         plt.savefig(plot_path, bbox_inches="tight")
         plt.close()
         self.plots["tree_importance"] = plot_path
@@ -106,23 +194,22 @@
     def save_shap_values(self, max_samples=None, max_display=None, max_features=None):
         model = self.best_model or self.exp.get_config("best_model")
 
-        X_data = None
-        for key in ("X_test_transformed", "X_train_transformed"):
-            try:
-                X_data = self.exp.get_config(key)
-                break
-            except KeyError:
-                continue
+        X_data = self._get_transformed_frame(model)
         if X_data is None:
             raise RuntimeError("No transformed dataset found for SHAP.")
 
-        # --- Adaptive feature limiting (proportional cap) ---
         n_rows, n_features = X_data.shape
+        self.shap_total_features = n_features
+        feature_cap = (
+            min(self.max_plot_features, n_features)
+            if self.max_plot_features is not None
+            else n_features
+        )
         if max_features is None:
-            if n_features <= 200:
-                max_features = n_features
-            else:
-                max_features = min(200, max(20, int(n_features * 0.1)))
+            max_features = feature_cap
+        else:
+            max_features = min(max_features, feature_cap)
+        display_features = list(X_data.columns)
 
         try:
             if hasattr(model, "feature_importances_"):
@@ -138,15 +225,35 @@
                 variances = X_data.var()
                 top_features = variances.nlargest(max_features).index
 
-            if len(top_features) < n_features:
+            candidate_features = list(top_features)
+            missing = [f for f in candidate_features if f not in X_data.columns]
+            display_features = [f for f in candidate_features if f in X_data.columns]
+            if missing:
+                LOG.warning(
+                    "Dropping %s transformed feature(s) not present in SHAP frame: %s",
+                    len(missing),
+                    missing[:5],
+                )
+            if display_features and len(display_features) < n_features:
                 LOG.info(
-                    f"Restricted SHAP computation to top {len(top_features)} / {n_features} features"
+                    "Restricting SHAP display to top %s of %s features",
+                    len(display_features),
+                    n_features,
                 )
-            X_data = X_data[top_features]
+            elif not display_features:
+                display_features = list(X_data.columns)
         except Exception as e:
             LOG.warning(
                 f"Feature limiting failed: {e}. Using all {n_features} features."
             )
+            display_features = list(X_data.columns)
+
+        self.shap_used_features = len(display_features)
+
+        # Apply the column restriction so SHAP only runs on the selected features.
+        if display_features:
+            X_data = X_data[display_features]
+            n_rows, n_features = X_data.shape
 
         # --- Adaptive row subsampling ---
         if max_samples is None:
@@ -157,18 +264,26 @@
             else:
                 max_samples = min(1000, int(n_rows * 0.1))
 
+        if self.max_shap_rows is not None:
+            max_samples = min(max_samples, self.max_shap_rows)
+
         if n_rows > max_samples:
             LOG.info(f"Subsampling SHAP rows: {max_samples} of {n_rows}")
             X_data = X_data.sample(max_samples, random_state=42)
 
         # --- Adaptive feature display ---
+        display_cap = (
+            min(self.max_plot_features, len(display_features))
+            if self.max_plot_features is not None
+            else len(display_features)
+        )
         if max_display is None:
-            if X_data.shape[1] <= 20:
-                max_display = X_data.shape[1]
-            elif X_data.shape[1] <= 100:
-                max_display = 30
-            else:
-                max_display = 50
+            max_display = display_cap
+        else:
+            max_display = min(max_display, display_cap)
+        if not display_features:
+            display_features = list(X_data.columns)
+            max_display = len(display_features)
 
         # Background set
         bg = X_data.sample(min(len(X_data), 100), random_state=42)
@@ -177,37 +292,159 @@
         )
 
         # Optimized explainer
+        explainer = None
+        explainer_label = None
         if hasattr(model, "feature_importances_"):
             explainer = shap.TreeExplainer(
                 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
             )
+            explainer_label = "tree_path_dependent"
         elif hasattr(model, "coef_"):
             explainer = shap.LinearExplainer(model, bg)
+            explainer_label = "linear"
         else:
             explainer = shap.Explainer(predict_fn, bg)
+            explainer_label = explainer.__class__.__name__
 
         try:
             shap_values = explainer(X_data)
             self.shap_model_name = explainer.__class__.__name__
         except Exception as e:
-            LOG.error(f"SHAP computation failed: {e}")
+            error_message = str(e)
+            needs_tree_fallback = (
+                hasattr(model, "feature_importances_")
+                and "does not cover all the leaves" in error_message.lower()
+            )
+            feature_name_mismatch = "feature names should match" in error_message.lower()
+            if needs_tree_fallback:
+                LOG.warning(
+                    "SHAP computation failed using '%s' perturbation (%s). "
+                    "Retrying with interventional perturbation.",
+                    explainer_label,
+                    error_message,
+                )
+                try:
+                    explainer = shap.TreeExplainer(
+                        model,
+                        bg,
+                        feature_perturbation="interventional",
+                        n_jobs=-1,
+                    )
+                    shap_values = explainer(X_data)
+                    self.shap_model_name = (
+                        f"{explainer.__class__.__name__} (interventional)"
+                    )
+                except Exception as retry_exc:
+                    LOG.error(
+                        "SHAP computation failed even after fallback: %s",
+                        retry_exc,
+                    )
+                    self.shap_model_name = None
+                    return
+            elif feature_name_mismatch:
+                LOG.warning(
+                    "SHAP computation failed due to feature-name mismatch (%s). "
+                    "Falling back to model-agnostic SHAP explainer.",
+                    error_message,
+                )
+                try:
+                    agnostic_explainer = shap.Explainer(predict_fn, bg)
+                    shap_values = agnostic_explainer(X_data)
+                    self.shap_model_name = (
+                        f"{agnostic_explainer.__class__.__name__} (fallback)"
+                    )
+                except Exception as fallback_exc:
+                    LOG.error(
+                        "Model-agnostic SHAP fallback also failed: %s",
+                        fallback_exc,
+                    )
+                    self.shap_model_name = None
+                    return
+            else:
+                LOG.error(f"SHAP computation failed: {e}")
+                self.shap_model_name = None
+                return
+
+        def _limit_explanation_features(explanation):
+            if len(display_features) >= n_features:
+                return explanation
+            try:
+                limited = explanation[:, display_features]
+                LOG.info(
+                    "SHAP explanation trimmed to %s display features.",
+                    len(display_features),
+                )
+                return limited
+            except Exception as exc:
+                LOG.warning(
+                    "Failed to restrict SHAP explanation to top features "
+                    "(sample=%s); plot will include all features. Error: %s",
+                    display_features[:5],
+                    exc,
+                )
+                # Keep using full feature list if trimming fails
+                return explanation
+
+        shap_shape = getattr(shap_values, "shape", None)
+        class_labels = list(getattr(model, "classes_", []))
+        shap_outputs = []
+        if shap_shape is not None and len(shap_shape) == 3:
+            output_count = shap_shape[2]
+            LOG.info("Detected multi-output SHAP explanation with %s classes.", output_count)
+            for class_idx in range(output_count):
+                try:
+                    class_expl = shap_values[..., class_idx]
+                except Exception as exc:
+                    LOG.warning(
+                        "Failed to extract SHAP explanation for class index %s: %s",
+                        class_idx,
+                        exc,
+                    )
+                    continue
+                label = (
+                    class_labels[class_idx]
+                    if class_labels and class_idx < len(class_labels)
+                    else class_idx
+                )
+                shap_outputs.append((class_idx, label, class_expl))
+        else:
+            shap_outputs.append((None, None, shap_values))
+
+        if not shap_outputs:
+            LOG.error("No SHAP outputs available for plotting.")
             self.shap_model_name = None
             return
 
-        # --- Plot SHAP summary ---
-        out_path = os.path.join(self.output_dir, "shap_summary.png")
-        plt.figure()
-        shap.plots.beeswarm(shap_values, max_display=max_display, show=False)
-        plt.title(
-            f"SHAP Summary for {model.__class__.__name__} (top {max_display} features)"
-        )
-        plt.savefig(out_path, bbox_inches="tight")
-        plt.close()
-        self.plots["shap_summary"] = out_path
+        # --- Plot SHAP summary (one per class if needed) ---
+        for class_idx, class_label, class_expl in shap_outputs:
+            expl_to_plot = _limit_explanation_features(class_expl)
+            suffix = ""
+            plot_key = "shap_summary"
+            if class_idx is not None:
+                safe_label = str(class_label).replace("/", "_").replace(" ", "_")
+                suffix = f"_class_{safe_label}"
+                plot_key = f"shap_summary_class_{safe_label}"
+            out_filename = f"shap_summary{suffix}.png"
+            out_path = os.path.join(self.output_dir, out_filename)
+            plt.figure()
+            shap.plots.beeswarm(expl_to_plot, max_display=max_display, show=False)
+            title = f"SHAP Summary for {model.__class__.__name__}"
+            if class_idx is not None:
+                title += f" (class {class_label})"
+            plt.title(f"{title} (top {max_display} features)")
+            plt.tight_layout()
+            plt.savefig(out_path, bbox_inches="tight")
+            plt.close()
+            self.plots[plot_key] = out_path
 
         # --- Log summary ---
         LOG.info(
-            f"SHAP summary completed with {X_data.shape[0]} rows and {X_data.shape[1]} features (displaying top {max_display})."
+            "SHAP summary completed with %s rows and %s features "
+            "(displaying top %s) across %s output(s).",
+            X_data.shape[0],
+            X_data.shape[1],
+            max_display,
+            len(shap_outputs),
         )
 
     def generate_html_report(self):
@@ -227,12 +464,19 @@
                 section_title = (
                     f"SHAP Summary from {getattr(self, 'shap_model_name', 'model')}"
                 )
+            elif plot_name.startswith("shap_summary_class_"):
+                class_label = plot_name.replace("shap_summary_class_", "")
+                section_title = (
+                    f"SHAP Summary for class {class_label} "
+                    f"({getattr(self, 'shap_model_name', 'model')})"
+                )
             else:
                 section_title = plot_name
             plots_html += f"""
-            <div class="plot" id="{plot_name}">
+            <div class="plot" id="{plot_name}" style="text-align:center;margin-bottom:24px;">
                 <h2>{section_title}</h2>
-                <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">
+                <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"
+                     style="max-width:95%;height:auto;display:block;margin:0 auto;border:1px solid #ddd;padding:8px;background:#fff;">
             </div>
             """
         return f"{plots_html}"
--- a/pycaret_macros.xml	Fri Nov 28 15:46:05 2025 +0000
+++ b/pycaret_macros.xml	Fri Nov 28 22:28:12 2025 +0000
@@ -1,5 +1,5 @@
 <macros>
-    <token name="@TABULAR_LEARNER_VERSION@">0.1.2</token>
+    <token name="@TABULAR_LEARNER_VERSION@">0.1.3</token>
     <token name="@PYCARET_VERSION@">3.3.2</token>
     <token name="@SUFFIX@">2</token>
     <token name="@PYCARET_PREDICT_VERSION@">@PYCARET_VERSION@+@SUFFIX@</token>