Mercurial > repos > goeckslab > pycaret_predict
changeset 2:0314dad38aaa draft default tip
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
author | goeckslab |
---|---|
date | Wed, 01 Jan 2025 03:19:27 +0000 |
parents | 4a7df9abe4c4 |
children | |
files | base_model_trainer.py pycaret_classification.py pycaret_regression.py utils.py |
diffstat | 4 files changed, 50 insertions(+), 13 deletions(-) [+] |
line wrap: on
line diff
--- a/base_model_trainer.py Sat Dec 14 23:17:48 2024 +0000 +++ b/base_model_trainer.py Wed Jan 01 03:19:27 2025 +0000 @@ -263,9 +263,13 @@ Best Model Plots</div> <div class="tab" onclick="openTab(event, 'feature')"> Feature Importance</div> - <div class="tab" onclick="openTab(event, 'explainer')"> - Explainer - </div> + """ + if self.plots_explainer_html: + html_content += """ + "<div class="tab" onclick="openTab(event, 'explainer')">" + Explainer Plots</div> + """ + html_content += f""" </div> <div id="summary" class="tab-content"> <h2>Setup Parameters</h2> @@ -299,13 +303,19 @@ <div id="feature" class="tab-content"> {feature_importance_html} </div> + """ + if self.plots_explainer_html: + html_content += f""" <div id="explainer" class="tab-content"> {self.plots_explainer_html} {tree_plots} </div> - {get_html_closing()} - """ - + {get_html_closing()} + """ + else: + html_content += f""" + {get_html_closing()} + """ with open(os.path.join( self.output_dir, "comparison_result.html"), "w") as file: file.write(html_content)
--- a/pycaret_classification.py Sat Dec 14 23:17:48 2024 +0000 +++ b/pycaret_classification.py Wed Jan 01 03:19:27 2025 +0000 @@ -6,7 +6,7 @@ from pycaret.classification import ClassificationExperiment -from utils import add_hr_to_html, add_plot_to_html +from utils import add_hr_to_html, add_plot_to_html, predict_proba LOG = logging.getLogger(__name__) @@ -39,6 +39,16 @@ def generate_plots(self): LOG.info("Generating and saving plots") + + if not hasattr(self.best_model, "predict_proba"): + import types + self.best_model.predict_proba = types.MethodType( + predict_proba, self.best_model) + LOG.warning( + f"The model {type(self.best_model).__name__}\ + does not support `predict_proba`. \ + Applying monkey patch.") + plots = ['confusion_matrix', 'auc', 'threshold', 'pr', 'error', 'class_report', 'learning', 'calibration', 'vc', 'dimension', 'manifold', 'rfe', 'feature', @@ -74,9 +84,14 @@ X_test = self.exp.X_test_transformed.copy() y_test = self.exp.y_test_transformed - explainer = ClassifierExplainer(self.best_model, X_test, y_test) - self.expaliner = explainer - plots_explainer_html = "" + try: + explainer = ClassifierExplainer(self.best_model, X_test, y_test) + self.expaliner = explainer + plots_explainer_html = "" + except Exception as e: + LOG.error(f"Error creating explainer: {e}") + self.plots_explainer_html = None + return try: fig_importance = explainer.plot_importances()
--- a/pycaret_regression.py Sat Dec 14 23:17:48 2024 +0000 +++ b/pycaret_regression.py Wed Jan 01 03:19:27 2025 +0000 @@ -59,9 +59,14 @@ X_test = self.exp.X_test_transformed.copy() y_test = self.exp.y_test_transformed - explainer = RegressionExplainer(self.best_model, X_test, y_test) - self.expaliner = explainer - plots_explainer_html = "" + try: + explainer = RegressionExplainer(self.best_model, X_test, y_test) + self.expaliner = explainer + plots_explainer_html = "" + except Exception as e: + LOG.error(f"Error creating explainer: {e}") + self.plots_explainer_html = None + return try: fig_importance = explainer.plot_importances()
--- a/utils.py Sat Dec 14 23:17:48 2024 +0000 +++ b/utils.py Wed Jan 01 03:19:27 2025 +0000 @@ -1,6 +1,8 @@ import base64 import logging +import numpy as np + logging.basicConfig(level=logging.DEBUG) LOG = logging.getLogger(__name__) @@ -155,3 +157,8 @@ """Convert an image file to a base64 encoded string.""" with open(image_path, "rb") as img_file: return base64.b64encode(img_file.read()).decode("utf-8") + + +def predict_proba(self, X): + pred = self.predict(X) + return np.array([1-pred, pred]).T