diff feature_importance.py @ 8:1aed7d47c5ec draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author goeckslab
date Fri, 25 Jul 2025 19:02:32 +0000
parents f4cb41f458fd
children
line wrap: on
line diff
--- a/feature_importance.py	Wed Jul 09 01:13:01 2025 +0000
+++ b/feature_importance.py	Fri Jul 25 19:02:32 2025 +0000
@@ -14,14 +14,15 @@
 
 class FeatureImportanceAnalyzer:
     def __init__(
-            self,
-            task_type,
-            output_dir,
-            data_path=None,
-            data=None,
-            target_col=None,
-            exp=None,
-            best_model=None):
+        self,
+        task_type,
+        output_dir,
+        data_path=None,
+        data=None,
+        target_col=None,
+        exp=None,
+        best_model=None,
+    ):
 
         self.task_type = task_type
         self.output_dir = output_dir
@@ -43,7 +44,11 @@
                 self.data.columns = self.data.columns.str.replace('.', '_')
                 self.data = self.data.fillna(self.data.median(numeric_only=True))
             self.target = self.data.columns[int(target_col) - 1]
-            self.exp = ClassificationExperiment() if task_type == 'classification' else RegressionExperiment()
+            self.exp = (
+                ClassificationExperiment()
+                if task_type == "classification"
+                else RegressionExperiment()
+            )
 
         self.plots = {}
 
@@ -57,7 +62,7 @@
             'session_id': 123,
             'html': True,
             'log_experiment': False,
-            'system_log': False
+            'system_log': False,
         }
         self.exp.setup(self.data, **setup_params)
 
@@ -70,14 +75,16 @@
         model_type = model.__class__.__name__
         self.tree_model_name = model_type  # Store the model name for reporting
 
-        if hasattr(model, "feature_importances_"):
+        if hasattr(model, 'feature_importances_'):
             importances = model.feature_importances_
-        elif hasattr(model, "coef_"):
+        elif hasattr(model, 'coef_'):
             # For linear models, flatten coef_ and take abs (importance as magnitude)
             importances = abs(model.coef_).flatten()
         else:
             # Neither attribute exists; skip the plot
-            LOG.warning(f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot.")
+            LOG.warning(
+                f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot."
+            )
             self.tree_model_name = None  # No plot generated
             return
 
@@ -89,69 +96,77 @@
             self.tree_model_name = None
             return
 
-        feature_importances = pd.DataFrame({
-            'Feature': processed_features,
-            'Importance': importances
-        }).sort_values(by='Importance', ascending=False)
+        feature_importances = pd.DataFrame(
+            {'Feature': processed_features, 'Importance': importances}
+        ).sort_values(by='Importance', ascending=False)
         plt.figure(figsize=(10, 6))
-        plt.barh(
-            feature_importances['Feature'],
-            feature_importances['Importance'])
+        plt.barh(feature_importances['Feature'], feature_importances['Importance'])
         plt.xlabel('Importance')
         plt.title(f'Feature Importance ({model_type})')
-        plot_path = os.path.join(
-            self.output_dir,
-            'tree_importance.png')
+        plot_path = os.path.join(self.output_dir, 'tree_importance.png')
         plt.savefig(plot_path)
         plt.close()
         self.plots['tree_importance'] = plot_path
 
     def save_shap_values(self):
-        model = self.best_model or self.exp.get_config('best_model')
-        X_transformed = self.exp.get_config('X_transformed')
-        tree_classes = (
-            "LGBM", "XGB", "CatBoost", "RandomForest", "DecisionTree", "ExtraTrees", "HistGradientBoosting"
-        )
-        model_class_name = model.__class__.__name__
-        self.shap_model_name = model_class_name
+
+        model = self.best_model or self.exp.get_config("best_model")
 
-        # Ensure feature alignment
-        if hasattr(model, "feature_name_"):
-            used_features = model.feature_name_
-        elif hasattr(model, "booster_") and hasattr(model.booster_, "feature_name"):
+        X_data = None
+        for key in ("X_test_transformed", "X_train_transformed"):
+            try:
+                X_data = self.exp.get_config(key)
+                break
+            except KeyError:
+                continue
+        if X_data is None:
+            raise RuntimeError(
+                "Could not find 'X_test_transformed' or 'X_train_transformed' in the experiment. "
+                "Make sure PyCaret setup/compare_models was run with feature_selection=True."
+            )
+
+        try:
             used_features = model.booster_.feature_name()
-        elif hasattr(model, "feature_names_in_"):
-            # scikit‐learn's standard attribute for the names of features used during fit
-            used_features = list(model.feature_names_in_)
-        else:
-            used_features = X_transformed.columns
+        except Exception:
+            used_features = getattr(model, "feature_names_in_", X_data.columns.tolist())
+        X_data = X_data[used_features]
+
+        max_bg = min(len(X_data), 100)
+        bg = X_data.sample(max_bg, random_state=42)
+
+        predict_fn = model.predict_proba if hasattr(model, "predict_proba") else model.predict
 
-        if any(tc in model_class_name for tc in tree_classes):
-            explainer = shap.TreeExplainer(model)
-            X_shap = X_transformed[used_features]
-            shap_values = explainer.shap_values(X_shap)
-            plot_X = X_shap
-            plot_title = f"SHAP Summary for {model_class_name} (TreeExplainer)"
+        explainer = shap.Explainer(predict_fn, bg)
+        self.shap_model_name = explainer.__class__.__name__
+
+        shap_values = explainer(X_data)
+
+        output_names = getattr(shap_values, "output_names", None)
+        if output_names is None and hasattr(model, "classes_"):
+            output_names = list(model.classes_)
+        if output_names is None:
+            n_out = shap_values.values.shape[-1]
+            output_names = list(map(str, range(n_out)))
+
+        values = shap_values.values
+        if values.ndim == 3:
+            for j, name in enumerate(output_names):
+                safe = name.replace(" ", "_").replace("/", "_")
+                out_path = os.path.join(self.output_dir, f"shap_summary_{safe}.png")
+                plt.figure()
+                shap.plots.beeswarm(shap_values[..., j], show=False)
+                plt.title(f"SHAP for {model.__class__.__name__} ⇒ {name}")
+                plt.savefig(out_path)
+                plt.close()
+                self.plots[f"shap_summary_{safe}"] = out_path
         else:
-            logging.warning(f"len(X_transformed) = {len(X_transformed)}")
-            max_samples = 100
-            n_samples = min(max_samples, len(X_transformed))
-            sampled_X = X_transformed[used_features].sample(
-                n=n_samples,
-                replace=False,
-                random_state=42
-            )
-            explainer = shap.KernelExplainer(model.predict, sampled_X)
-            shap_values = explainer.shap_values(sampled_X)
-            plot_X = sampled_X
-            plot_title = f"SHAP Summary for {model_class_name} (KernelExplainer)"
-
-        shap.summary_plot(shap_values, plot_X, show=False)
-        plt.title(plot_title)
-        plot_path = os.path.join(self.output_dir, "shap_summary.png")
-        plt.savefig(plot_path)
-        plt.close()
-        self.plots["shap_summary"] = plot_path
+            plt.figure()
+            shap.plots.beeswarm(shap_values, show=False)
+            plt.title(f"SHAP Summary for {model.__class__.__name__}")
+            out_path = os.path.join(self.output_dir, "shap_summary.png")
+            plt.savefig(out_path)
+            plt.close()
+            self.plots["shap_summary"] = out_path
 
     def generate_html_report(self):
         LOG.info("Generating HTML report")
@@ -159,11 +174,17 @@
         plots_html = ""
         for plot_name, plot_path in self.plots.items():
             # Special handling for tree importance: skip if no model name (not generated)
-            if plot_name == 'tree_importance' and not getattr(self, 'tree_model_name', None):
+            if plot_name == 'tree_importance' and not getattr(
+                self, 'tree_model_name', None
+            ):
                 continue
             encoded_image = self.encode_image_to_base64(plot_path)
-            if plot_name == 'tree_importance' and getattr(self, 'tree_model_name', None):
-                section_title = f"Feature importance analysis from a trained {self.tree_model_name}"
+            if plot_name == 'tree_importance' and getattr(
+                self, 'tree_model_name', None
+            ):
+                section_title = (
+                    f"Feature importance analysis from a trained {self.tree_model_name}"
+                )
             elif plot_name == 'shap_summary':
                 section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}"
             else:
@@ -176,7 +197,6 @@
             """
 
         html_content = f"""
-            <h1>PyCaret Feature Importance Report</h1>
             {plots_html}
         """
 
@@ -187,7 +207,11 @@
             return base64.b64encode(img_file.read()).decode('utf-8')
 
     def run(self):
-        if self.exp is None or not hasattr(self.exp, 'is_setup') or not self.exp.is_setup:
+        if (
+            self.exp is None
+            or not hasattr(self.exp, 'is_setup')
+            or not self.exp.is_setup
+        ):
             self.setup_pycaret()
         self.save_tree_importance()
         self.save_shap_values()