comparison pycaret_classification.py @ 0:209b663a4f62 draft

planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
author goeckslab
date Wed, 18 Jun 2025 15:38:19 +0000
parents
children 11fdac5affb3
comparison
equal deleted inserted replaced
-1:000000000000 0:209b663a4f62
1 import logging
2
3 from base_model_trainer import BaseModelTrainer
4 from dashboard import generate_classifier_explainer_dashboard
5 from pycaret.classification import ClassificationExperiment
6 from utils import add_hr_to_html, add_plot_to_html, predict_proba
7
8 LOG = logging.getLogger(__name__)
9
10
11 class ClassificationModelTrainer(BaseModelTrainer):
12 def __init__(
13 self,
14 input_file,
15 target_col,
16 output_dir,
17 task_type,
18 random_seed,
19 test_file=None,
20 **kwargs):
21 super().__init__(
22 input_file,
23 target_col,
24 output_dir,
25 task_type,
26 random_seed,
27 test_file,
28 **kwargs)
29 self.exp = ClassificationExperiment()
30
31 def save_dashboard(self):
32 LOG.info("Saving explainer dashboard")
33 dashboard = generate_classifier_explainer_dashboard(self.exp,
34 self.best_model)
35 dashboard.save_html("dashboard.html")
36
37 def generate_plots(self):
38 LOG.info("Generating and saving plots")
39
40 if not hasattr(self.best_model, "predict_proba"):
41 import types
42 self.best_model.predict_proba = types.MethodType(
43 predict_proba, self.best_model)
44 LOG.warning(
45 f"The model {type(self.best_model).__name__}\
46 does not support `predict_proba`. \
47 Applying monkey patch.")
48
49 plots = ['confusion_matrix', 'auc', 'threshold', 'pr',
50 'error', 'class_report', 'learning', 'calibration',
51 'vc', 'dimension', 'manifold', 'rfe', 'feature',
52 'feature_all']
53 for plot_name in plots:
54 try:
55 if plot_name == 'auc' and not self.exp.is_multiclass:
56 plot_path = self.exp.plot_model(self.best_model,
57 plot=plot_name,
58 save=True,
59 plot_kwargs={
60 'micro': False,
61 'macro': False,
62 'per_class': False,
63 'binary': True
64 })
65 self.plots[plot_name] = plot_path
66 continue
67
68 plot_path = self.exp.plot_model(self.best_model,
69 plot=plot_name, save=True)
70 self.plots[plot_name] = plot_path
71 except Exception as e:
72 LOG.error(f"Error generating plot {plot_name}: {e}")
73 continue
74
75 def generate_plots_explainer(self):
76 LOG.info("Generating and saving plots from explainer")
77
78 from explainerdashboard import ClassifierExplainer
79
80 X_test = self.exp.X_test_transformed.copy()
81 y_test = self.exp.y_test_transformed
82
83 try:
84 explainer = ClassifierExplainer(self.best_model, X_test, y_test)
85 self.expaliner = explainer
86 plots_explainer_html = ""
87 except Exception as e:
88 LOG.error(f"Error creating explainer: {e}")
89 self.plots_explainer_html = None
90 return
91
92 try:
93 fig_importance = explainer.plot_importances()
94 plots_explainer_html += add_plot_to_html(fig_importance)
95 plots_explainer_html += add_hr_to_html()
96 except Exception as e:
97 LOG.error(f"Error generating plot importance(mean shap): {e}")
98
99 try:
100 fig_importance_perm = explainer.plot_importances(
101 kind="permutation")
102 plots_explainer_html += add_plot_to_html(fig_importance_perm)
103 plots_explainer_html += add_hr_to_html()
104 except Exception as e:
105 LOG.error(f"Error generating plot importance(permutation): {e}")
106
107 # try:
108 # fig_shap = explainer.plot_shap_summary()
109 # plots_explainer_html += add_plot_to_html(fig_shap,
110 # include_plotlyjs=False)
111 # except Exception as e:
112 # LOG.error(f"Error generating plot shap: {e}")
113
114 # try:
115 # fig_contributions = explainer.plot_contributions(
116 # index=0)
117 # plots_explainer_html += add_plot_to_html(
118 # fig_contributions, include_plotlyjs=False)
119 # except Exception as e:
120 # LOG.error(f"Error generating plot contributions: {e}")
121
122 # try:
123 # for feature in self.features_name:
124 # fig_dependence = explainer.plot_dependence(col=feature)
125 # plots_explainer_html += add_plot_to_html(fig_dependence)
126 # except Exception as e:
127 # LOG.error(f"Error generating plot dependencies: {e}")
128
129 try:
130 for feature in self.features_name:
131 fig_pdp = explainer.plot_pdp(feature)
132 plots_explainer_html += add_plot_to_html(fig_pdp)
133 plots_explainer_html += add_hr_to_html()
134 except Exception as e:
135 LOG.error(f"Error generating plot pdp: {e}")
136
137 try:
138 for feature in self.features_name:
139 fig_interaction = explainer.plot_interaction(
140 col=feature, interact_col=feature)
141 plots_explainer_html += add_plot_to_html(fig_interaction)
142 except Exception as e:
143 LOG.error(f"Error generating plot interactions: {e}")
144
145 try:
146 for feature in self.features_name:
147 fig_interactions_importance = \
148 explainer.plot_interactions_importance(
149 col=feature)
150 plots_explainer_html += add_plot_to_html(
151 fig_interactions_importance)
152 plots_explainer_html += add_hr_to_html()
153 except Exception as e:
154 LOG.error(f"Error generating plot interactions importance: {e}")
155
156 # try:
157 # for feature in self.features_name:
158 # fig_interactions_detailed = \
159 # explainer.plot_interactions_detailed(
160 # col=feature)
161 # plots_explainer_html += add_plot_to_html(
162 # fig_interactions_detailed)
163 # except Exception as e:
164 # LOG.error(f"Error generating plot interactions detailed: {e}")
165
166 try:
167 fig_precision = explainer.plot_precision()
168 plots_explainer_html += add_plot_to_html(fig_precision)
169 plots_explainer_html += add_hr_to_html()
170 except Exception as e:
171 LOG.error(f"Error generating plot precision: {e}")
172
173 try:
174 fig_cumulative_precision = explainer.plot_cumulative_precision()
175 plots_explainer_html += add_plot_to_html(fig_cumulative_precision)
176 plots_explainer_html += add_hr_to_html()
177 except Exception as e:
178 LOG.error(f"Error generating plot cumulative precision: {e}")
179
180 try:
181 fig_classification = explainer.plot_classification()
182 plots_explainer_html += add_plot_to_html(fig_classification)
183 plots_explainer_html += add_hr_to_html()
184 except Exception as e:
185 LOG.error(f"Error generating plot classification: {e}")
186
187 try:
188 fig_confusion_matrix = explainer.plot_confusion_matrix()
189 plots_explainer_html += add_plot_to_html(fig_confusion_matrix)
190 plots_explainer_html += add_hr_to_html()
191 except Exception as e:
192 LOG.error(f"Error generating plot confusion matrix: {e}")
193
194 try:
195 fig_lift_curve = explainer.plot_lift_curve()
196 plots_explainer_html += add_plot_to_html(fig_lift_curve)
197 plots_explainer_html += add_hr_to_html()
198 except Exception as e:
199 LOG.error(f"Error generating plot lift curve: {e}")
200
201 try:
202 fig_roc_auc = explainer.plot_roc_auc()
203 plots_explainer_html += add_plot_to_html(fig_roc_auc)
204 plots_explainer_html += add_hr_to_html()
205 except Exception as e:
206 LOG.error(f"Error generating plot roc auc: {e}")
207
208 try:
209 fig_pr_auc = explainer.plot_pr_auc()
210 plots_explainer_html += add_plot_to_html(fig_pr_auc)
211 plots_explainer_html += add_hr_to_html()
212 except Exception as e:
213 LOG.error(f"Error generating plot pr auc: {e}")
214
215 self.plots_explainer_html = plots_explainer_html