Mercurial > repos > goeckslab > pycaret_predict
comparison pycaret_classification.py @ 8:1aed7d47c5ec draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author | goeckslab |
---|---|
date | Fri, 25 Jul 2025 19:02:32 +0000 |
parents | ccd798db5abb |
children |
comparison
equal
deleted
inserted
replaced
7:f4cb41f458fd | 8:1aed7d47c5ec |
---|---|
1 import logging | 1 import logging |
2 import types | |
3 from typing import Dict | |
2 | 4 |
3 from base_model_trainer import BaseModelTrainer | 5 from base_model_trainer import BaseModelTrainer |
4 from dashboard import generate_classifier_explainer_dashboard | 6 from dashboard import generate_classifier_explainer_dashboard |
7 from plotly.graph_objects import Figure | |
5 from pycaret.classification import ClassificationExperiment | 8 from pycaret.classification import ClassificationExperiment |
6 from utils import add_hr_to_html, add_plot_to_html, predict_proba | 9 from utils import predict_proba |
7 | 10 |
8 LOG = logging.getLogger(__name__) | 11 LOG = logging.getLogger(__name__) |
9 | 12 |
10 | 13 |
11 class ClassificationModelTrainer(BaseModelTrainer): | 14 class ClassificationModelTrainer(BaseModelTrainer): |
12 def __init__( | 15 def __init__( |
13 self, | 16 self, |
14 input_file, | 17 input_file, |
15 target_col, | 18 target_col, |
16 output_dir, | 19 output_dir, |
17 task_type, | 20 task_type, |
18 random_seed, | 21 random_seed, |
19 test_file=None, | 22 test_file=None, |
20 **kwargs): | 23 **kwargs, |
24 ): | |
21 super().__init__( | 25 super().__init__( |
22 input_file, | 26 input_file, |
23 target_col, | 27 target_col, |
24 output_dir, | 28 output_dir, |
25 task_type, | 29 task_type, |
26 random_seed, | 30 random_seed, |
27 test_file, | 31 test_file, |
28 **kwargs) | 32 **kwargs, |
33 ) | |
29 self.exp = ClassificationExperiment() | 34 self.exp = ClassificationExperiment() |
30 | 35 |
31 def save_dashboard(self): | 36 def save_dashboard(self): |
32 LOG.info("Saving explainer dashboard") | 37 LOG.info("Saving explainer dashboard") |
33 dashboard = generate_classifier_explainer_dashboard(self.exp, | 38 dashboard = generate_classifier_explainer_dashboard(self.exp, self.best_model) |
34 self.best_model) | |
35 dashboard.save_html("dashboard.html") | 39 dashboard.save_html("dashboard.html") |
36 | 40 |
37 def generate_plots(self): | 41 def generate_plots(self): |
38 LOG.info("Generating and saving plots") | 42 LOG.info("Generating and saving plots") |
39 | 43 |
40 if not hasattr(self.best_model, "predict_proba"): | 44 if not hasattr(self.best_model, "predict_proba"): |
41 import types | |
42 self.best_model.predict_proba = types.MethodType( | 45 self.best_model.predict_proba = types.MethodType( |
43 predict_proba, self.best_model) | 46 predict_proba, self.best_model |
47 ) | |
44 LOG.warning( | 48 LOG.warning( |
45 f"The model {type(self.best_model).__name__}\ | 49 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." |
46 does not support `predict_proba`. \ | 50 ) |
47 Applying monkey patch.") | |
48 | 51 |
49 plots = ['confusion_matrix', 'auc', 'threshold', 'pr', | 52 plots = [ |
50 'error', 'class_report', 'learning', 'calibration', | 53 'confusion_matrix', |
51 'vc', 'dimension', 'manifold', 'rfe', 'feature', | 54 'auc', |
52 'feature_all'] | 55 'threshold', |
56 'pr', | |
57 'error', | |
58 'class_report', | |
59 'learning', | |
60 'calibration', | |
61 'vc', | |
62 'dimension', | |
63 'manifold', | |
64 'rfe', | |
65 'feature', | |
66 'feature_all', | |
67 ] | |
53 for plot_name in plots: | 68 for plot_name in plots: |
54 try: | 69 try: |
55 if plot_name == 'auc' and not self.exp.is_multiclass: | 70 if plot_name == "threshold": |
56 plot_path = self.exp.plot_model(self.best_model, | 71 plot_path = self.exp.plot_model( |
57 plot=plot_name, | 72 self.best_model, |
58 save=True, | 73 plot=plot_name, |
59 plot_kwargs={ | 74 save=True, |
60 'micro': False, | 75 plot_kwargs={"binary": True, "percentage": True}, |
61 'macro': False, | 76 ) |
62 'per_class': False, | |
63 'binary': True | |
64 }) | |
65 self.plots[plot_name] = plot_path | 77 self.plots[plot_name] = plot_path |
66 continue | 78 elif plot_name == "auc" and not self.exp.is_multiclass: |
67 | 79 plot_path = self.exp.plot_model( |
68 plot_path = self.exp.plot_model(self.best_model, | 80 self.best_model, |
69 plot=plot_name, save=True) | 81 plot=plot_name, |
70 self.plots[plot_name] = plot_path | 82 save=True, |
83 plot_kwargs={ | |
84 "micro": False, | |
85 "macro": False, | |
86 "per_class": False, | |
87 "binary": True, | |
88 }, | |
89 ) | |
90 self.plots[plot_name] = plot_path | |
91 else: | |
92 plot_path = self.exp.plot_model( | |
93 self.best_model, plot=plot_name, save=True | |
94 ) | |
95 self.plots[plot_name] = plot_path | |
71 except Exception as e: | 96 except Exception as e: |
72 LOG.error(f"Error generating plot {plot_name}: {e}") | 97 LOG.error(f"Error generating plot {plot_name}: {e}") |
73 continue | 98 continue |
74 | 99 |
75 def generate_plots_explainer(self): | 100 def generate_plots_explainer(self): |
76 LOG.info("Generating and saving plots from explainer") | 101 from explainerdashboard import ClassifierExplainer |
77 | 102 |
78 from explainerdashboard import ClassifierExplainer | 103 LOG.info("Generating explainer plots") |
79 | 104 |
80 X_test = self.exp.X_test_transformed.copy() | 105 X_test = self.exp.X_test_transformed.copy() |
81 y_test = self.exp.y_test_transformed | 106 y_test = self.exp.y_test_transformed |
107 explainer = ClassifierExplainer(self.best_model, X_test, y_test) | |
82 | 108 |
109 # a dict to hold the raw Figure objects or callables | |
110 self.explainer_plots: Dict[str, Figure] = {} | |
111 | |
112 # these go into the Test tab | |
113 for key, fn in [ | |
114 ("roc_auc", explainer.plot_roc_auc), | |
115 ("pr_auc", explainer.plot_pr_auc), | |
116 ("lift_curve", explainer.plot_lift_curve), | |
117 ("confusion_matrix", explainer.plot_confusion_matrix), | |
118 ("threshold", explainer.plot_precision), # Percentage vs probability | |
119 ("cumulative_precision", explainer.plot_cumulative_precision), | |
120 ]: | |
121 try: | |
122 self.explainer_plots[key] = fn() | |
123 except Exception as e: | |
124 LOG.error(f"Error generating explainer plot {key}: {e}") | |
125 | |
126 # mean SHAP importances | |
83 try: | 127 try: |
84 explainer = ClassifierExplainer(self.best_model, X_test, y_test) | 128 self.explainer_plots["shap_mean"] = explainer.plot_importances() |
85 self.expaliner = explainer | |
86 plots_explainer_html = "" | |
87 except Exception as e: | 129 except Exception as e: |
88 LOG.error(f"Error creating explainer: {e}") | 130 LOG.warning(f"Could not generate shap_mean: {e}") |
89 self.plots_explainer_html = None | |
90 return | |
91 | 131 |
132 # permutation importances | |
92 try: | 133 try: |
93 fig_importance = explainer.plot_importances() | 134 self.explainer_plots["shap_perm"] = lambda: explainer.plot_importances( |
94 plots_explainer_html += add_plot_to_html(fig_importance) | 135 kind="permutation" |
95 plots_explainer_html += add_hr_to_html() | 136 ) |
96 except Exception as e: | 137 except Exception as e: |
97 LOG.error(f"Error generating plot importance(mean shap): {e}") | 138 LOG.warning(f"Could not generate shap_perm: {e}") |
98 | 139 |
99 try: | 140 # PDPs for each feature (appended last) |
100 fig_importance_perm = explainer.plot_importances( | 141 valid_feats = [] |
101 kind="permutation") | 142 for feat in self.features_name: |
102 plots_explainer_html += add_plot_to_html(fig_importance_perm) | 143 if feat in explainer.X.columns or feat in explainer.onehot_cols: |
103 plots_explainer_html += add_hr_to_html() | 144 valid_feats.append(feat) |
104 except Exception as e: | 145 else: |
105 LOG.error(f"Error generating plot importance(permutation): {e}") | 146 LOG.warning(f"Skipping PDP for feature {feat!r}: not found in explainer data") |
106 | 147 |
107 # try: | 148 for feat in valid_feats: |
108 # fig_shap = explainer.plot_shap_summary() | 149 # wrap each PDP call to catch any unexpected AssertionErrors |
109 # plots_explainer_html += add_plot_to_html(fig_shap, | 150 def make_pdp_plotter(f): |
110 # include_plotlyjs=False) | 151 def _plot(): |
111 # except Exception as e: | 152 try: |
112 # LOG.error(f"Error generating plot shap: {e}") | 153 return explainer.plot_pdp(f) |
154 except AssertionError as ae: | |
155 LOG.warning(f"PDP AssertionError for {f!r}: {ae}") | |
156 return None | |
157 except Exception as e: | |
158 LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}") | |
159 return None | |
160 return _plot | |
113 | 161 |
114 # try: | 162 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat) |
115 # fig_contributions = explainer.plot_contributions( | |
116 # index=0) | |
117 # plots_explainer_html += add_plot_to_html( | |
118 # fig_contributions, include_plotlyjs=False) | |
119 # except Exception as e: | |
120 # LOG.error(f"Error generating plot contributions: {e}") | |
121 | |
122 # try: | |
123 # for feature in self.features_name: | |
124 # fig_dependence = explainer.plot_dependence(col=feature) | |
125 # plots_explainer_html += add_plot_to_html(fig_dependence) | |
126 # except Exception as e: | |
127 # LOG.error(f"Error generating plot dependencies: {e}") | |
128 | |
129 try: | |
130 for feature in self.features_name: | |
131 fig_pdp = explainer.plot_pdp(feature) | |
132 plots_explainer_html += add_plot_to_html(fig_pdp) | |
133 plots_explainer_html += add_hr_to_html() | |
134 except Exception as e: | |
135 LOG.error(f"Error generating plot pdp: {e}") | |
136 | |
137 try: | |
138 for feature in self.features_name: | |
139 fig_interaction = explainer.plot_interaction( | |
140 col=feature, interact_col=feature) | |
141 plots_explainer_html += add_plot_to_html(fig_interaction) | |
142 except Exception as e: | |
143 LOG.error(f"Error generating plot interactions: {e}") | |
144 | |
145 try: | |
146 for feature in self.features_name: | |
147 fig_interactions_importance = \ | |
148 explainer.plot_interactions_importance( | |
149 col=feature) | |
150 plots_explainer_html += add_plot_to_html( | |
151 fig_interactions_importance) | |
152 plots_explainer_html += add_hr_to_html() | |
153 except Exception as e: | |
154 LOG.error(f"Error generating plot interactions importance: {e}") | |
155 | |
156 # try: | |
157 # for feature in self.features_name: | |
158 # fig_interactions_detailed = \ | |
159 # explainer.plot_interactions_detailed( | |
160 # col=feature) | |
161 # plots_explainer_html += add_plot_to_html( | |
162 # fig_interactions_detailed) | |
163 # except Exception as e: | |
164 # LOG.error(f"Error generating plot interactions detailed: {e}") | |
165 | |
166 try: | |
167 fig_precision = explainer.plot_precision() | |
168 plots_explainer_html += add_plot_to_html(fig_precision) | |
169 plots_explainer_html += add_hr_to_html() | |
170 except Exception as e: | |
171 LOG.error(f"Error generating plot precision: {e}") | |
172 | |
173 try: | |
174 fig_cumulative_precision = explainer.plot_cumulative_precision() | |
175 plots_explainer_html += add_plot_to_html(fig_cumulative_precision) | |
176 plots_explainer_html += add_hr_to_html() | |
177 except Exception as e: | |
178 LOG.error(f"Error generating plot cumulative precision: {e}") | |
179 | |
180 try: | |
181 fig_classification = explainer.plot_classification() | |
182 plots_explainer_html += add_plot_to_html(fig_classification) | |
183 plots_explainer_html += add_hr_to_html() | |
184 except Exception as e: | |
185 LOG.error(f"Error generating plot classification: {e}") | |
186 | |
187 try: | |
188 fig_confusion_matrix = explainer.plot_confusion_matrix() | |
189 plots_explainer_html += add_plot_to_html(fig_confusion_matrix) | |
190 plots_explainer_html += add_hr_to_html() | |
191 except Exception as e: | |
192 LOG.error(f"Error generating plot confusion matrix: {e}") | |
193 | |
194 try: | |
195 fig_lift_curve = explainer.plot_lift_curve() | |
196 plots_explainer_html += add_plot_to_html(fig_lift_curve) | |
197 plots_explainer_html += add_hr_to_html() | |
198 except Exception as e: | |
199 LOG.error(f"Error generating plot lift curve: {e}") | |
200 | |
201 try: | |
202 fig_roc_auc = explainer.plot_roc_auc() | |
203 plots_explainer_html += add_plot_to_html(fig_roc_auc) | |
204 plots_explainer_html += add_hr_to_html() | |
205 except Exception as e: | |
206 LOG.error(f"Error generating plot roc auc: {e}") | |
207 | |
208 try: | |
209 fig_pr_auc = explainer.plot_pr_auc() | |
210 plots_explainer_html += add_plot_to_html(fig_pr_auc) | |
211 plots_explainer_html += add_hr_to_html() | |
212 except Exception as e: | |
213 LOG.error(f"Error generating plot pr auc: {e}") | |
214 | |
215 self.plots_explainer_html = plots_explainer_html |