Mercurial > repos > goeckslab > pycaret_predict
diff pycaret_classification.py @ 8:1aed7d47c5ec draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author | goeckslab |
---|---|
date | Fri, 25 Jul 2025 19:02:32 +0000 |
parents | ccd798db5abb |
children |
line wrap: on
line diff
--- a/pycaret_classification.py Wed Jul 09 01:13:01 2025 +0000 +++ b/pycaret_classification.py Fri Jul 25 19:02:32 2025 +0000 @@ -1,23 +1,27 @@ import logging +import types +from typing import Dict from base_model_trainer import BaseModelTrainer from dashboard import generate_classifier_explainer_dashboard +from plotly.graph_objects import Figure from pycaret.classification import ClassificationExperiment -from utils import add_hr_to_html, add_plot_to_html, predict_proba +from utils import predict_proba LOG = logging.getLogger(__name__) class ClassificationModelTrainer(BaseModelTrainer): def __init__( - self, - input_file, - target_col, - output_dir, - task_type, - random_seed, - test_file=None, - **kwargs): + self, + input_file, + target_col, + output_dir, + task_type, + random_seed, + test_file=None, + **kwargs, + ): super().__init__( input_file, target_col, @@ -25,191 +29,134 @@ task_type, random_seed, test_file, - **kwargs) + **kwargs, + ) self.exp = ClassificationExperiment() def save_dashboard(self): LOG.info("Saving explainer dashboard") - dashboard = generate_classifier_explainer_dashboard(self.exp, - self.best_model) + dashboard = generate_classifier_explainer_dashboard(self.exp, self.best_model) dashboard.save_html("dashboard.html") def generate_plots(self): LOG.info("Generating and saving plots") if not hasattr(self.best_model, "predict_proba"): - import types self.best_model.predict_proba = types.MethodType( - predict_proba, self.best_model) + predict_proba, self.best_model + ) LOG.warning( - f"The model {type(self.best_model).__name__}\ - does not support `predict_proba`. \ - Applying monkey patch.") + f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." + ) - plots = ['confusion_matrix', 'auc', 'threshold', 'pr', - 'error', 'class_report', 'learning', 'calibration', - 'vc', 'dimension', 'manifold', 'rfe', 'feature', - 'feature_all'] + plots = [ + 'confusion_matrix', + 'auc', + 'threshold', + 'pr', + 'error', + 'class_report', + 'learning', + 'calibration', + 'vc', + 'dimension', + 'manifold', + 'rfe', + 'feature', + 'feature_all', + ] for plot_name in plots: try: - if plot_name == 'auc' and not self.exp.is_multiclass: - plot_path = self.exp.plot_model(self.best_model, - plot=plot_name, - save=True, - plot_kwargs={ - 'micro': False, - 'macro': False, - 'per_class': False, - 'binary': True - }) + if plot_name == "threshold": + plot_path = self.exp.plot_model( + self.best_model, + plot=plot_name, + save=True, + plot_kwargs={"binary": True, "percentage": True}, + ) self.plots[plot_name] = plot_path - continue - - plot_path = self.exp.plot_model(self.best_model, - plot=plot_name, save=True) - self.plots[plot_name] = plot_path + elif plot_name == "auc" and not self.exp.is_multiclass: + plot_path = self.exp.plot_model( + self.best_model, + plot=plot_name, + save=True, + plot_kwargs={ + "micro": False, + "macro": False, + "per_class": False, + "binary": True, + }, + ) + self.plots[plot_name] = plot_path + else: + plot_path = self.exp.plot_model( + self.best_model, plot=plot_name, save=True + ) + self.plots[plot_name] = plot_path except Exception as e: LOG.error(f"Error generating plot {plot_name}: {e}") continue def generate_plots_explainer(self): - LOG.info("Generating and saving plots from explainer") + from explainerdashboard import ClassifierExplainer - from explainerdashboard import ClassifierExplainer + LOG.info("Generating explainer plots") X_test = self.exp.X_test_transformed.copy() y_test = self.exp.y_test_transformed - - try: - explainer = ClassifierExplainer(self.best_model, X_test, y_test) - self.expaliner = explainer - plots_explainer_html = "" - except Exception as e: - LOG.error(f"Error creating explainer: {e}") - self.plots_explainer_html = None - return + explainer = ClassifierExplainer(self.best_model, X_test, y_test) - try: - fig_importance = explainer.plot_importances() - plots_explainer_html += add_plot_to_html(fig_importance) - plots_explainer_html += add_hr_to_html() - except Exception as e: - LOG.error(f"Error generating plot importance(mean shap): {e}") - - try: - fig_importance_perm = explainer.plot_importances( - kind="permutation") - plots_explainer_html += add_plot_to_html(fig_importance_perm) - plots_explainer_html += add_hr_to_html() - except Exception as e: - LOG.error(f"Error generating plot importance(permutation): {e}") - - # try: - # fig_shap = explainer.plot_shap_summary() - # plots_explainer_html += add_plot_to_html(fig_shap, - # include_plotlyjs=False) - # except Exception as e: - # LOG.error(f"Error generating plot shap: {e}") + # a dict to hold the raw Figure objects or callables + self.explainer_plots: Dict[str, Figure] = {} - # try: - # fig_contributions = explainer.plot_contributions( - # index=0) - # plots_explainer_html += add_plot_to_html( - # fig_contributions, include_plotlyjs=False) - # except Exception as e: - # LOG.error(f"Error generating plot contributions: {e}") + # these go into the Test tab + for key, fn in [ + ("roc_auc", explainer.plot_roc_auc), + ("pr_auc", explainer.plot_pr_auc), + ("lift_curve", explainer.plot_lift_curve), + ("confusion_matrix", explainer.plot_confusion_matrix), + ("threshold", explainer.plot_precision), # Percentage vs probability + ("cumulative_precision", explainer.plot_cumulative_precision), + ]: + try: + self.explainer_plots[key] = fn() + except Exception as e: + LOG.error(f"Error generating explainer plot {key}: {e}") - # try: - # for feature in self.features_name: - # fig_dependence = explainer.plot_dependence(col=feature) - # plots_explainer_html += add_plot_to_html(fig_dependence) - # except Exception as e: - # LOG.error(f"Error generating plot dependencies: {e}") - + # mean SHAP importances try: - for feature in self.features_name: - fig_pdp = explainer.plot_pdp(feature) - plots_explainer_html += add_plot_to_html(fig_pdp) - plots_explainer_html += add_hr_to_html() + self.explainer_plots["shap_mean"] = explainer.plot_importances() except Exception as e: - LOG.error(f"Error generating plot pdp: {e}") + LOG.warning(f"Could not generate shap_mean: {e}") - try: - for feature in self.features_name: - fig_interaction = explainer.plot_interaction( - col=feature, interact_col=feature) - plots_explainer_html += add_plot_to_html(fig_interaction) - except Exception as e: - LOG.error(f"Error generating plot interactions: {e}") - + # permutation importances try: - for feature in self.features_name: - fig_interactions_importance = \ - explainer.plot_interactions_importance( - col=feature) - plots_explainer_html += add_plot_to_html( - fig_interactions_importance) - plots_explainer_html += add_hr_to_html() + self.explainer_plots["shap_perm"] = lambda: explainer.plot_importances( + kind="permutation" + ) except Exception as e: - LOG.error(f"Error generating plot interactions importance: {e}") + LOG.warning(f"Could not generate shap_perm: {e}") - # try: - # for feature in self.features_name: - # fig_interactions_detailed = \ - # explainer.plot_interactions_detailed( - # col=feature) - # plots_explainer_html += add_plot_to_html( - # fig_interactions_detailed) - # except Exception as e: - # LOG.error(f"Error generating plot interactions detailed: {e}") - - try: - fig_precision = explainer.plot_precision() - plots_explainer_html += add_plot_to_html(fig_precision) - plots_explainer_html += add_hr_to_html() - except Exception as e: - LOG.error(f"Error generating plot precision: {e}") - - try: - fig_cumulative_precision = explainer.plot_cumulative_precision() - plots_explainer_html += add_plot_to_html(fig_cumulative_precision) - plots_explainer_html += add_hr_to_html() - except Exception as e: - LOG.error(f"Error generating plot cumulative precision: {e}") + # PDPs for each feature (appended last) + valid_feats = [] + for feat in self.features_name: + if feat in explainer.X.columns or feat in explainer.onehot_cols: + valid_feats.append(feat) + else: + LOG.warning(f"Skipping PDP for feature {feat!r}: not found in explainer data") - try: - fig_classification = explainer.plot_classification() - plots_explainer_html += add_plot_to_html(fig_classification) - plots_explainer_html += add_hr_to_html() - except Exception as e: - LOG.error(f"Error generating plot classification: {e}") - - try: - fig_confusion_matrix = explainer.plot_confusion_matrix() - plots_explainer_html += add_plot_to_html(fig_confusion_matrix) - plots_explainer_html += add_hr_to_html() - except Exception as e: - LOG.error(f"Error generating plot confusion matrix: {e}") + for feat in valid_feats: + # wrap each PDP call to catch any unexpected AssertionErrors + def make_pdp_plotter(f): + def _plot(): + try: + return explainer.plot_pdp(f) + except AssertionError as ae: + LOG.warning(f"PDP AssertionError for {f!r}: {ae}") + return None + except Exception as e: + LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}") + return None + return _plot - try: - fig_lift_curve = explainer.plot_lift_curve() - plots_explainer_html += add_plot_to_html(fig_lift_curve) - plots_explainer_html += add_hr_to_html() - except Exception as e: - LOG.error(f"Error generating plot lift curve: {e}") - - try: - fig_roc_auc = explainer.plot_roc_auc() - plots_explainer_html += add_plot_to_html(fig_roc_auc) - plots_explainer_html += add_hr_to_html() - except Exception as e: - LOG.error(f"Error generating plot roc auc: {e}") - - try: - fig_pr_auc = explainer.plot_pr_auc() - plots_explainer_html += add_plot_to_html(fig_pr_auc) - plots_explainer_html += add_hr_to_html() - except Exception as e: - LOG.error(f"Error generating plot pr auc: {e}") - - self.plots_explainer_html = plots_explainer_html + self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat)