Mercurial > repos > goeckslab > pycaret_predict
view pycaret_classification.py @ 2:0314dad38aaa draft default tip
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
author | goeckslab |
---|---|
date | Wed, 01 Jan 2025 03:19:27 +0000 |
parents | 1f20fe57fdee |
children |
line wrap: on
line source
import logging from base_model_trainer import BaseModelTrainer from dashboard import generate_classifier_explainer_dashboard from pycaret.classification import ClassificationExperiment from utils import add_hr_to_html, add_plot_to_html, predict_proba LOG = logging.getLogger(__name__) class ClassificationModelTrainer(BaseModelTrainer): def __init__( self, input_file, target_col, output_dir, task_type, random_seed, test_file=None, **kwargs): super().__init__( input_file, target_col, output_dir, task_type, random_seed, test_file, **kwargs) self.exp = ClassificationExperiment() def save_dashboard(self): LOG.info("Saving explainer dashboard") dashboard = generate_classifier_explainer_dashboard(self.exp, self.best_model) dashboard.save_html("dashboard.html") def generate_plots(self): LOG.info("Generating and saving plots") if not hasattr(self.best_model, "predict_proba"): import types self.best_model.predict_proba = types.MethodType( predict_proba, self.best_model) LOG.warning( f"The model {type(self.best_model).__name__}\ does not support `predict_proba`. \ Applying monkey patch.") plots = ['confusion_matrix', 'auc', 'threshold', 'pr', 'error', 'class_report', 'learning', 'calibration', 'vc', 'dimension', 'manifold', 'rfe', 'feature', 'feature_all'] for plot_name in plots: try: if plot_name == 'auc' and not self.exp.is_multiclass: plot_path = self.exp.plot_model(self.best_model, plot=plot_name, save=True, plot_kwargs={ 'micro': False, 'macro': False, 'per_class': False, 'binary': True } ) self.plots[plot_name] = plot_path continue plot_path = self.exp.plot_model(self.best_model, plot=plot_name, save=True) self.plots[plot_name] = plot_path except Exception as e: LOG.error(f"Error generating plot {plot_name}: {e}") continue def generate_plots_explainer(self): LOG.info("Generating and saving plots from explainer") from explainerdashboard import ClassifierExplainer X_test = self.exp.X_test_transformed.copy() y_test = self.exp.y_test_transformed try: explainer = ClassifierExplainer(self.best_model, X_test, y_test) self.expaliner = explainer plots_explainer_html = "" except Exception as e: LOG.error(f"Error creating explainer: {e}") self.plots_explainer_html = None return try: fig_importance = explainer.plot_importances() plots_explainer_html += add_plot_to_html(fig_importance) plots_explainer_html += add_hr_to_html() except Exception as e: LOG.error(f"Error generating plot importance(mean shap): {e}") try: fig_importance_perm = explainer.plot_importances( kind="permutation") plots_explainer_html += add_plot_to_html(fig_importance_perm) plots_explainer_html += add_hr_to_html() except Exception as e: LOG.error(f"Error generating plot importance(permutation): {e}") # try: # fig_shap = explainer.plot_shap_summary() # plots_explainer_html += add_plot_to_html(fig_shap, # include_plotlyjs=False) # except Exception as e: # LOG.error(f"Error generating plot shap: {e}") # try: # fig_contributions = explainer.plot_contributions( # index=0) # plots_explainer_html += add_plot_to_html( # fig_contributions, include_plotlyjs=False) # except Exception as e: # LOG.error(f"Error generating plot contributions: {e}") # try: # for feature in self.features_name: # fig_dependence = explainer.plot_dependence(col=feature) # plots_explainer_html += add_plot_to_html(fig_dependence) # except Exception as e: # LOG.error(f"Error generating plot dependencies: {e}") try: for feature in self.features_name: fig_pdp = explainer.plot_pdp(feature) plots_explainer_html += add_plot_to_html(fig_pdp) plots_explainer_html += add_hr_to_html() except Exception as e: LOG.error(f"Error generating plot pdp: {e}") try: for feature in self.features_name: fig_interaction = explainer.plot_interaction( col=feature, interact_col=feature) plots_explainer_html += add_plot_to_html(fig_interaction) except Exception as e: LOG.error(f"Error generating plot interactions: {e}") try: for feature in self.features_name: fig_interactions_importance = \ explainer.plot_interactions_importance( col=feature) plots_explainer_html += add_plot_to_html( fig_interactions_importance) plots_explainer_html += add_hr_to_html() except Exception as e: LOG.error(f"Error generating plot interactions importance: {e}") # try: # for feature in self.features_name: # fig_interactions_detailed = \ # explainer.plot_interactions_detailed( # col=feature) # plots_explainer_html += add_plot_to_html( # fig_interactions_detailed) # except Exception as e: # LOG.error(f"Error generating plot interactions detailed: {e}") try: fig_precision = explainer.plot_precision() plots_explainer_html += add_plot_to_html(fig_precision) plots_explainer_html += add_hr_to_html() except Exception as e: LOG.error(f"Error generating plot precision: {e}") try: fig_cumulative_precision = explainer.plot_cumulative_precision() plots_explainer_html += add_plot_to_html(fig_cumulative_precision) plots_explainer_html += add_hr_to_html() except Exception as e: LOG.error(f"Error generating plot cumulative precision: {e}") try: fig_classification = explainer.plot_classification() plots_explainer_html += add_plot_to_html(fig_classification) plots_explainer_html += add_hr_to_html() except Exception as e: LOG.error(f"Error generating plot classification: {e}") try: fig_confusion_matrix = explainer.plot_confusion_matrix() plots_explainer_html += add_plot_to_html(fig_confusion_matrix) plots_explainer_html += add_hr_to_html() except Exception as e: LOG.error(f"Error generating plot confusion matrix: {e}") try: fig_lift_curve = explainer.plot_lift_curve() plots_explainer_html += add_plot_to_html(fig_lift_curve) plots_explainer_html += add_hr_to_html() except Exception as e: LOG.error(f"Error generating plot lift curve: {e}") try: fig_roc_auc = explainer.plot_roc_auc() plots_explainer_html += add_plot_to_html(fig_roc_auc) plots_explainer_html += add_hr_to_html() except Exception as e: LOG.error(f"Error generating plot roc auc: {e}") try: fig_pr_auc = explainer.plot_pr_auc() plots_explainer_html += add_plot_to_html(fig_pr_auc) plots_explainer_html += add_hr_to_html() except Exception as e: LOG.error(f"Error generating plot pr auc: {e}") self.plots_explainer_html = plots_explainer_html