Mercurial > repos > goeckslab > tabular_learner
comparison pycaret_classification.py @ 4:11fdac5affb3 draft
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
| author | goeckslab |
|---|---|
| date | Fri, 25 Jul 2025 19:02:12 +0000 |
| parents | 209b663a4f62 |
| children | ba45bc057d70 |
comparison
equal
deleted
inserted
replaced
| 3:f6a65e05d6ec | 4:11fdac5affb3 |
|---|---|
| 1 import logging | 1 import logging |
| 2 import types | |
| 3 from typing import Dict | |
| 2 | 4 |
| 3 from base_model_trainer import BaseModelTrainer | 5 from base_model_trainer import BaseModelTrainer |
| 4 from dashboard import generate_classifier_explainer_dashboard | 6 from dashboard import generate_classifier_explainer_dashboard |
| 7 from plotly.graph_objects import Figure | |
| 5 from pycaret.classification import ClassificationExperiment | 8 from pycaret.classification import ClassificationExperiment |
| 6 from utils import add_hr_to_html, add_plot_to_html, predict_proba | 9 from utils import predict_proba |
| 7 | 10 |
| 8 LOG = logging.getLogger(__name__) | 11 LOG = logging.getLogger(__name__) |
| 9 | 12 |
| 10 | 13 |
| 11 class ClassificationModelTrainer(BaseModelTrainer): | 14 class ClassificationModelTrainer(BaseModelTrainer): |
| 12 def __init__( | 15 def __init__( |
| 13 self, | 16 self, |
| 14 input_file, | 17 input_file, |
| 15 target_col, | 18 target_col, |
| 16 output_dir, | 19 output_dir, |
| 17 task_type, | 20 task_type, |
| 18 random_seed, | 21 random_seed, |
| 19 test_file=None, | 22 test_file=None, |
| 20 **kwargs): | 23 **kwargs, |
| 24 ): | |
| 21 super().__init__( | 25 super().__init__( |
| 22 input_file, | 26 input_file, |
| 23 target_col, | 27 target_col, |
| 24 output_dir, | 28 output_dir, |
| 25 task_type, | 29 task_type, |
| 26 random_seed, | 30 random_seed, |
| 27 test_file, | 31 test_file, |
| 28 **kwargs) | 32 **kwargs, |
| 33 ) | |
| 29 self.exp = ClassificationExperiment() | 34 self.exp = ClassificationExperiment() |
| 30 | 35 |
| 31 def save_dashboard(self): | 36 def save_dashboard(self): |
| 32 LOG.info("Saving explainer dashboard") | 37 LOG.info("Saving explainer dashboard") |
| 33 dashboard = generate_classifier_explainer_dashboard(self.exp, | 38 dashboard = generate_classifier_explainer_dashboard(self.exp, self.best_model) |
| 34 self.best_model) | |
| 35 dashboard.save_html("dashboard.html") | 39 dashboard.save_html("dashboard.html") |
| 36 | 40 |
| 37 def generate_plots(self): | 41 def generate_plots(self): |
| 38 LOG.info("Generating and saving plots") | 42 LOG.info("Generating and saving plots") |
| 39 | 43 |
| 40 if not hasattr(self.best_model, "predict_proba"): | 44 if not hasattr(self.best_model, "predict_proba"): |
| 41 import types | |
| 42 self.best_model.predict_proba = types.MethodType( | 45 self.best_model.predict_proba = types.MethodType( |
| 43 predict_proba, self.best_model) | 46 predict_proba, self.best_model |
| 47 ) | |
| 44 LOG.warning( | 48 LOG.warning( |
| 45 f"The model {type(self.best_model).__name__}\ | 49 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." |
| 46 does not support `predict_proba`. \ | 50 ) |
| 47 Applying monkey patch.") | |
| 48 | 51 |
| 49 plots = ['confusion_matrix', 'auc', 'threshold', 'pr', | 52 plots = [ |
| 50 'error', 'class_report', 'learning', 'calibration', | 53 'confusion_matrix', |
| 51 'vc', 'dimension', 'manifold', 'rfe', 'feature', | 54 'auc', |
| 52 'feature_all'] | 55 'threshold', |
| 56 'pr', | |
| 57 'error', | |
| 58 'class_report', | |
| 59 'learning', | |
| 60 'calibration', | |
| 61 'vc', | |
| 62 'dimension', | |
| 63 'manifold', | |
| 64 'rfe', | |
| 65 'feature', | |
| 66 'feature_all', | |
| 67 ] | |
| 53 for plot_name in plots: | 68 for plot_name in plots: |
| 54 try: | 69 try: |
| 55 if plot_name == 'auc' and not self.exp.is_multiclass: | 70 if plot_name == "threshold": |
| 56 plot_path = self.exp.plot_model(self.best_model, | 71 plot_path = self.exp.plot_model( |
| 57 plot=plot_name, | 72 self.best_model, |
| 58 save=True, | 73 plot=plot_name, |
| 59 plot_kwargs={ | 74 save=True, |
| 60 'micro': False, | 75 plot_kwargs={"binary": True, "percentage": True}, |
| 61 'macro': False, | 76 ) |
| 62 'per_class': False, | |
| 63 'binary': True | |
| 64 }) | |
| 65 self.plots[plot_name] = plot_path | 77 self.plots[plot_name] = plot_path |
| 66 continue | 78 elif plot_name == "auc" and not self.exp.is_multiclass: |
| 67 | 79 plot_path = self.exp.plot_model( |
| 68 plot_path = self.exp.plot_model(self.best_model, | 80 self.best_model, |
| 69 plot=plot_name, save=True) | 81 plot=plot_name, |
| 70 self.plots[plot_name] = plot_path | 82 save=True, |
| 83 plot_kwargs={ | |
| 84 "micro": False, | |
| 85 "macro": False, | |
| 86 "per_class": False, | |
| 87 "binary": True, | |
| 88 }, | |
| 89 ) | |
| 90 self.plots[plot_name] = plot_path | |
| 91 else: | |
| 92 plot_path = self.exp.plot_model( | |
| 93 self.best_model, plot=plot_name, save=True | |
| 94 ) | |
| 95 self.plots[plot_name] = plot_path | |
| 71 except Exception as e: | 96 except Exception as e: |
| 72 LOG.error(f"Error generating plot {plot_name}: {e}") | 97 LOG.error(f"Error generating plot {plot_name}: {e}") |
| 73 continue | 98 continue |
| 74 | 99 |
| 75 def generate_plots_explainer(self): | 100 def generate_plots_explainer(self): |
| 76 LOG.info("Generating and saving plots from explainer") | 101 from explainerdashboard import ClassifierExplainer |
| 77 | 102 |
| 78 from explainerdashboard import ClassifierExplainer | 103 LOG.info("Generating explainer plots") |
| 79 | 104 |
| 80 X_test = self.exp.X_test_transformed.copy() | 105 X_test = self.exp.X_test_transformed.copy() |
| 81 y_test = self.exp.y_test_transformed | 106 y_test = self.exp.y_test_transformed |
| 107 explainer = ClassifierExplainer(self.best_model, X_test, y_test) | |
| 82 | 108 |
| 109 # a dict to hold the raw Figure objects or callables | |
| 110 self.explainer_plots: Dict[str, Figure] = {} | |
| 111 | |
| 112 # these go into the Test tab | |
| 113 for key, fn in [ | |
| 114 ("roc_auc", explainer.plot_roc_auc), | |
| 115 ("pr_auc", explainer.plot_pr_auc), | |
| 116 ("lift_curve", explainer.plot_lift_curve), | |
| 117 ("confusion_matrix", explainer.plot_confusion_matrix), | |
| 118 ("threshold", explainer.plot_precision), # Percentage vs probability | |
| 119 ("cumulative_precision", explainer.plot_cumulative_precision), | |
| 120 ]: | |
| 121 try: | |
| 122 self.explainer_plots[key] = fn() | |
| 123 except Exception as e: | |
| 124 LOG.error(f"Error generating explainer plot {key}: {e}") | |
| 125 | |
| 126 # mean SHAP importances | |
| 83 try: | 127 try: |
| 84 explainer = ClassifierExplainer(self.best_model, X_test, y_test) | 128 self.explainer_plots["shap_mean"] = explainer.plot_importances() |
| 85 self.expaliner = explainer | |
| 86 plots_explainer_html = "" | |
| 87 except Exception as e: | 129 except Exception as e: |
| 88 LOG.error(f"Error creating explainer: {e}") | 130 LOG.warning(f"Could not generate shap_mean: {e}") |
| 89 self.plots_explainer_html = None | |
| 90 return | |
| 91 | 131 |
| 132 # permutation importances | |
| 92 try: | 133 try: |
| 93 fig_importance = explainer.plot_importances() | 134 self.explainer_plots["shap_perm"] = lambda: explainer.plot_importances( |
| 94 plots_explainer_html += add_plot_to_html(fig_importance) | 135 kind="permutation" |
| 95 plots_explainer_html += add_hr_to_html() | 136 ) |
| 96 except Exception as e: | 137 except Exception as e: |
| 97 LOG.error(f"Error generating plot importance(mean shap): {e}") | 138 LOG.warning(f"Could not generate shap_perm: {e}") |
| 98 | 139 |
| 99 try: | 140 # PDPs for each feature (appended last) |
| 100 fig_importance_perm = explainer.plot_importances( | 141 valid_feats = [] |
| 101 kind="permutation") | 142 for feat in self.features_name: |
| 102 plots_explainer_html += add_plot_to_html(fig_importance_perm) | 143 if feat in explainer.X.columns or feat in explainer.onehot_cols: |
| 103 plots_explainer_html += add_hr_to_html() | 144 valid_feats.append(feat) |
| 104 except Exception as e: | 145 else: |
| 105 LOG.error(f"Error generating plot importance(permutation): {e}") | 146 LOG.warning(f"Skipping PDP for feature {feat!r}: not found in explainer data") |
| 106 | 147 |
| 107 # try: | 148 for feat in valid_feats: |
| 108 # fig_shap = explainer.plot_shap_summary() | 149 # wrap each PDP call to catch any unexpected AssertionErrors |
| 109 # plots_explainer_html += add_plot_to_html(fig_shap, | 150 def make_pdp_plotter(f): |
| 110 # include_plotlyjs=False) | 151 def _plot(): |
| 111 # except Exception as e: | 152 try: |
| 112 # LOG.error(f"Error generating plot shap: {e}") | 153 return explainer.plot_pdp(f) |
| 154 except AssertionError as ae: | |
| 155 LOG.warning(f"PDP AssertionError for {f!r}: {ae}") | |
| 156 return None | |
| 157 except Exception as e: | |
| 158 LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}") | |
| 159 return None | |
| 160 return _plot | |
| 113 | 161 |
| 114 # try: | 162 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat) |
| 115 # fig_contributions = explainer.plot_contributions( | |
| 116 # index=0) | |
| 117 # plots_explainer_html += add_plot_to_html( | |
| 118 # fig_contributions, include_plotlyjs=False) | |
| 119 # except Exception as e: | |
| 120 # LOG.error(f"Error generating plot contributions: {e}") | |
| 121 | |
| 122 # try: | |
| 123 # for feature in self.features_name: | |
| 124 # fig_dependence = explainer.plot_dependence(col=feature) | |
| 125 # plots_explainer_html += add_plot_to_html(fig_dependence) | |
| 126 # except Exception as e: | |
| 127 # LOG.error(f"Error generating plot dependencies: {e}") | |
| 128 | |
| 129 try: | |
| 130 for feature in self.features_name: | |
| 131 fig_pdp = explainer.plot_pdp(feature) | |
| 132 plots_explainer_html += add_plot_to_html(fig_pdp) | |
| 133 plots_explainer_html += add_hr_to_html() | |
| 134 except Exception as e: | |
| 135 LOG.error(f"Error generating plot pdp: {e}") | |
| 136 | |
| 137 try: | |
| 138 for feature in self.features_name: | |
| 139 fig_interaction = explainer.plot_interaction( | |
| 140 col=feature, interact_col=feature) | |
| 141 plots_explainer_html += add_plot_to_html(fig_interaction) | |
| 142 except Exception as e: | |
| 143 LOG.error(f"Error generating plot interactions: {e}") | |
| 144 | |
| 145 try: | |
| 146 for feature in self.features_name: | |
| 147 fig_interactions_importance = \ | |
| 148 explainer.plot_interactions_importance( | |
| 149 col=feature) | |
| 150 plots_explainer_html += add_plot_to_html( | |
| 151 fig_interactions_importance) | |
| 152 plots_explainer_html += add_hr_to_html() | |
| 153 except Exception as e: | |
| 154 LOG.error(f"Error generating plot interactions importance: {e}") | |
| 155 | |
| 156 # try: | |
| 157 # for feature in self.features_name: | |
| 158 # fig_interactions_detailed = \ | |
| 159 # explainer.plot_interactions_detailed( | |
| 160 # col=feature) | |
| 161 # plots_explainer_html += add_plot_to_html( | |
| 162 # fig_interactions_detailed) | |
| 163 # except Exception as e: | |
| 164 # LOG.error(f"Error generating plot interactions detailed: {e}") | |
| 165 | |
| 166 try: | |
| 167 fig_precision = explainer.plot_precision() | |
| 168 plots_explainer_html += add_plot_to_html(fig_precision) | |
| 169 plots_explainer_html += add_hr_to_html() | |
| 170 except Exception as e: | |
| 171 LOG.error(f"Error generating plot precision: {e}") | |
| 172 | |
| 173 try: | |
| 174 fig_cumulative_precision = explainer.plot_cumulative_precision() | |
| 175 plots_explainer_html += add_plot_to_html(fig_cumulative_precision) | |
| 176 plots_explainer_html += add_hr_to_html() | |
| 177 except Exception as e: | |
| 178 LOG.error(f"Error generating plot cumulative precision: {e}") | |
| 179 | |
| 180 try: | |
| 181 fig_classification = explainer.plot_classification() | |
| 182 plots_explainer_html += add_plot_to_html(fig_classification) | |
| 183 plots_explainer_html += add_hr_to_html() | |
| 184 except Exception as e: | |
| 185 LOG.error(f"Error generating plot classification: {e}") | |
| 186 | |
| 187 try: | |
| 188 fig_confusion_matrix = explainer.plot_confusion_matrix() | |
| 189 plots_explainer_html += add_plot_to_html(fig_confusion_matrix) | |
| 190 plots_explainer_html += add_hr_to_html() | |
| 191 except Exception as e: | |
| 192 LOG.error(f"Error generating plot confusion matrix: {e}") | |
| 193 | |
| 194 try: | |
| 195 fig_lift_curve = explainer.plot_lift_curve() | |
| 196 plots_explainer_html += add_plot_to_html(fig_lift_curve) | |
| 197 plots_explainer_html += add_hr_to_html() | |
| 198 except Exception as e: | |
| 199 LOG.error(f"Error generating plot lift curve: {e}") | |
| 200 | |
| 201 try: | |
| 202 fig_roc_auc = explainer.plot_roc_auc() | |
| 203 plots_explainer_html += add_plot_to_html(fig_roc_auc) | |
| 204 plots_explainer_html += add_hr_to_html() | |
| 205 except Exception as e: | |
| 206 LOG.error(f"Error generating plot roc auc: {e}") | |
| 207 | |
| 208 try: | |
| 209 fig_pr_auc = explainer.plot_pr_auc() | |
| 210 plots_explainer_html += add_plot_to_html(fig_pr_auc) | |
| 211 plots_explainer_html += add_hr_to_html() | |
| 212 except Exception as e: | |
| 213 LOG.error(f"Error generating plot pr auc: {e}") | |
| 214 | |
| 215 self.plots_explainer_html = plots_explainer_html |
