diff pycaret_classification.py @ 2:0314dad38aaa draft default tip

planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
author goeckslab
date Wed, 01 Jan 2025 03:19:27 +0000
parents 1f20fe57fdee
children
line wrap: on
line diff
--- a/pycaret_classification.py	Sat Dec 14 23:17:48 2024 +0000
+++ b/pycaret_classification.py	Wed Jan 01 03:19:27 2025 +0000
@@ -6,7 +6,7 @@
 
 from pycaret.classification import ClassificationExperiment
 
-from utils import add_hr_to_html, add_plot_to_html
+from utils import add_hr_to_html, add_plot_to_html, predict_proba
 
 LOG = logging.getLogger(__name__)
 
@@ -39,6 +39,16 @@
 
     def generate_plots(self):
         LOG.info("Generating and saving plots")
+
+        if not hasattr(self.best_model, "predict_proba"):
+            import types
+            self.best_model.predict_proba = types.MethodType(
+                predict_proba, self.best_model)
+            LOG.warning(
+                f"The model {type(self.best_model).__name__}\
+                    does not support `predict_proba`. \
+                    Applying monkey patch.")
+
         plots = ['confusion_matrix', 'auc', 'threshold', 'pr',
                  'error', 'class_report', 'learning', 'calibration',
                  'vc', 'dimension', 'manifold', 'rfe', 'feature',
@@ -74,9 +84,14 @@
         X_test = self.exp.X_test_transformed.copy()
         y_test = self.exp.y_test_transformed
 
-        explainer = ClassifierExplainer(self.best_model, X_test, y_test)
-        self.expaliner = explainer
-        plots_explainer_html = ""
+        try:
+            explainer = ClassifierExplainer(self.best_model, X_test, y_test)
+            self.expaliner = explainer
+            plots_explainer_html = ""
+        except Exception as e:
+            LOG.error(f"Error creating explainer: {e}")
+            self.plots_explainer_html = None
+            return
 
         try:
             fig_importance = explainer.plot_importances()