Mercurial > repos > goeckslab > pycaret_compare
comparison pycaret_classification.py @ 3:02f7746e7772 draft
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
| author | goeckslab |
|---|---|
| date | Wed, 01 Jan 2025 03:19:40 +0000 |
| parents | 915447b14520 |
| children | 4aa511539199 |
comparison
equal
deleted
inserted
replaced
| 2:009b18a75dc3 | 3:02f7746e7772 |
|---|---|
| 4 | 4 |
| 5 from dashboard import generate_classifier_explainer_dashboard | 5 from dashboard import generate_classifier_explainer_dashboard |
| 6 | 6 |
| 7 from pycaret.classification import ClassificationExperiment | 7 from pycaret.classification import ClassificationExperiment |
| 8 | 8 |
| 9 from utils import add_hr_to_html, add_plot_to_html | 9 from utils import add_hr_to_html, add_plot_to_html, predict_proba |
| 10 | 10 |
| 11 LOG = logging.getLogger(__name__) | 11 LOG = logging.getLogger(__name__) |
| 12 | 12 |
| 13 | 13 |
| 14 class ClassificationModelTrainer(BaseModelTrainer): | 14 class ClassificationModelTrainer(BaseModelTrainer): |
| 37 self.best_model) | 37 self.best_model) |
| 38 dashboard.save_html("dashboard.html") | 38 dashboard.save_html("dashboard.html") |
| 39 | 39 |
| 40 def generate_plots(self): | 40 def generate_plots(self): |
| 41 LOG.info("Generating and saving plots") | 41 LOG.info("Generating and saving plots") |
| 42 | |
| 43 if not hasattr(self.best_model, "predict_proba"): | |
| 44 import types | |
| 45 self.best_model.predict_proba = types.MethodType( | |
| 46 predict_proba, self.best_model) | |
| 47 LOG.warning( | |
| 48 f"The model {type(self.best_model).__name__}\ | |
| 49 does not support `predict_proba`. \ | |
| 50 Applying monkey patch.") | |
| 51 | |
| 42 plots = ['confusion_matrix', 'auc', 'threshold', 'pr', | 52 plots = ['confusion_matrix', 'auc', 'threshold', 'pr', |
| 43 'error', 'class_report', 'learning', 'calibration', | 53 'error', 'class_report', 'learning', 'calibration', |
| 44 'vc', 'dimension', 'manifold', 'rfe', 'feature', | 54 'vc', 'dimension', 'manifold', 'rfe', 'feature', |
| 45 'feature_all'] | 55 'feature_all'] |
| 46 for plot_name in plots: | 56 for plot_name in plots: |
| 72 from explainerdashboard import ClassifierExplainer | 82 from explainerdashboard import ClassifierExplainer |
| 73 | 83 |
| 74 X_test = self.exp.X_test_transformed.copy() | 84 X_test = self.exp.X_test_transformed.copy() |
| 75 y_test = self.exp.y_test_transformed | 85 y_test = self.exp.y_test_transformed |
| 76 | 86 |
| 77 explainer = ClassifierExplainer(self.best_model, X_test, y_test) | 87 try: |
| 78 self.expaliner = explainer | 88 explainer = ClassifierExplainer(self.best_model, X_test, y_test) |
| 79 plots_explainer_html = "" | 89 self.expaliner = explainer |
| 90 plots_explainer_html = "" | |
| 91 except Exception as e: | |
| 92 LOG.error(f"Error creating explainer: {e}") | |
| 93 self.plots_explainer_html = None | |
| 94 return | |
| 80 | 95 |
| 81 try: | 96 try: |
| 82 fig_importance = explainer.plot_importances() | 97 fig_importance = explainer.plot_importances() |
| 83 plots_explainer_html += add_plot_to_html(fig_importance) | 98 plots_explainer_html += add_plot_to_html(fig_importance) |
| 84 plots_explainer_html += add_hr_to_html() | 99 plots_explainer_html += add_hr_to_html() |
