comparison pycaret_classification.py @ 8:1aed7d47c5ec draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author goeckslab
date Fri, 25 Jul 2025 19:02:32 +0000
parents ccd798db5abb
children
comparison
equal deleted inserted replaced
7:f4cb41f458fd 8:1aed7d47c5ec
1 import logging 1 import logging
2 import types
3 from typing import Dict
2 4
3 from base_model_trainer import BaseModelTrainer 5 from base_model_trainer import BaseModelTrainer
4 from dashboard import generate_classifier_explainer_dashboard 6 from dashboard import generate_classifier_explainer_dashboard
7 from plotly.graph_objects import Figure
5 from pycaret.classification import ClassificationExperiment 8 from pycaret.classification import ClassificationExperiment
6 from utils import add_hr_to_html, add_plot_to_html, predict_proba 9 from utils import predict_proba
7 10
8 LOG = logging.getLogger(__name__) 11 LOG = logging.getLogger(__name__)
9 12
10 13
11 class ClassificationModelTrainer(BaseModelTrainer): 14 class ClassificationModelTrainer(BaseModelTrainer):
12 def __init__( 15 def __init__(
13 self, 16 self,
14 input_file, 17 input_file,
15 target_col, 18 target_col,
16 output_dir, 19 output_dir,
17 task_type, 20 task_type,
18 random_seed, 21 random_seed,
19 test_file=None, 22 test_file=None,
20 **kwargs): 23 **kwargs,
24 ):
21 super().__init__( 25 super().__init__(
22 input_file, 26 input_file,
23 target_col, 27 target_col,
24 output_dir, 28 output_dir,
25 task_type, 29 task_type,
26 random_seed, 30 random_seed,
27 test_file, 31 test_file,
28 **kwargs) 32 **kwargs,
33 )
29 self.exp = ClassificationExperiment() 34 self.exp = ClassificationExperiment()
30 35
31 def save_dashboard(self): 36 def save_dashboard(self):
32 LOG.info("Saving explainer dashboard") 37 LOG.info("Saving explainer dashboard")
33 dashboard = generate_classifier_explainer_dashboard(self.exp, 38 dashboard = generate_classifier_explainer_dashboard(self.exp, self.best_model)
34 self.best_model)
35 dashboard.save_html("dashboard.html") 39 dashboard.save_html("dashboard.html")
36 40
37 def generate_plots(self): 41 def generate_plots(self):
38 LOG.info("Generating and saving plots") 42 LOG.info("Generating and saving plots")
39 43
40 if not hasattr(self.best_model, "predict_proba"): 44 if not hasattr(self.best_model, "predict_proba"):
41 import types
42 self.best_model.predict_proba = types.MethodType( 45 self.best_model.predict_proba = types.MethodType(
43 predict_proba, self.best_model) 46 predict_proba, self.best_model
47 )
44 LOG.warning( 48 LOG.warning(
45 f"The model {type(self.best_model).__name__}\ 49 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch."
46 does not support `predict_proba`. \ 50 )
47 Applying monkey patch.")
48 51
49 plots = ['confusion_matrix', 'auc', 'threshold', 'pr', 52 plots = [
50 'error', 'class_report', 'learning', 'calibration', 53 'confusion_matrix',
51 'vc', 'dimension', 'manifold', 'rfe', 'feature', 54 'auc',
52 'feature_all'] 55 'threshold',
56 'pr',
57 'error',
58 'class_report',
59 'learning',
60 'calibration',
61 'vc',
62 'dimension',
63 'manifold',
64 'rfe',
65 'feature',
66 'feature_all',
67 ]
53 for plot_name in plots: 68 for plot_name in plots:
54 try: 69 try:
55 if plot_name == 'auc' and not self.exp.is_multiclass: 70 if plot_name == "threshold":
56 plot_path = self.exp.plot_model(self.best_model, 71 plot_path = self.exp.plot_model(
57 plot=plot_name, 72 self.best_model,
58 save=True, 73 plot=plot_name,
59 plot_kwargs={ 74 save=True,
60 'micro': False, 75 plot_kwargs={"binary": True, "percentage": True},
61 'macro': False, 76 )
62 'per_class': False,
63 'binary': True
64 })
65 self.plots[plot_name] = plot_path 77 self.plots[plot_name] = plot_path
66 continue 78 elif plot_name == "auc" and not self.exp.is_multiclass:
67 79 plot_path = self.exp.plot_model(
68 plot_path = self.exp.plot_model(self.best_model, 80 self.best_model,
69 plot=plot_name, save=True) 81 plot=plot_name,
70 self.plots[plot_name] = plot_path 82 save=True,
83 plot_kwargs={
84 "micro": False,
85 "macro": False,
86 "per_class": False,
87 "binary": True,
88 },
89 )
90 self.plots[plot_name] = plot_path
91 else:
92 plot_path = self.exp.plot_model(
93 self.best_model, plot=plot_name, save=True
94 )
95 self.plots[plot_name] = plot_path
71 except Exception as e: 96 except Exception as e:
72 LOG.error(f"Error generating plot {plot_name}: {e}") 97 LOG.error(f"Error generating plot {plot_name}: {e}")
73 continue 98 continue
74 99
75 def generate_plots_explainer(self): 100 def generate_plots_explainer(self):
76 LOG.info("Generating and saving plots from explainer") 101 from explainerdashboard import ClassifierExplainer
77 102
78 from explainerdashboard import ClassifierExplainer 103 LOG.info("Generating explainer plots")
79 104
80 X_test = self.exp.X_test_transformed.copy() 105 X_test = self.exp.X_test_transformed.copy()
81 y_test = self.exp.y_test_transformed 106 y_test = self.exp.y_test_transformed
107 explainer = ClassifierExplainer(self.best_model, X_test, y_test)
82 108
109 # a dict to hold the raw Figure objects or callables
110 self.explainer_plots: Dict[str, Figure] = {}
111
112 # these go into the Test tab
113 for key, fn in [
114 ("roc_auc", explainer.plot_roc_auc),
115 ("pr_auc", explainer.plot_pr_auc),
116 ("lift_curve", explainer.plot_lift_curve),
117 ("confusion_matrix", explainer.plot_confusion_matrix),
118 ("threshold", explainer.plot_precision), # Percentage vs probability
119 ("cumulative_precision", explainer.plot_cumulative_precision),
120 ]:
121 try:
122 self.explainer_plots[key] = fn()
123 except Exception as e:
124 LOG.error(f"Error generating explainer plot {key}: {e}")
125
126 # mean SHAP importances
83 try: 127 try:
84 explainer = ClassifierExplainer(self.best_model, X_test, y_test) 128 self.explainer_plots["shap_mean"] = explainer.plot_importances()
85 self.expaliner = explainer
86 plots_explainer_html = ""
87 except Exception as e: 129 except Exception as e:
88 LOG.error(f"Error creating explainer: {e}") 130 LOG.warning(f"Could not generate shap_mean: {e}")
89 self.plots_explainer_html = None
90 return
91 131
132 # permutation importances
92 try: 133 try:
93 fig_importance = explainer.plot_importances() 134 self.explainer_plots["shap_perm"] = lambda: explainer.plot_importances(
94 plots_explainer_html += add_plot_to_html(fig_importance) 135 kind="permutation"
95 plots_explainer_html += add_hr_to_html() 136 )
96 except Exception as e: 137 except Exception as e:
97 LOG.error(f"Error generating plot importance(mean shap): {e}") 138 LOG.warning(f"Could not generate shap_perm: {e}")
98 139
99 try: 140 # PDPs for each feature (appended last)
100 fig_importance_perm = explainer.plot_importances( 141 valid_feats = []
101 kind="permutation") 142 for feat in self.features_name:
102 plots_explainer_html += add_plot_to_html(fig_importance_perm) 143 if feat in explainer.X.columns or feat in explainer.onehot_cols:
103 plots_explainer_html += add_hr_to_html() 144 valid_feats.append(feat)
104 except Exception as e: 145 else:
105 LOG.error(f"Error generating plot importance(permutation): {e}") 146 LOG.warning(f"Skipping PDP for feature {feat!r}: not found in explainer data")
106 147
107 # try: 148 for feat in valid_feats:
108 # fig_shap = explainer.plot_shap_summary() 149 # wrap each PDP call to catch any unexpected AssertionErrors
109 # plots_explainer_html += add_plot_to_html(fig_shap, 150 def make_pdp_plotter(f):
110 # include_plotlyjs=False) 151 def _plot():
111 # except Exception as e: 152 try:
112 # LOG.error(f"Error generating plot shap: {e}") 153 return explainer.plot_pdp(f)
154 except AssertionError as ae:
155 LOG.warning(f"PDP AssertionError for {f!r}: {ae}")
156 return None
157 except Exception as e:
158 LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}")
159 return None
160 return _plot
113 161
114 # try: 162 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat)
115 # fig_contributions = explainer.plot_contributions(
116 # index=0)
117 # plots_explainer_html += add_plot_to_html(
118 # fig_contributions, include_plotlyjs=False)
119 # except Exception as e:
120 # LOG.error(f"Error generating plot contributions: {e}")
121
122 # try:
123 # for feature in self.features_name:
124 # fig_dependence = explainer.plot_dependence(col=feature)
125 # plots_explainer_html += add_plot_to_html(fig_dependence)
126 # except Exception as e:
127 # LOG.error(f"Error generating plot dependencies: {e}")
128
129 try:
130 for feature in self.features_name:
131 fig_pdp = explainer.plot_pdp(feature)
132 plots_explainer_html += add_plot_to_html(fig_pdp)
133 plots_explainer_html += add_hr_to_html()
134 except Exception as e:
135 LOG.error(f"Error generating plot pdp: {e}")
136
137 try:
138 for feature in self.features_name:
139 fig_interaction = explainer.plot_interaction(
140 col=feature, interact_col=feature)
141 plots_explainer_html += add_plot_to_html(fig_interaction)
142 except Exception as e:
143 LOG.error(f"Error generating plot interactions: {e}")
144
145 try:
146 for feature in self.features_name:
147 fig_interactions_importance = \
148 explainer.plot_interactions_importance(
149 col=feature)
150 plots_explainer_html += add_plot_to_html(
151 fig_interactions_importance)
152 plots_explainer_html += add_hr_to_html()
153 except Exception as e:
154 LOG.error(f"Error generating plot interactions importance: {e}")
155
156 # try:
157 # for feature in self.features_name:
158 # fig_interactions_detailed = \
159 # explainer.plot_interactions_detailed(
160 # col=feature)
161 # plots_explainer_html += add_plot_to_html(
162 # fig_interactions_detailed)
163 # except Exception as e:
164 # LOG.error(f"Error generating plot interactions detailed: {e}")
165
166 try:
167 fig_precision = explainer.plot_precision()
168 plots_explainer_html += add_plot_to_html(fig_precision)
169 plots_explainer_html += add_hr_to_html()
170 except Exception as e:
171 LOG.error(f"Error generating plot precision: {e}")
172
173 try:
174 fig_cumulative_precision = explainer.plot_cumulative_precision()
175 plots_explainer_html += add_plot_to_html(fig_cumulative_precision)
176 plots_explainer_html += add_hr_to_html()
177 except Exception as e:
178 LOG.error(f"Error generating plot cumulative precision: {e}")
179
180 try:
181 fig_classification = explainer.plot_classification()
182 plots_explainer_html += add_plot_to_html(fig_classification)
183 plots_explainer_html += add_hr_to_html()
184 except Exception as e:
185 LOG.error(f"Error generating plot classification: {e}")
186
187 try:
188 fig_confusion_matrix = explainer.plot_confusion_matrix()
189 plots_explainer_html += add_plot_to_html(fig_confusion_matrix)
190 plots_explainer_html += add_hr_to_html()
191 except Exception as e:
192 LOG.error(f"Error generating plot confusion matrix: {e}")
193
194 try:
195 fig_lift_curve = explainer.plot_lift_curve()
196 plots_explainer_html += add_plot_to_html(fig_lift_curve)
197 plots_explainer_html += add_hr_to_html()
198 except Exception as e:
199 LOG.error(f"Error generating plot lift curve: {e}")
200
201 try:
202 fig_roc_auc = explainer.plot_roc_auc()
203 plots_explainer_html += add_plot_to_html(fig_roc_auc)
204 plots_explainer_html += add_hr_to_html()
205 except Exception as e:
206 LOG.error(f"Error generating plot roc auc: {e}")
207
208 try:
209 fig_pr_auc = explainer.plot_pr_auc()
210 plots_explainer_html += add_plot_to_html(fig_pr_auc)
211 plots_explainer_html += add_hr_to_html()
212 except Exception as e:
213 LOG.error(f"Error generating plot pr auc: {e}")
214
215 self.plots_explainer_html = plots_explainer_html