Mercurial > repos > goeckslab > pycaret_predict
comparison pycaret_classification.py @ 0:1f20fe57fdee draft
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
author | goeckslab |
---|---|
date | Wed, 11 Dec 2024 04:59:43 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:1f20fe57fdee |
---|---|
1 import logging | |
2 | |
3 from base_model_trainer import BaseModelTrainer | |
4 | |
5 from dashboard import generate_classifier_explainer_dashboard | |
6 | |
7 from pycaret.classification import ClassificationExperiment | |
8 | |
9 from utils import add_hr_to_html, add_plot_to_html | |
10 | |
11 LOG = logging.getLogger(__name__) | |
12 | |
13 | |
14 class ClassificationModelTrainer(BaseModelTrainer): | |
15 def __init__( | |
16 self, | |
17 input_file, | |
18 target_col, | |
19 output_dir, | |
20 task_type, | |
21 random_seed, | |
22 test_file=None, | |
23 **kwargs): | |
24 super().__init__( | |
25 input_file, | |
26 target_col, | |
27 output_dir, | |
28 task_type, | |
29 random_seed, | |
30 test_file, | |
31 **kwargs) | |
32 self.exp = ClassificationExperiment() | |
33 | |
34 def save_dashboard(self): | |
35 LOG.info("Saving explainer dashboard") | |
36 dashboard = generate_classifier_explainer_dashboard(self.exp, | |
37 self.best_model) | |
38 dashboard.save_html("dashboard.html") | |
39 | |
40 def generate_plots(self): | |
41 LOG.info("Generating and saving plots") | |
42 plots = ['confusion_matrix', 'auc', 'threshold', 'pr', | |
43 'error', 'class_report', 'learning', 'calibration', | |
44 'vc', 'dimension', 'manifold', 'rfe', 'feature', | |
45 'feature_all'] | |
46 for plot_name in plots: | |
47 try: | |
48 if plot_name == 'auc' and not self.exp.is_multiclass: | |
49 plot_path = self.exp.plot_model(self.best_model, | |
50 plot=plot_name, | |
51 save=True, | |
52 plot_kwargs={ | |
53 'micro': False, | |
54 'macro': False, | |
55 'per_class': False, | |
56 'binary': True | |
57 } | |
58 ) | |
59 self.plots[plot_name] = plot_path | |
60 continue | |
61 | |
62 plot_path = self.exp.plot_model(self.best_model, | |
63 plot=plot_name, save=True) | |
64 self.plots[plot_name] = plot_path | |
65 except Exception as e: | |
66 LOG.error(f"Error generating plot {plot_name}: {e}") | |
67 continue | |
68 | |
69 def generate_plots_explainer(self): | |
70 LOG.info("Generating and saving plots from explainer") | |
71 | |
72 from explainerdashboard import ClassifierExplainer | |
73 | |
74 X_test = self.exp.X_test_transformed.copy() | |
75 y_test = self.exp.y_test_transformed | |
76 | |
77 explainer = ClassifierExplainer(self.best_model, X_test, y_test) | |
78 self.expaliner = explainer | |
79 plots_explainer_html = "" | |
80 | |
81 try: | |
82 fig_importance = explainer.plot_importances() | |
83 plots_explainer_html += add_plot_to_html(fig_importance) | |
84 plots_explainer_html += add_hr_to_html() | |
85 except Exception as e: | |
86 LOG.error(f"Error generating plot importance(mean shap): {e}") | |
87 | |
88 try: | |
89 fig_importance_perm = explainer.plot_importances( | |
90 kind="permutation") | |
91 plots_explainer_html += add_plot_to_html(fig_importance_perm) | |
92 plots_explainer_html += add_hr_to_html() | |
93 except Exception as e: | |
94 LOG.error(f"Error generating plot importance(permutation): {e}") | |
95 | |
96 # try: | |
97 # fig_shap = explainer.plot_shap_summary() | |
98 # plots_explainer_html += add_plot_to_html(fig_shap, | |
99 # include_plotlyjs=False) | |
100 # except Exception as e: | |
101 # LOG.error(f"Error generating plot shap: {e}") | |
102 | |
103 # try: | |
104 # fig_contributions = explainer.plot_contributions( | |
105 # index=0) | |
106 # plots_explainer_html += add_plot_to_html( | |
107 # fig_contributions, include_plotlyjs=False) | |
108 # except Exception as e: | |
109 # LOG.error(f"Error generating plot contributions: {e}") | |
110 | |
111 # try: | |
112 # for feature in self.features_name: | |
113 # fig_dependence = explainer.plot_dependence(col=feature) | |
114 # plots_explainer_html += add_plot_to_html(fig_dependence) | |
115 # except Exception as e: | |
116 # LOG.error(f"Error generating plot dependencies: {e}") | |
117 | |
118 try: | |
119 for feature in self.features_name: | |
120 fig_pdp = explainer.plot_pdp(feature) | |
121 plots_explainer_html += add_plot_to_html(fig_pdp) | |
122 plots_explainer_html += add_hr_to_html() | |
123 except Exception as e: | |
124 LOG.error(f"Error generating plot pdp: {e}") | |
125 | |
126 try: | |
127 for feature in self.features_name: | |
128 fig_interaction = explainer.plot_interaction( | |
129 col=feature, interact_col=feature) | |
130 plots_explainer_html += add_plot_to_html(fig_interaction) | |
131 except Exception as e: | |
132 LOG.error(f"Error generating plot interactions: {e}") | |
133 | |
134 try: | |
135 for feature in self.features_name: | |
136 fig_interactions_importance = \ | |
137 explainer.plot_interactions_importance( | |
138 col=feature) | |
139 plots_explainer_html += add_plot_to_html( | |
140 fig_interactions_importance) | |
141 plots_explainer_html += add_hr_to_html() | |
142 except Exception as e: | |
143 LOG.error(f"Error generating plot interactions importance: {e}") | |
144 | |
145 # try: | |
146 # for feature in self.features_name: | |
147 # fig_interactions_detailed = \ | |
148 # explainer.plot_interactions_detailed( | |
149 # col=feature) | |
150 # plots_explainer_html += add_plot_to_html( | |
151 # fig_interactions_detailed) | |
152 # except Exception as e: | |
153 # LOG.error(f"Error generating plot interactions detailed: {e}") | |
154 | |
155 try: | |
156 fig_precision = explainer.plot_precision() | |
157 plots_explainer_html += add_plot_to_html(fig_precision) | |
158 plots_explainer_html += add_hr_to_html() | |
159 except Exception as e: | |
160 LOG.error(f"Error generating plot precision: {e}") | |
161 | |
162 try: | |
163 fig_cumulative_precision = explainer.plot_cumulative_precision() | |
164 plots_explainer_html += add_plot_to_html(fig_cumulative_precision) | |
165 plots_explainer_html += add_hr_to_html() | |
166 except Exception as e: | |
167 LOG.error(f"Error generating plot cumulative precision: {e}") | |
168 | |
169 try: | |
170 fig_classification = explainer.plot_classification() | |
171 plots_explainer_html += add_plot_to_html(fig_classification) | |
172 plots_explainer_html += add_hr_to_html() | |
173 except Exception as e: | |
174 LOG.error(f"Error generating plot classification: {e}") | |
175 | |
176 try: | |
177 fig_confusion_matrix = explainer.plot_confusion_matrix() | |
178 plots_explainer_html += add_plot_to_html(fig_confusion_matrix) | |
179 plots_explainer_html += add_hr_to_html() | |
180 except Exception as e: | |
181 LOG.error(f"Error generating plot confusion matrix: {e}") | |
182 | |
183 try: | |
184 fig_lift_curve = explainer.plot_lift_curve() | |
185 plots_explainer_html += add_plot_to_html(fig_lift_curve) | |
186 plots_explainer_html += add_hr_to_html() | |
187 except Exception as e: | |
188 LOG.error(f"Error generating plot lift curve: {e}") | |
189 | |
190 try: | |
191 fig_roc_auc = explainer.plot_roc_auc() | |
192 plots_explainer_html += add_plot_to_html(fig_roc_auc) | |
193 plots_explainer_html += add_hr_to_html() | |
194 except Exception as e: | |
195 LOG.error(f"Error generating plot roc auc: {e}") | |
196 | |
197 try: | |
198 fig_pr_auc = explainer.plot_pr_auc() | |
199 plots_explainer_html += add_plot_to_html(fig_pr_auc) | |
200 plots_explainer_html += add_hr_to_html() | |
201 except Exception as e: | |
202 LOG.error(f"Error generating plot pr auc: {e}") | |
203 | |
204 self.plots_explainer_html = plots_explainer_html |