diff feature_importance.py @ 6:a32ff7201629 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 06c0da44ac93256dfb616a6b40276b5485a71e8e
author goeckslab
date Wed, 02 Jul 2025 19:00:03 +0000
parents ccd798db5abb
children
line wrap: on
line diff
--- a/feature_importance.py	Sat Jun 21 15:07:04 2025 +0000
+++ b/feature_importance.py	Wed Jul 02 19:00:03 2025 +0000
@@ -4,6 +4,7 @@
 
 import matplotlib.pyplot as plt
 import pandas as pd
+import shap
 from pycaret.classification import ClassificationExperiment
 from pycaret.regression import RegressionExperiment
 
@@ -18,25 +19,38 @@
             output_dir,
             data_path=None,
             data=None,
-            target_col=None):
+            target_col=None,
+            exp=None,
+            best_model=None):
 
-        if data is not None:
-            self.data = data
-            LOG.info("Data loaded from memory")
+        self.task_type = task_type
+        self.output_dir = output_dir
+        self.exp = exp
+        self.best_model = best_model
+
+        if exp is not None:
+            # Assume all configs (data, target) are in exp
+            self.data = exp.dataset.copy()
+            self.target = exp.target_param
+            LOG.info("Using provided experiment object")
         else:
-            self.target_col = target_col
-            self.data = pd.read_csv(data_path, sep=None, engine='python')
-            self.data.columns = self.data.columns.str.replace('.', '_')
-            self.data = self.data.fillna(self.data.median(numeric_only=True))
-        self.task_type = task_type
-        self.target = self.data.columns[int(target_col) - 1]
-        self.exp = ClassificationExperiment() \
-            if task_type == 'classification' \
-            else RegressionExperiment()
+            if data is not None:
+                self.data = data
+                LOG.info("Data loaded from memory")
+            else:
+                self.target_col = target_col
+                self.data = pd.read_csv(data_path, sep=None, engine='python')
+                self.data.columns = self.data.columns.str.replace('.', '_')
+                self.data = self.data.fillna(self.data.median(numeric_only=True))
+            self.target = self.data.columns[int(target_col) - 1]
+            self.exp = ClassificationExperiment() if task_type == 'classification' else RegressionExperiment()
+
         self.plots = {}
-        self.output_dir = output_dir
 
     def setup_pycaret(self):
+        if self.exp is not None and hasattr(self.exp, 'is_setup') and self.exp.is_setup:
+            LOG.info("Experiment already set up. Skipping PyCaret setup.")
+            return
         LOG.info("Initializing PyCaret")
         setup_params = {
             'target': self.target,
@@ -45,25 +59,36 @@
             'log_experiment': False,
             'system_log': False
         }
-        LOG.info(self.task_type)
-        LOG.info(self.exp)
         self.exp.setup(self.data, **setup_params)
 
-    # def save_coefficients(self):
-    #     model = self.exp.create_model('lr')
-    #     coef_df = pd.DataFrame({
-    #         'Feature': self.data.columns.drop(self.target),
-    #         'Coefficient': model.coef_[0]
-    #     })
-    #     coef_html = coef_df.to_html(index=False)
-    #     return coef_html
+    def save_tree_importance(self):
+        model = self.best_model or self.exp.get_config('best_model')
+        processed_features = self.exp.get_config('X_transformed').columns
+
+        # Try feature_importances_ or coef_ if available
+        importances = None
+        model_type = model.__class__.__name__
+        self.tree_model_name = model_type  # Store the model name for reporting
 
-    def save_tree_importance(self):
-        model = self.exp.create_model('rf')
-        importances = model.feature_importances_
-        processed_features = self.exp.get_config('X_transformed').columns
-        LOG.debug(f"Feature importances: {importances}")
-        LOG.debug(f"Features: {processed_features}")
+        if hasattr(model, "feature_importances_"):
+            importances = model.feature_importances_
+        elif hasattr(model, "coef_"):
+            # For linear models, flatten coef_ and take abs (importance as magnitude)
+            importances = abs(model.coef_).flatten()
+        else:
+            # Neither attribute exists; skip the plot
+            LOG.warning(f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot.")
+            self.tree_model_name = None  # No plot generated
+            return
+
+        # Defensive: handle mismatch in number of features
+        if len(importances) != len(processed_features):
+            LOG.warning(
+                f"Number of importances ({len(importances)}) does not match number of features ({len(processed_features)}). Skipping plot."
+            )
+            self.tree_model_name = None
+            return
+
         feature_importances = pd.DataFrame({
             'Feature': processed_features,
             'Importance': importances
@@ -73,7 +98,7 @@
             feature_importances['Feature'],
             feature_importances['Importance'])
         plt.xlabel('Importance')
-        plt.title('Feature Importance (Random Forest)')
+        plt.title(f'Feature Importance ({model_type})')
         plot_path = os.path.join(
             self.output_dir,
             'tree_importance.png')
@@ -82,53 +107,64 @@
         self.plots['tree_importance'] = plot_path
 
     def save_shap_values(self):
-        model = self.exp.create_model('lightgbm')
-        import shap
-        explainer = shap.Explainer(model)
-        shap_values = explainer.shap_values(
-            self.exp.get_config('X_transformed'))
-        shap.summary_plot(shap_values,
-                          self.exp.get_config('X_transformed'), show=False)
-        plt.title('Shap (LightGBM)')
-        plot_path = os.path.join(
-            self.output_dir, 'shap_summary.png')
+        model = self.best_model or self.exp.get_config('best_model')
+        X_transformed = self.exp.get_config('X_transformed')
+        tree_classes = (
+            "LGBM", "XGB", "CatBoost", "RandomForest", "DecisionTree", "ExtraTrees", "HistGradientBoosting"
+        )
+        model_class_name = model.__class__.__name__
+        self.shap_model_name = model_class_name
+
+        # Ensure feature alignment
+        if hasattr(model, "feature_name_"):
+            used_features = model.feature_name_
+        elif hasattr(model, "booster_") and hasattr(model.booster_, "feature_name"):
+            used_features = model.booster_.feature_name()
+        else:
+            used_features = X_transformed.columns
+
+        if any(tc in model_class_name for tc in tree_classes):
+            explainer = shap.TreeExplainer(model)
+            X_shap = X_transformed[used_features]
+            shap_values = explainer.shap_values(X_shap)
+            plot_X = X_shap
+            plot_title = f"SHAP Summary for {model_class_name} (TreeExplainer)"
+        else:
+            sampled_X = X_transformed[used_features].sample(100, random_state=42)
+            explainer = shap.KernelExplainer(model.predict, sampled_X)
+            shap_values = explainer.shap_values(sampled_X)
+            plot_X = sampled_X
+            plot_title = f"SHAP Summary for {model_class_name} (KernelExplainer)"
+
+        shap.summary_plot(shap_values, plot_X, show=False)
+        plt.title(plot_title)
+        plot_path = os.path.join(self.output_dir, "shap_summary.png")
         plt.savefig(plot_path)
         plt.close()
-        self.plots['shap_summary'] = plot_path
-
-    def generate_feature_importance(self):
-        # coef_html = self.save_coefficients()
-        self.save_tree_importance()
-        self.save_shap_values()
-
-    def encode_image_to_base64(self, img_path):
-        with open(img_path, 'rb') as img_file:
-            return base64.b64encode(img_file.read()).decode('utf-8')
+        self.plots["shap_summary"] = plot_path
 
     def generate_html_report(self):
         LOG.info("Generating HTML report")
 
-        # Read and encode plot images
         plots_html = ""
         for plot_name, plot_path in self.plots.items():
+            # Special handling for tree importance: skip if no model name (not generated)
+            if plot_name == 'tree_importance' and not getattr(self, 'tree_model_name', None):
+                continue
             encoded_image = self.encode_image_to_base64(plot_path)
+            if plot_name == 'tree_importance' and getattr(self, 'tree_model_name', None):
+                section_title = f"Feature importance analysis from a trained {self.tree_model_name}"
+            elif plot_name == 'shap_summary':
+                section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}"
+            else:
+                section_title = plot_name
             plots_html += f"""
             <div class="plot" id="{plot_name}">
-                <h2>{'Feature importance analysis from a'
-                    'trained Random Forest'
-                    if plot_name == 'tree_importance'
-                    else 'SHAP Summary from a trained lightgbm'}</h2>
-                <h3>{'Use gini impurity for'
-                    'calculating feature importance for classification'
-                    'and Variance Reduction for regression'
-                  if plot_name == 'tree_importance'
-                  else ''}</h3>
-                <img src="data:image/png;base64,
-                {encoded_image}" alt="{plot_name}">
+                <h2>{section_title}</h2>
+                <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">
             </div>
             """
 
-        # Generate HTML content with tabs
         html_content = f"""
             <h1>PyCaret Feature Importance Report</h1>
             {plots_html}
@@ -136,34 +172,14 @@
 
         return html_content
 
-    def run(self):
-        LOG.info("Running feature importance analysis")
-        self.setup_pycaret()
-        self.generate_feature_importance()
-        html_content = self.generate_html_report()
-        LOG.info("Feature importance analysis completed")
-        return html_content
-
+    def encode_image_to_base64(self, img_path):
+        with open(img_path, 'rb') as img_file:
+            return base64.b64encode(img_file.read()).decode('utf-8')
 
-if __name__ == "__main__":
-    import argparse
-    parser = argparse.ArgumentParser(description="Feature Importance Analysis")
-    parser.add_argument(
-        "--data_path", type=str, help="Path to the dataset")
-    parser.add_argument(
-        "--target_col", type=int,
-        help="Index of the target column (1-based)")
-    parser.add_argument(
-        "--task_type", type=str,
-        choices=["classification", "regression"],
-        help="Task type: classification or regression")
-    parser.add_argument(
-        "--output_dir",
-        type=str,
-        help="Directory to save the outputs")
-    args = parser.parse_args()
-
-    analyzer = FeatureImportanceAnalyzer(
-        args.data_path, args.target_col,
-        args.task_type, args.output_dir)
-    analyzer.run()
+    def run(self):
+        if self.exp is None or not hasattr(self.exp, 'is_setup') or not self.exp.is_setup:
+            self.setup_pycaret()
+        self.save_tree_importance()
+        self.save_shap_values()
+        html_content = self.generate_html_report()
+        return html_content