diff pycaret_classification.py @ 0:1f20fe57fdee draft

planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
author goeckslab
date Wed, 11 Dec 2024 04:59:43 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pycaret_classification.py	Wed Dec 11 04:59:43 2024 +0000
@@ -0,0 +1,204 @@
+import logging
+
+from base_model_trainer import BaseModelTrainer
+
+from dashboard import generate_classifier_explainer_dashboard
+
+from pycaret.classification import ClassificationExperiment
+
+from utils import add_hr_to_html, add_plot_to_html
+
+LOG = logging.getLogger(__name__)
+
+
+class ClassificationModelTrainer(BaseModelTrainer):
+    def __init__(
+            self,
+            input_file,
+            target_col,
+            output_dir,
+            task_type,
+            random_seed,
+            test_file=None,
+            **kwargs):
+        super().__init__(
+            input_file,
+            target_col,
+            output_dir,
+            task_type,
+            random_seed,
+            test_file,
+            **kwargs)
+        self.exp = ClassificationExperiment()
+
+    def save_dashboard(self):
+        LOG.info("Saving explainer dashboard")
+        dashboard = generate_classifier_explainer_dashboard(self.exp,
+                                                            self.best_model)
+        dashboard.save_html("dashboard.html")
+
+    def generate_plots(self):
+        LOG.info("Generating and saving plots")
+        plots = ['confusion_matrix', 'auc', 'threshold', 'pr',
+                 'error', 'class_report', 'learning', 'calibration',
+                 'vc', 'dimension', 'manifold', 'rfe', 'feature',
+                 'feature_all']
+        for plot_name in plots:
+            try:
+                if plot_name == 'auc' and not self.exp.is_multiclass:
+                    plot_path = self.exp.plot_model(self.best_model,
+                                                    plot=plot_name,
+                                                    save=True,
+                                                    plot_kwargs={
+                                                        'micro': False,
+                                                        'macro': False,
+                                                        'per_class': False,
+                                                        'binary': True
+                                                        }
+                                                    )
+                    self.plots[plot_name] = plot_path
+                    continue
+
+                plot_path = self.exp.plot_model(self.best_model,
+                                                plot=plot_name, save=True)
+                self.plots[plot_name] = plot_path
+            except Exception as e:
+                LOG.error(f"Error generating plot {plot_name}: {e}")
+                continue
+
+    def generate_plots_explainer(self):
+        LOG.info("Generating and saving plots from explainer")
+
+        from explainerdashboard import ClassifierExplainer
+
+        X_test = self.exp.X_test_transformed.copy()
+        y_test = self.exp.y_test_transformed
+
+        explainer = ClassifierExplainer(self.best_model, X_test, y_test)
+        self.expaliner = explainer
+        plots_explainer_html = ""
+
+        try:
+            fig_importance = explainer.plot_importances()
+            plots_explainer_html += add_plot_to_html(fig_importance)
+            plots_explainer_html += add_hr_to_html()
+        except Exception as e:
+            LOG.error(f"Error generating plot importance(mean shap): {e}")
+
+        try:
+            fig_importance_perm = explainer.plot_importances(
+                kind="permutation")
+            plots_explainer_html += add_plot_to_html(fig_importance_perm)
+            plots_explainer_html += add_hr_to_html()
+        except Exception as e:
+            LOG.error(f"Error generating plot importance(permutation): {e}")
+
+        # try:
+        #     fig_shap = explainer.plot_shap_summary()
+        #     plots_explainer_html += add_plot_to_html(fig_shap,
+        #       include_plotlyjs=False)
+        # except Exception as e:
+        #     LOG.error(f"Error generating plot shap: {e}")
+
+        # try:
+        #     fig_contributions = explainer.plot_contributions(
+        #       index=0)
+        #     plots_explainer_html += add_plot_to_html(
+        #       fig_contributions, include_plotlyjs=False)
+        # except Exception as e:
+        #     LOG.error(f"Error generating plot contributions: {e}")
+
+        # try:
+        #     for feature in self.features_name:
+        #         fig_dependence = explainer.plot_dependence(col=feature)
+        #         plots_explainer_html += add_plot_to_html(fig_dependence)
+        # except Exception as e:
+        #     LOG.error(f"Error generating plot dependencies: {e}")
+
+        try:
+            for feature in self.features_name:
+                fig_pdp = explainer.plot_pdp(feature)
+                plots_explainer_html += add_plot_to_html(fig_pdp)
+                plots_explainer_html += add_hr_to_html()
+        except Exception as e:
+            LOG.error(f"Error generating plot pdp: {e}")
+
+        try:
+            for feature in self.features_name:
+                fig_interaction = explainer.plot_interaction(
+                    col=feature, interact_col=feature)
+                plots_explainer_html += add_plot_to_html(fig_interaction)
+        except Exception as e:
+            LOG.error(f"Error generating plot interactions: {e}")
+
+        try:
+            for feature in self.features_name:
+                fig_interactions_importance = \
+                    explainer.plot_interactions_importance(
+                        col=feature)
+                plots_explainer_html += add_plot_to_html(
+                    fig_interactions_importance)
+                plots_explainer_html += add_hr_to_html()
+        except Exception as e:
+            LOG.error(f"Error generating plot interactions importance: {e}")
+
+        # try:
+        #     for feature in self.features_name:
+        #         fig_interactions_detailed = \
+        #           explainer.plot_interactions_detailed(
+        #               col=feature)
+        #         plots_explainer_html += add_plot_to_html(
+        #           fig_interactions_detailed)
+        # except Exception as e:
+        #     LOG.error(f"Error generating plot interactions detailed: {e}")
+
+        try:
+            fig_precision = explainer.plot_precision()
+            plots_explainer_html += add_plot_to_html(fig_precision)
+            plots_explainer_html += add_hr_to_html()
+        except Exception as e:
+            LOG.error(f"Error generating plot precision: {e}")
+
+        try:
+            fig_cumulative_precision = explainer.plot_cumulative_precision()
+            plots_explainer_html += add_plot_to_html(fig_cumulative_precision)
+            plots_explainer_html += add_hr_to_html()
+        except Exception as e:
+            LOG.error(f"Error generating plot cumulative precision: {e}")
+
+        try:
+            fig_classification = explainer.plot_classification()
+            plots_explainer_html += add_plot_to_html(fig_classification)
+            plots_explainer_html += add_hr_to_html()
+        except Exception as e:
+            LOG.error(f"Error generating plot classification: {e}")
+
+        try:
+            fig_confusion_matrix = explainer.plot_confusion_matrix()
+            plots_explainer_html += add_plot_to_html(fig_confusion_matrix)
+            plots_explainer_html += add_hr_to_html()
+        except Exception as e:
+            LOG.error(f"Error generating plot confusion matrix: {e}")
+
+        try:
+            fig_lift_curve = explainer.plot_lift_curve()
+            plots_explainer_html += add_plot_to_html(fig_lift_curve)
+            plots_explainer_html += add_hr_to_html()
+        except Exception as e:
+            LOG.error(f"Error generating plot lift curve: {e}")
+
+        try:
+            fig_roc_auc = explainer.plot_roc_auc()
+            plots_explainer_html += add_plot_to_html(fig_roc_auc)
+            plots_explainer_html += add_hr_to_html()
+        except Exception as e:
+            LOG.error(f"Error generating plot roc auc: {e}")
+
+        try:
+            fig_pr_auc = explainer.plot_pr_auc()
+            plots_explainer_html += add_plot_to_html(fig_pr_auc)
+            plots_explainer_html += add_hr_to_html()
+        except Exception as e:
+            LOG.error(f"Error generating plot pr auc: {e}")
+
+        self.plots_explainer_html = plots_explainer_html