diff feature_importance.py @ 16:4fee4504646e draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 2b826699ef9518d4610f5cfb6468ce719ec8039d
author goeckslab
date Fri, 28 Nov 2025 22:28:26 +0000
parents e674b9e946fb
children
line wrap: on
line diff
--- a/feature_importance.py	Fri Nov 28 15:46:17 2025 +0000
+++ b/feature_importance.py	Fri Nov 28 22:28:26 2025 +0000
@@ -22,11 +22,23 @@
         target_col=None,
         exp=None,
         best_model=None,
+        max_plot_features=None,
+        processed_data=None,
+        max_shap_rows=None,
     ):
         self.task_type = task_type
         self.output_dir = output_dir
         self.exp = exp
         self.best_model = best_model
+        self._skip_messages = []
+        self.shap_total_features = None
+        self.shap_used_features = None
+        if isinstance(max_plot_features, int) and max_plot_features > 0:
+            self.max_plot_features = max_plot_features
+        elif max_plot_features is None:
+            self.max_plot_features = 30
+        else:
+            self.max_plot_features = None
 
         if exp is not None:
             # Assume all configs (data, target) are in exp
@@ -48,8 +60,55 @@
                 if task_type == "classification"
                 else RegressionExperiment()
             )
+        if processed_data is not None:
+            self.data = processed_data
 
         self.plots = {}
+        self.max_shap_rows = max_shap_rows
+
+    def _get_feature_names_from_model(self, model):
+        """Best-effort extraction of feature names seen by the estimator."""
+        if model is None:
+            return None
+
+        candidates = [model]
+        if hasattr(model, "named_steps"):
+            candidates.extend(model.named_steps.values())
+        elif hasattr(model, "steps"):
+            candidates.extend(step for _, step in model.steps)
+
+        for candidate in candidates:
+            names = getattr(candidate, "feature_names_in_", None)
+            if names is not None:
+                return list(names)
+        return None
+
+    def _get_transformed_frame(self, model=None, prefer_test=True):
+        """Return a DataFrame that mirrors the matrix fed to the estimator."""
+        key_order = ["X_test_transformed", "X_train_transformed"]
+        if not prefer_test:
+            key_order.reverse()
+        key_order.append("X_transformed")
+
+        feature_names = self._get_feature_names_from_model(model)
+        for key in key_order:
+            try:
+                frame = self.exp.get_config(key)
+            except KeyError:
+                continue
+            if frame is None:
+                continue
+            if isinstance(frame, pd.DataFrame):
+                return frame.copy()
+            try:
+                n_features = frame.shape[1]
+            except Exception:
+                continue
+            if feature_names and len(feature_names) == n_features:
+                return pd.DataFrame(frame, columns=feature_names)
+            # Fallback to positional names so downstream logic still works
+            return pd.DataFrame(frame, columns=[f"f{i}" for i in range(n_features)])
+        return None
 
     def setup_pycaret(self):
         if self.exp is not None and hasattr(self.exp, "is_setup") and self.exp.is_setup:
@@ -67,7 +126,14 @@
 
     def save_tree_importance(self):
         model = self.best_model or self.exp.get_config("best_model")
-        processed_features = self.exp.get_config("X_transformed").columns
+        processed_frame = self._get_transformed_frame(model, prefer_test=False)
+        if processed_frame is None:
+            LOG.warning(
+                "Unable to determine transformed feature names; skipping tree importance plot."
+            )
+            self.tree_model_name = None
+            return
+        processed_features = list(processed_frame.columns)
 
         importances = None
         model_type = model.__class__.__name__
@@ -85,20 +151,42 @@
             return
 
         if len(importances) != len(processed_features):
-            LOG.warning(
-                f"Importances ({len(importances)}) != features ({len(processed_features)}). Skipping tree importance."
-            )
-            self.tree_model_name = None
-            return
+            model_feature_names = self._get_feature_names_from_model(model)
+            if model_feature_names and len(model_feature_names) == len(importances):
+                processed_features = model_feature_names
+            else:
+                LOG.warning(
+                    "Importances (%s) != features (%s). Skipping tree importance.",
+                    len(importances),
+                    len(processed_features),
+                )
+                self.tree_model_name = None
+                return
 
         feature_importances = pd.DataFrame(
             {"Feature": processed_features, "Importance": importances}
         ).sort_values(by="Importance", ascending=False)
+        cap = (
+            min(self.max_plot_features, len(feature_importances))
+            if self.max_plot_features is not None
+            else len(feature_importances)
+        )
+        plot_importances = feature_importances.head(cap)
+        if cap < len(feature_importances):
+            LOG.info(
+                "Tree importance plot limited to top %s of %s features",
+                cap,
+                len(feature_importances),
+            )
         plt.figure(figsize=(10, 6))
-        plt.barh(feature_importances["Feature"], feature_importances["Importance"])
+        plt.barh(
+            plot_importances["Feature"],
+            plot_importances["Importance"],
+        )
         plt.xlabel("Importance")
-        plt.title(f"Feature Importance ({model_type})")
+        plt.title(f"Feature Importance ({model_type}) (top {cap})")
         plot_path = os.path.join(self.output_dir, "tree_importance.png")
+        plt.tight_layout()
         plt.savefig(plot_path, bbox_inches="tight")
         plt.close()
         self.plots["tree_importance"] = plot_path
@@ -106,23 +194,22 @@
     def save_shap_values(self, max_samples=None, max_display=None, max_features=None):
         model = self.best_model or self.exp.get_config("best_model")
 
-        X_data = None
-        for key in ("X_test_transformed", "X_train_transformed"):
-            try:
-                X_data = self.exp.get_config(key)
-                break
-            except KeyError:
-                continue
+        X_data = self._get_transformed_frame(model)
         if X_data is None:
             raise RuntimeError("No transformed dataset found for SHAP.")
 
-        # --- Adaptive feature limiting (proportional cap) ---
         n_rows, n_features = X_data.shape
+        self.shap_total_features = n_features
+        feature_cap = (
+            min(self.max_plot_features, n_features)
+            if self.max_plot_features is not None
+            else n_features
+        )
         if max_features is None:
-            if n_features <= 200:
-                max_features = n_features
-            else:
-                max_features = min(200, max(20, int(n_features * 0.1)))
+            max_features = feature_cap
+        else:
+            max_features = min(max_features, feature_cap)
+        display_features = list(X_data.columns)
 
         try:
             if hasattr(model, "feature_importances_"):
@@ -138,15 +225,35 @@
                 variances = X_data.var()
                 top_features = variances.nlargest(max_features).index
 
-            if len(top_features) < n_features:
+            candidate_features = list(top_features)
+            missing = [f for f in candidate_features if f not in X_data.columns]
+            display_features = [f for f in candidate_features if f in X_data.columns]
+            if missing:
+                LOG.warning(
+                    "Dropping %s transformed feature(s) not present in SHAP frame: %s",
+                    len(missing),
+                    missing[:5],
+                )
+            if display_features and len(display_features) < n_features:
                 LOG.info(
-                    f"Restricted SHAP computation to top {len(top_features)} / {n_features} features"
+                    "Restricting SHAP display to top %s of %s features",
+                    len(display_features),
+                    n_features,
                 )
-            X_data = X_data[top_features]
+            elif not display_features:
+                display_features = list(X_data.columns)
         except Exception as e:
             LOG.warning(
                 f"Feature limiting failed: {e}. Using all {n_features} features."
             )
+            display_features = list(X_data.columns)
+
+        self.shap_used_features = len(display_features)
+
+        # Apply the column restriction so SHAP only runs on the selected features.
+        if display_features:
+            X_data = X_data[display_features]
+            n_rows, n_features = X_data.shape
 
         # --- Adaptive row subsampling ---
         if max_samples is None:
@@ -157,18 +264,26 @@
             else:
                 max_samples = min(1000, int(n_rows * 0.1))
 
+        if self.max_shap_rows is not None:
+            max_samples = min(max_samples, self.max_shap_rows)
+
         if n_rows > max_samples:
             LOG.info(f"Subsampling SHAP rows: {max_samples} of {n_rows}")
             X_data = X_data.sample(max_samples, random_state=42)
 
         # --- Adaptive feature display ---
+        display_cap = (
+            min(self.max_plot_features, len(display_features))
+            if self.max_plot_features is not None
+            else len(display_features)
+        )
         if max_display is None:
-            if X_data.shape[1] <= 20:
-                max_display = X_data.shape[1]
-            elif X_data.shape[1] <= 100:
-                max_display = 30
-            else:
-                max_display = 50
+            max_display = display_cap
+        else:
+            max_display = min(max_display, display_cap)
+        if not display_features:
+            display_features = list(X_data.columns)
+            max_display = len(display_features)
 
         # Background set
         bg = X_data.sample(min(len(X_data), 100), random_state=42)
@@ -177,37 +292,159 @@
         )
 
         # Optimized explainer
+        explainer = None
+        explainer_label = None
         if hasattr(model, "feature_importances_"):
             explainer = shap.TreeExplainer(
                 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
             )
+            explainer_label = "tree_path_dependent"
         elif hasattr(model, "coef_"):
             explainer = shap.LinearExplainer(model, bg)
+            explainer_label = "linear"
         else:
             explainer = shap.Explainer(predict_fn, bg)
+            explainer_label = explainer.__class__.__name__
 
         try:
             shap_values = explainer(X_data)
             self.shap_model_name = explainer.__class__.__name__
         except Exception as e:
-            LOG.error(f"SHAP computation failed: {e}")
+            error_message = str(e)
+            needs_tree_fallback = (
+                hasattr(model, "feature_importances_")
+                and "does not cover all the leaves" in error_message.lower()
+            )
+            feature_name_mismatch = "feature names should match" in error_message.lower()
+            if needs_tree_fallback:
+                LOG.warning(
+                    "SHAP computation failed using '%s' perturbation (%s). "
+                    "Retrying with interventional perturbation.",
+                    explainer_label,
+                    error_message,
+                )
+                try:
+                    explainer = shap.TreeExplainer(
+                        model,
+                        bg,
+                        feature_perturbation="interventional",
+                        n_jobs=-1,
+                    )
+                    shap_values = explainer(X_data)
+                    self.shap_model_name = (
+                        f"{explainer.__class__.__name__} (interventional)"
+                    )
+                except Exception as retry_exc:
+                    LOG.error(
+                        "SHAP computation failed even after fallback: %s",
+                        retry_exc,
+                    )
+                    self.shap_model_name = None
+                    return
+            elif feature_name_mismatch:
+                LOG.warning(
+                    "SHAP computation failed due to feature-name mismatch (%s). "
+                    "Falling back to model-agnostic SHAP explainer.",
+                    error_message,
+                )
+                try:
+                    agnostic_explainer = shap.Explainer(predict_fn, bg)
+                    shap_values = agnostic_explainer(X_data)
+                    self.shap_model_name = (
+                        f"{agnostic_explainer.__class__.__name__} (fallback)"
+                    )
+                except Exception as fallback_exc:
+                    LOG.error(
+                        "Model-agnostic SHAP fallback also failed: %s",
+                        fallback_exc,
+                    )
+                    self.shap_model_name = None
+                    return
+            else:
+                LOG.error(f"SHAP computation failed: {e}")
+                self.shap_model_name = None
+                return
+
+        def _limit_explanation_features(explanation):
+            if len(display_features) >= n_features:
+                return explanation
+            try:
+                limited = explanation[:, display_features]
+                LOG.info(
+                    "SHAP explanation trimmed to %s display features.",
+                    len(display_features),
+                )
+                return limited
+            except Exception as exc:
+                LOG.warning(
+                    "Failed to restrict SHAP explanation to top features "
+                    "(sample=%s); plot will include all features. Error: %s",
+                    display_features[:5],
+                    exc,
+                )
+                # Keep using full feature list if trimming fails
+                return explanation
+
+        shap_shape = getattr(shap_values, "shape", None)
+        class_labels = list(getattr(model, "classes_", []))
+        shap_outputs = []
+        if shap_shape is not None and len(shap_shape) == 3:
+            output_count = shap_shape[2]
+            LOG.info("Detected multi-output SHAP explanation with %s classes.", output_count)
+            for class_idx in range(output_count):
+                try:
+                    class_expl = shap_values[..., class_idx]
+                except Exception as exc:
+                    LOG.warning(
+                        "Failed to extract SHAP explanation for class index %s: %s",
+                        class_idx,
+                        exc,
+                    )
+                    continue
+                label = (
+                    class_labels[class_idx]
+                    if class_labels and class_idx < len(class_labels)
+                    else class_idx
+                )
+                shap_outputs.append((class_idx, label, class_expl))
+        else:
+            shap_outputs.append((None, None, shap_values))
+
+        if not shap_outputs:
+            LOG.error("No SHAP outputs available for plotting.")
             self.shap_model_name = None
             return
 
-        # --- Plot SHAP summary ---
-        out_path = os.path.join(self.output_dir, "shap_summary.png")
-        plt.figure()
-        shap.plots.beeswarm(shap_values, max_display=max_display, show=False)
-        plt.title(
-            f"SHAP Summary for {model.__class__.__name__} (top {max_display} features)"
-        )
-        plt.savefig(out_path, bbox_inches="tight")
-        plt.close()
-        self.plots["shap_summary"] = out_path
+        # --- Plot SHAP summary (one per class if needed) ---
+        for class_idx, class_label, class_expl in shap_outputs:
+            expl_to_plot = _limit_explanation_features(class_expl)
+            suffix = ""
+            plot_key = "shap_summary"
+            if class_idx is not None:
+                safe_label = str(class_label).replace("/", "_").replace(" ", "_")
+                suffix = f"_class_{safe_label}"
+                plot_key = f"shap_summary_class_{safe_label}"
+            out_filename = f"shap_summary{suffix}.png"
+            out_path = os.path.join(self.output_dir, out_filename)
+            plt.figure()
+            shap.plots.beeswarm(expl_to_plot, max_display=max_display, show=False)
+            title = f"SHAP Summary for {model.__class__.__name__}"
+            if class_idx is not None:
+                title += f" (class {class_label})"
+            plt.title(f"{title} (top {max_display} features)")
+            plt.tight_layout()
+            plt.savefig(out_path, bbox_inches="tight")
+            plt.close()
+            self.plots[plot_key] = out_path
 
         # --- Log summary ---
         LOG.info(
-            f"SHAP summary completed with {X_data.shape[0]} rows and {X_data.shape[1]} features (displaying top {max_display})."
+            "SHAP summary completed with %s rows and %s features "
+            "(displaying top %s) across %s output(s).",
+            X_data.shape[0],
+            X_data.shape[1],
+            max_display,
+            len(shap_outputs),
         )
 
     def generate_html_report(self):
@@ -227,12 +464,19 @@
                 section_title = (
                     f"SHAP Summary from {getattr(self, 'shap_model_name', 'model')}"
                 )
+            elif plot_name.startswith("shap_summary_class_"):
+                class_label = plot_name.replace("shap_summary_class_", "")
+                section_title = (
+                    f"SHAP Summary for class {class_label} "
+                    f"({getattr(self, 'shap_model_name', 'model')})"
+                )
             else:
                 section_title = plot_name
             plots_html += f"""
-            <div class="plot" id="{plot_name}">
+            <div class="plot" id="{plot_name}" style="text-align:center;margin-bottom:24px;">
                 <h2>{section_title}</h2>
-                <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">
+                <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"
+                     style="max-width:95%;height:auto;display:block;margin:0 auto;border:1px solid #ddd;padding:8px;background:#fff;">
             </div>
             """
         return f"{plots_html}"