Mercurial > repos > goeckslab > tabular_learner
comparison pycaret_classification.py @ 0:209b663a4f62 draft
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
author | goeckslab |
---|---|
date | Wed, 18 Jun 2025 15:38:19 +0000 |
parents | |
children | 11fdac5affb3 |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:209b663a4f62 |
---|---|
1 import logging | |
2 | |
3 from base_model_trainer import BaseModelTrainer | |
4 from dashboard import generate_classifier_explainer_dashboard | |
5 from pycaret.classification import ClassificationExperiment | |
6 from utils import add_hr_to_html, add_plot_to_html, predict_proba | |
7 | |
8 LOG = logging.getLogger(__name__) | |
9 | |
10 | |
11 class ClassificationModelTrainer(BaseModelTrainer): | |
12 def __init__( | |
13 self, | |
14 input_file, | |
15 target_col, | |
16 output_dir, | |
17 task_type, | |
18 random_seed, | |
19 test_file=None, | |
20 **kwargs): | |
21 super().__init__( | |
22 input_file, | |
23 target_col, | |
24 output_dir, | |
25 task_type, | |
26 random_seed, | |
27 test_file, | |
28 **kwargs) | |
29 self.exp = ClassificationExperiment() | |
30 | |
31 def save_dashboard(self): | |
32 LOG.info("Saving explainer dashboard") | |
33 dashboard = generate_classifier_explainer_dashboard(self.exp, | |
34 self.best_model) | |
35 dashboard.save_html("dashboard.html") | |
36 | |
37 def generate_plots(self): | |
38 LOG.info("Generating and saving plots") | |
39 | |
40 if not hasattr(self.best_model, "predict_proba"): | |
41 import types | |
42 self.best_model.predict_proba = types.MethodType( | |
43 predict_proba, self.best_model) | |
44 LOG.warning( | |
45 f"The model {type(self.best_model).__name__}\ | |
46 does not support `predict_proba`. \ | |
47 Applying monkey patch.") | |
48 | |
49 plots = ['confusion_matrix', 'auc', 'threshold', 'pr', | |
50 'error', 'class_report', 'learning', 'calibration', | |
51 'vc', 'dimension', 'manifold', 'rfe', 'feature', | |
52 'feature_all'] | |
53 for plot_name in plots: | |
54 try: | |
55 if plot_name == 'auc' and not self.exp.is_multiclass: | |
56 plot_path = self.exp.plot_model(self.best_model, | |
57 plot=plot_name, | |
58 save=True, | |
59 plot_kwargs={ | |
60 'micro': False, | |
61 'macro': False, | |
62 'per_class': False, | |
63 'binary': True | |
64 }) | |
65 self.plots[plot_name] = plot_path | |
66 continue | |
67 | |
68 plot_path = self.exp.plot_model(self.best_model, | |
69 plot=plot_name, save=True) | |
70 self.plots[plot_name] = plot_path | |
71 except Exception as e: | |
72 LOG.error(f"Error generating plot {plot_name}: {e}") | |
73 continue | |
74 | |
75 def generate_plots_explainer(self): | |
76 LOG.info("Generating and saving plots from explainer") | |
77 | |
78 from explainerdashboard import ClassifierExplainer | |
79 | |
80 X_test = self.exp.X_test_transformed.copy() | |
81 y_test = self.exp.y_test_transformed | |
82 | |
83 try: | |
84 explainer = ClassifierExplainer(self.best_model, X_test, y_test) | |
85 self.expaliner = explainer | |
86 plots_explainer_html = "" | |
87 except Exception as e: | |
88 LOG.error(f"Error creating explainer: {e}") | |
89 self.plots_explainer_html = None | |
90 return | |
91 | |
92 try: | |
93 fig_importance = explainer.plot_importances() | |
94 plots_explainer_html += add_plot_to_html(fig_importance) | |
95 plots_explainer_html += add_hr_to_html() | |
96 except Exception as e: | |
97 LOG.error(f"Error generating plot importance(mean shap): {e}") | |
98 | |
99 try: | |
100 fig_importance_perm = explainer.plot_importances( | |
101 kind="permutation") | |
102 plots_explainer_html += add_plot_to_html(fig_importance_perm) | |
103 plots_explainer_html += add_hr_to_html() | |
104 except Exception as e: | |
105 LOG.error(f"Error generating plot importance(permutation): {e}") | |
106 | |
107 # try: | |
108 # fig_shap = explainer.plot_shap_summary() | |
109 # plots_explainer_html += add_plot_to_html(fig_shap, | |
110 # include_plotlyjs=False) | |
111 # except Exception as e: | |
112 # LOG.error(f"Error generating plot shap: {e}") | |
113 | |
114 # try: | |
115 # fig_contributions = explainer.plot_contributions( | |
116 # index=0) | |
117 # plots_explainer_html += add_plot_to_html( | |
118 # fig_contributions, include_plotlyjs=False) | |
119 # except Exception as e: | |
120 # LOG.error(f"Error generating plot contributions: {e}") | |
121 | |
122 # try: | |
123 # for feature in self.features_name: | |
124 # fig_dependence = explainer.plot_dependence(col=feature) | |
125 # plots_explainer_html += add_plot_to_html(fig_dependence) | |
126 # except Exception as e: | |
127 # LOG.error(f"Error generating plot dependencies: {e}") | |
128 | |
129 try: | |
130 for feature in self.features_name: | |
131 fig_pdp = explainer.plot_pdp(feature) | |
132 plots_explainer_html += add_plot_to_html(fig_pdp) | |
133 plots_explainer_html += add_hr_to_html() | |
134 except Exception as e: | |
135 LOG.error(f"Error generating plot pdp: {e}") | |
136 | |
137 try: | |
138 for feature in self.features_name: | |
139 fig_interaction = explainer.plot_interaction( | |
140 col=feature, interact_col=feature) | |
141 plots_explainer_html += add_plot_to_html(fig_interaction) | |
142 except Exception as e: | |
143 LOG.error(f"Error generating plot interactions: {e}") | |
144 | |
145 try: | |
146 for feature in self.features_name: | |
147 fig_interactions_importance = \ | |
148 explainer.plot_interactions_importance( | |
149 col=feature) | |
150 plots_explainer_html += add_plot_to_html( | |
151 fig_interactions_importance) | |
152 plots_explainer_html += add_hr_to_html() | |
153 except Exception as e: | |
154 LOG.error(f"Error generating plot interactions importance: {e}") | |
155 | |
156 # try: | |
157 # for feature in self.features_name: | |
158 # fig_interactions_detailed = \ | |
159 # explainer.plot_interactions_detailed( | |
160 # col=feature) | |
161 # plots_explainer_html += add_plot_to_html( | |
162 # fig_interactions_detailed) | |
163 # except Exception as e: | |
164 # LOG.error(f"Error generating plot interactions detailed: {e}") | |
165 | |
166 try: | |
167 fig_precision = explainer.plot_precision() | |
168 plots_explainer_html += add_plot_to_html(fig_precision) | |
169 plots_explainer_html += add_hr_to_html() | |
170 except Exception as e: | |
171 LOG.error(f"Error generating plot precision: {e}") | |
172 | |
173 try: | |
174 fig_cumulative_precision = explainer.plot_cumulative_precision() | |
175 plots_explainer_html += add_plot_to_html(fig_cumulative_precision) | |
176 plots_explainer_html += add_hr_to_html() | |
177 except Exception as e: | |
178 LOG.error(f"Error generating plot cumulative precision: {e}") | |
179 | |
180 try: | |
181 fig_classification = explainer.plot_classification() | |
182 plots_explainer_html += add_plot_to_html(fig_classification) | |
183 plots_explainer_html += add_hr_to_html() | |
184 except Exception as e: | |
185 LOG.error(f"Error generating plot classification: {e}") | |
186 | |
187 try: | |
188 fig_confusion_matrix = explainer.plot_confusion_matrix() | |
189 plots_explainer_html += add_plot_to_html(fig_confusion_matrix) | |
190 plots_explainer_html += add_hr_to_html() | |
191 except Exception as e: | |
192 LOG.error(f"Error generating plot confusion matrix: {e}") | |
193 | |
194 try: | |
195 fig_lift_curve = explainer.plot_lift_curve() | |
196 plots_explainer_html += add_plot_to_html(fig_lift_curve) | |
197 plots_explainer_html += add_hr_to_html() | |
198 except Exception as e: | |
199 LOG.error(f"Error generating plot lift curve: {e}") | |
200 | |
201 try: | |
202 fig_roc_auc = explainer.plot_roc_auc() | |
203 plots_explainer_html += add_plot_to_html(fig_roc_auc) | |
204 plots_explainer_html += add_hr_to_html() | |
205 except Exception as e: | |
206 LOG.error(f"Error generating plot roc auc: {e}") | |
207 | |
208 try: | |
209 fig_pr_auc = explainer.plot_pr_auc() | |
210 plots_explainer_html += add_plot_to_html(fig_pr_auc) | |
211 plots_explainer_html += add_hr_to_html() | |
212 except Exception as e: | |
213 LOG.error(f"Error generating plot pr auc: {e}") | |
214 | |
215 self.plots_explainer_html = plots_explainer_html |