diff pycaret_classification.py @ 8:1aed7d47c5ec draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author goeckslab
date Fri, 25 Jul 2025 19:02:32 +0000
parents ccd798db5abb
children
line wrap: on
line diff
--- a/pycaret_classification.py	Wed Jul 09 01:13:01 2025 +0000
+++ b/pycaret_classification.py	Fri Jul 25 19:02:32 2025 +0000
@@ -1,23 +1,27 @@
 import logging
+import types
+from typing import Dict
 
 from base_model_trainer import BaseModelTrainer
 from dashboard import generate_classifier_explainer_dashboard
+from plotly.graph_objects import Figure
 from pycaret.classification import ClassificationExperiment
-from utils import add_hr_to_html, add_plot_to_html, predict_proba
+from utils import predict_proba
 
 LOG = logging.getLogger(__name__)
 
 
 class ClassificationModelTrainer(BaseModelTrainer):
     def __init__(
-            self,
-            input_file,
-            target_col,
-            output_dir,
-            task_type,
-            random_seed,
-            test_file=None,
-            **kwargs):
+        self,
+        input_file,
+        target_col,
+        output_dir,
+        task_type,
+        random_seed,
+        test_file=None,
+        **kwargs,
+    ):
         super().__init__(
             input_file,
             target_col,
@@ -25,191 +29,134 @@
             task_type,
             random_seed,
             test_file,
-            **kwargs)
+            **kwargs,
+        )
         self.exp = ClassificationExperiment()
 
     def save_dashboard(self):
         LOG.info("Saving explainer dashboard")
-        dashboard = generate_classifier_explainer_dashboard(self.exp,
-                                                            self.best_model)
+        dashboard = generate_classifier_explainer_dashboard(self.exp, self.best_model)
         dashboard.save_html("dashboard.html")
 
     def generate_plots(self):
         LOG.info("Generating and saving plots")
 
         if not hasattr(self.best_model, "predict_proba"):
-            import types
             self.best_model.predict_proba = types.MethodType(
-                predict_proba, self.best_model)
+                predict_proba, self.best_model
+            )
             LOG.warning(
-                f"The model {type(self.best_model).__name__}\
-                    does not support `predict_proba`. \
-                    Applying monkey patch.")
+                f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch."
+            )
 
-        plots = ['confusion_matrix', 'auc', 'threshold', 'pr',
-                 'error', 'class_report', 'learning', 'calibration',
-                 'vc', 'dimension', 'manifold', 'rfe', 'feature',
-                 'feature_all']
+        plots = [
+            'confusion_matrix',
+            'auc',
+            'threshold',
+            'pr',
+            'error',
+            'class_report',
+            'learning',
+            'calibration',
+            'vc',
+            'dimension',
+            'manifold',
+            'rfe',
+            'feature',
+            'feature_all',
+        ]
         for plot_name in plots:
             try:
-                if plot_name == 'auc' and not self.exp.is_multiclass:
-                    plot_path = self.exp.plot_model(self.best_model,
-                                                    plot=plot_name,
-                                                    save=True,
-                                                    plot_kwargs={
-                                                        'micro': False,
-                                                        'macro': False,
-                                                        'per_class': False,
-                                                        'binary': True
-                                                    })
+                if plot_name == "threshold":
+                    plot_path = self.exp.plot_model(
+                        self.best_model,
+                        plot=plot_name,
+                        save=True,
+                        plot_kwargs={"binary": True, "percentage": True},
+                    )
                     self.plots[plot_name] = plot_path
-                    continue
-
-                plot_path = self.exp.plot_model(self.best_model,
-                                                plot=plot_name, save=True)
-                self.plots[plot_name] = plot_path
+                elif plot_name == "auc" and not self.exp.is_multiclass:
+                    plot_path = self.exp.plot_model(
+                        self.best_model,
+                        plot=plot_name,
+                        save=True,
+                        plot_kwargs={
+                            "micro": False,
+                            "macro": False,
+                            "per_class": False,
+                            "binary": True,
+                        },
+                    )
+                    self.plots[plot_name] = plot_path
+                else:
+                    plot_path = self.exp.plot_model(
+                        self.best_model, plot=plot_name, save=True
+                    )
+                    self.plots[plot_name] = plot_path
             except Exception as e:
                 LOG.error(f"Error generating plot {plot_name}: {e}")
                 continue
 
     def generate_plots_explainer(self):
-        LOG.info("Generating and saving plots from explainer")
+        from explainerdashboard import ClassifierExplainer
 
-        from explainerdashboard import ClassifierExplainer
+        LOG.info("Generating explainer plots")
 
         X_test = self.exp.X_test_transformed.copy()
         y_test = self.exp.y_test_transformed
-
-        try:
-            explainer = ClassifierExplainer(self.best_model, X_test, y_test)
-            self.expaliner = explainer
-            plots_explainer_html = ""
-        except Exception as e:
-            LOG.error(f"Error creating explainer: {e}")
-            self.plots_explainer_html = None
-            return
+        explainer = ClassifierExplainer(self.best_model, X_test, y_test)
 
-        try:
-            fig_importance = explainer.plot_importances()
-            plots_explainer_html += add_plot_to_html(fig_importance)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot importance(mean shap): {e}")
-
-        try:
-            fig_importance_perm = explainer.plot_importances(
-                kind="permutation")
-            plots_explainer_html += add_plot_to_html(fig_importance_perm)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot importance(permutation): {e}")
-
-        # try:
-        #     fig_shap = explainer.plot_shap_summary()
-        #     plots_explainer_html += add_plot_to_html(fig_shap,
-        #       include_plotlyjs=False)
-        # except Exception as e:
-        #     LOG.error(f"Error generating plot shap: {e}")
+        # a dict to hold the raw Figure objects or callables
+        self.explainer_plots: Dict[str, Figure] = {}
 
-        # try:
-        #     fig_contributions = explainer.plot_contributions(
-        #       index=0)
-        #     plots_explainer_html += add_plot_to_html(
-        #       fig_contributions, include_plotlyjs=False)
-        # except Exception as e:
-        #     LOG.error(f"Error generating plot contributions: {e}")
+        # these go into the Test tab
+        for key, fn in [
+            ("roc_auc", explainer.plot_roc_auc),
+            ("pr_auc", explainer.plot_pr_auc),
+            ("lift_curve", explainer.plot_lift_curve),
+            ("confusion_matrix", explainer.plot_confusion_matrix),
+            ("threshold", explainer.plot_precision),  # Percentage vs probability
+            ("cumulative_precision", explainer.plot_cumulative_precision),
+        ]:
+            try:
+                self.explainer_plots[key] = fn()
+            except Exception as e:
+                LOG.error(f"Error generating explainer plot {key}: {e}")
 
-        # try:
-        #     for feature in self.features_name:
-        #         fig_dependence = explainer.plot_dependence(col=feature)
-        #         plots_explainer_html += add_plot_to_html(fig_dependence)
-        # except Exception as e:
-        #     LOG.error(f"Error generating plot dependencies: {e}")
-
+        # mean SHAP importances
         try:
-            for feature in self.features_name:
-                fig_pdp = explainer.plot_pdp(feature)
-                plots_explainer_html += add_plot_to_html(fig_pdp)
-                plots_explainer_html += add_hr_to_html()
+            self.explainer_plots["shap_mean"] = explainer.plot_importances()
         except Exception as e:
-            LOG.error(f"Error generating plot pdp: {e}")
+            LOG.warning(f"Could not generate shap_mean: {e}")
 
-        try:
-            for feature in self.features_name:
-                fig_interaction = explainer.plot_interaction(
-                    col=feature, interact_col=feature)
-                plots_explainer_html += add_plot_to_html(fig_interaction)
-        except Exception as e:
-            LOG.error(f"Error generating plot interactions: {e}")
-
+        # permutation importances
         try:
-            for feature in self.features_name:
-                fig_interactions_importance = \
-                    explainer.plot_interactions_importance(
-                        col=feature)
-                plots_explainer_html += add_plot_to_html(
-                    fig_interactions_importance)
-                plots_explainer_html += add_hr_to_html()
+            self.explainer_plots["shap_perm"] = lambda: explainer.plot_importances(
+                kind="permutation"
+            )
         except Exception as e:
-            LOG.error(f"Error generating plot interactions importance: {e}")
+            LOG.warning(f"Could not generate shap_perm: {e}")
 
-        # try:
-        #     for feature in self.features_name:
-        #         fig_interactions_detailed = \
-        #           explainer.plot_interactions_detailed(
-        #               col=feature)
-        #         plots_explainer_html += add_plot_to_html(
-        #           fig_interactions_detailed)
-        # except Exception as e:
-        #     LOG.error(f"Error generating plot interactions detailed: {e}")
-
-        try:
-            fig_precision = explainer.plot_precision()
-            plots_explainer_html += add_plot_to_html(fig_precision)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot precision: {e}")
-
-        try:
-            fig_cumulative_precision = explainer.plot_cumulative_precision()
-            plots_explainer_html += add_plot_to_html(fig_cumulative_precision)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot cumulative precision: {e}")
+        # PDPs for each feature (appended last)
+        valid_feats = []
+        for feat in self.features_name:
+            if feat in explainer.X.columns or feat in explainer.onehot_cols:
+                valid_feats.append(feat)
+            else:
+                LOG.warning(f"Skipping PDP for feature {feat!r}: not found in explainer data")
 
-        try:
-            fig_classification = explainer.plot_classification()
-            plots_explainer_html += add_plot_to_html(fig_classification)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot classification: {e}")
-
-        try:
-            fig_confusion_matrix = explainer.plot_confusion_matrix()
-            plots_explainer_html += add_plot_to_html(fig_confusion_matrix)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot confusion matrix: {e}")
+        for feat in valid_feats:
+            # wrap each PDP call to catch any unexpected AssertionErrors
+            def make_pdp_plotter(f):
+                def _plot():
+                    try:
+                        return explainer.plot_pdp(f)
+                    except AssertionError as ae:
+                        LOG.warning(f"PDP AssertionError for {f!r}: {ae}")
+                        return None
+                    except Exception as e:
+                        LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}")
+                        return None
+                return _plot
 
-        try:
-            fig_lift_curve = explainer.plot_lift_curve()
-            plots_explainer_html += add_plot_to_html(fig_lift_curve)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot lift curve: {e}")
-
-        try:
-            fig_roc_auc = explainer.plot_roc_auc()
-            plots_explainer_html += add_plot_to_html(fig_roc_auc)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot roc auc: {e}")
-
-        try:
-            fig_pr_auc = explainer.plot_pr_auc()
-            plots_explainer_html += add_plot_to_html(fig_pr_auc)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot pr auc: {e}")
-
-        self.plots_explainer_html = plots_explainer_html
+            self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat)