diff pycaret_regression.py @ 8:1aed7d47c5ec draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author goeckslab
date Fri, 25 Jul 2025 19:02:32 +0000
parents ccd798db5abb
children
line wrap: on
line diff
--- a/pycaret_regression.py	Wed Jul 09 01:13:01 2025 +0000
+++ b/pycaret_regression.py	Fri Jul 25 19:02:32 2025 +0000
@@ -3,21 +3,21 @@
 from base_model_trainer import BaseModelTrainer
 from dashboard import generate_regression_explainer_dashboard
 from pycaret.regression import RegressionExperiment
-from utils import add_hr_to_html, add_plot_to_html
 
 LOG = logging.getLogger(__name__)
 
 
 class RegressionModelTrainer(BaseModelTrainer):
     def __init__(
-            self,
-            input_file,
-            target_col,
-            output_dir,
-            task_type,
-            random_seed,
-            test_file=None,
-            **kwargs):
+        self,
+        input_file,
+        target_col,
+        output_dir,
+        task_type,
+        random_seed,
+        test_file=None,
+        **kwargs,
+    ):
         super().__init__(
             input_file,
             target_col,
@@ -25,24 +25,35 @@
             task_type,
             random_seed,
             test_file,
-            **kwargs)
+            **kwargs,
+        )
+        # The BaseModelTrainer.setup_pycaret will set self.exp appropriately
+        # But we reassign here for clarity
         self.exp = RegressionExperiment()
 
     def save_dashboard(self):
         LOG.info("Saving explainer dashboard")
-        dashboard = generate_regression_explainer_dashboard(self.exp,
-                                                            self.best_model)
+        dashboard = generate_regression_explainer_dashboard(self.exp, self.best_model)
         dashboard.save_html("dashboard.html")
 
     def generate_plots(self):
         LOG.info("Generating and saving plots")
-        plots = ['residuals', 'error', 'cooks',
-                 'learning', 'vc', 'manifold',
-                 'rfe', 'feature', 'feature_all']
+        plots = [
+            "residuals",
+            "error",
+            "cooks",
+            "learning",
+            "vc",
+            "manifold",
+            "rfe",
+            "feature",
+            "feature_all",
+        ]
         for plot_name in plots:
             try:
-                plot_path = self.exp.plot_model(self.best_model,
-                                                plot=plot_name, save=True)
+                plot_path = self.exp.plot_model(
+                    self.best_model, plot=plot_name, save=True
+                )
                 self.plots[plot_name] = plot_path
             except Exception as e:
                 LOG.error(f"Error generating plot {plot_name}: {e}")
@@ -58,79 +69,60 @@
 
         try:
             explainer = RegressionExplainer(self.best_model, X_test, y_test)
-            self.expaliner = explainer
-            plots_explainer_html = ""
         except Exception as e:
             LOG.error(f"Error creating explainer: {e}")
-            self.plots_explainer_html = None
             return
 
+        # --- 1) SHAP mean impact (average absolute SHAP values) ---
         try:
-            fig_importance = explainer.plot_importances()
-            plots_explainer_html += add_plot_to_html(fig_importance)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot importance: {e}")
-
-        try:
-            fig_importance_permutation = \
-                explainer.plot_importances_permutation(
-                    kind="permutation")
-            plots_explainer_html += add_plot_to_html(
-                fig_importance_permutation)
-            plots_explainer_html += add_hr_to_html()
+            self.explainer_plots["shap_mean"] = explainer.plot_importances()
         except Exception as e:
-            LOG.error(f"Error generating plot importance permutation: {e}")
+            LOG.error(f"Error generating SHAP mean importance: {e}")
 
-        try:
-            for feature in self.features_name:
-                fig_shap = explainer.plot_pdp(feature)
-                plots_explainer_html += add_plot_to_html(fig_shap)
-                plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot shap dependence: {e}")
-
-        # try:
-        #     for feature in self.features_name:
-        #         fig_interaction = explainer.plot_interaction(col=feature)
-        #         plots_explainer_html += add_plot_to_html(fig_interaction)
-        # except Exception as e:
-        #     LOG.error(f"Error generating plot shap interaction: {e}")
-
+        # --- 2) SHAP permutation importance ---
         try:
-            for feature in self.features_name:
-                fig_interactions_importance = \
-                    explainer.plot_interactions_importance(
-                        col=feature)
-                plots_explainer_html += add_plot_to_html(
-                    fig_interactions_importance)
-                plots_explainer_html += add_hr_to_html()
+            self.explainer_plots["shap_perm"] = explainer.plot_importances_permutation(
+                kind="permutation"
+            )
         except Exception as e:
-            LOG.error(f"Error generating plot shap summary: {e}")
+            LOG.error(f"Error generating SHAP permutation importance: {e}")
 
-        # Regression specific plots
-        try:
-            fig_pred_actual = explainer.plot_predicted_vs_actual()
-            plots_explainer_html += add_plot_to_html(fig_pred_actual)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot prediction vs actual: {e}")
+        # Pre-filter features so we never call PDP or residual-vs-feature on missing cols
+        valid_feats = []
+        for feat in self.features_name:
+            if feat in explainer.X.columns or feat in explainer.onehot_cols:
+                valid_feats.append(feat)
+            else:
+                LOG.warning(f"Skipping feature {feat!r}: not found in explainer data")
 
-        try:
-            fig_residuals = explainer.plot_residuals()
-            plots_explainer_html += add_plot_to_html(fig_residuals)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot residuals: {e}")
+        # --- 3) Partial Dependence Plots (PDPs) per feature ---
+        for feature in valid_feats:
+            try:
+                fig_pdp = explainer.plot_pdp(feature)
+                self.explainer_plots[f"pdp__{feature}"] = fig_pdp
+            except AssertionError as ae:
+                LOG.warning(f"PDP AssertionError for {feature!r}: {ae}")
+            except Exception as e:
+                LOG.error(f"Error generating PDP for {feature}: {e}")
 
+        # --- 4) Predicted vs Actual plot ---
         try:
-            for feature in self.features_name:
-                fig_residuals_vs_feature = \
-                    explainer.plot_residuals_vs_feature(feature)
-                plots_explainer_html += add_plot_to_html(
-                    fig_residuals_vs_feature)
-                plots_explainer_html += add_hr_to_html()
+            self.explainer_plots["predicted_vs_actual"] = explainer.plot_predicted_vs_actual()
+        except Exception as e:
+            LOG.error(f"Error generating Predicted vs Actual plot: {e}")
+
+        # --- 5) Global residuals distribution ---
+        try:
+            self.explainer_plots["residuals"] = explainer.plot_residuals()
         except Exception as e:
-            LOG.error(f"Error generating plot residuals vs feature: {e}")
+            LOG.error(f"Error generating Residuals plot: {e}")
 
-        self.plots_explainer_html = plots_explainer_html
+        # --- 6) Residuals vs each feature ---
+        for feature in valid_feats:
+            try:
+                fig_res_vs_feat = explainer.plot_residuals_vs_feature(feature)
+                self.explainer_plots[f"residuals_vs_feature__{feature}"] = fig_res_vs_feat
+            except AssertionError as ae:
+                LOG.warning(f"Residuals-vs-feature AssertionError for {feature!r}: {ae}")
+            except Exception as e:
+                LOG.error(f"Error generating Residuals vs {feature}: {e}")