changeset 2:0314dad38aaa draft default tip

planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
author goeckslab
date Wed, 01 Jan 2025 03:19:27 +0000
parents 4a7df9abe4c4
children
files base_model_trainer.py pycaret_classification.py pycaret_regression.py utils.py
diffstat 4 files changed, 50 insertions(+), 13 deletions(-) [+]
line wrap: on
line diff
--- a/base_model_trainer.py	Sat Dec 14 23:17:48 2024 +0000
+++ b/base_model_trainer.py	Wed Jan 01 03:19:27 2025 +0000
@@ -263,9 +263,13 @@
                 Best Model Plots</div>
                 <div class="tab" onclick="openTab(event, 'feature')">
                 Feature Importance</div>
-                <div class="tab" onclick="openTab(event, 'explainer')">
-                Explainer
-                </div>
+            """
+        if self.plots_explainer_html:
+            html_content += """
+                "<div class="tab" onclick="openTab(event, 'explainer')">"
+                Explainer Plots</div>
+            """
+        html_content += f"""
             </div>
             <div id="summary" class="tab-content">
                 <h2>Setup Parameters</h2>
@@ -299,13 +303,19 @@
             <div id="feature" class="tab-content">
                 {feature_importance_html}
             </div>
+        """
+        if self.plots_explainer_html:
+            html_content += f"""
             <div id="explainer" class="tab-content">
                 {self.plots_explainer_html}
                 {tree_plots}
             </div>
-        {get_html_closing()}
-        """
-
+            {get_html_closing()}
+            """
+        else:
+            html_content += f"""
+            {get_html_closing()}
+            """
         with open(os.path.join(
                 self.output_dir, "comparison_result.html"), "w") as file:
             file.write(html_content)
--- a/pycaret_classification.py	Sat Dec 14 23:17:48 2024 +0000
+++ b/pycaret_classification.py	Wed Jan 01 03:19:27 2025 +0000
@@ -6,7 +6,7 @@
 
 from pycaret.classification import ClassificationExperiment
 
-from utils import add_hr_to_html, add_plot_to_html
+from utils import add_hr_to_html, add_plot_to_html, predict_proba
 
 LOG = logging.getLogger(__name__)
 
@@ -39,6 +39,16 @@
 
     def generate_plots(self):
         LOG.info("Generating and saving plots")
+
+        if not hasattr(self.best_model, "predict_proba"):
+            import types
+            self.best_model.predict_proba = types.MethodType(
+                predict_proba, self.best_model)
+            LOG.warning(
+                f"The model {type(self.best_model).__name__}\
+                    does not support `predict_proba`. \
+                    Applying monkey patch.")
+
         plots = ['confusion_matrix', 'auc', 'threshold', 'pr',
                  'error', 'class_report', 'learning', 'calibration',
                  'vc', 'dimension', 'manifold', 'rfe', 'feature',
@@ -74,9 +84,14 @@
         X_test = self.exp.X_test_transformed.copy()
         y_test = self.exp.y_test_transformed
 
-        explainer = ClassifierExplainer(self.best_model, X_test, y_test)
-        self.expaliner = explainer
-        plots_explainer_html = ""
+        try:
+            explainer = ClassifierExplainer(self.best_model, X_test, y_test)
+            self.expaliner = explainer
+            plots_explainer_html = ""
+        except Exception as e:
+            LOG.error(f"Error creating explainer: {e}")
+            self.plots_explainer_html = None
+            return
 
         try:
             fig_importance = explainer.plot_importances()
--- a/pycaret_regression.py	Sat Dec 14 23:17:48 2024 +0000
+++ b/pycaret_regression.py	Wed Jan 01 03:19:27 2025 +0000
@@ -59,9 +59,14 @@
         X_test = self.exp.X_test_transformed.copy()
         y_test = self.exp.y_test_transformed
 
-        explainer = RegressionExplainer(self.best_model, X_test, y_test)
-        self.expaliner = explainer
-        plots_explainer_html = ""
+        try:
+            explainer = RegressionExplainer(self.best_model, X_test, y_test)
+            self.expaliner = explainer
+            plots_explainer_html = ""
+        except Exception as e:
+            LOG.error(f"Error creating explainer: {e}")
+            self.plots_explainer_html = None
+            return
 
         try:
             fig_importance = explainer.plot_importances()
--- a/utils.py	Sat Dec 14 23:17:48 2024 +0000
+++ b/utils.py	Wed Jan 01 03:19:27 2025 +0000
@@ -1,6 +1,8 @@
 import base64
 import logging
 
+import numpy as np
+
 logging.basicConfig(level=logging.DEBUG)
 LOG = logging.getLogger(__name__)
 
@@ -155,3 +157,8 @@
     """Convert an image file to a base64 encoded string."""
     with open(image_path, "rb") as img_file:
         return base64.b64encode(img_file.read()).decode("utf-8")
+
+
+def predict_proba(self, X):
+    pred = self.predict(X)
+    return np.array([1-pred, pred]).T