Mercurial > repos > goeckslab > tabular_learner
comparison pycaret_classification.py @ 0:209b663a4f62 draft
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
| author | goeckslab |
|---|---|
| date | Wed, 18 Jun 2025 15:38:19 +0000 |
| parents | |
| children | 11fdac5affb3 |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:209b663a4f62 |
|---|---|
| 1 import logging | |
| 2 | |
| 3 from base_model_trainer import BaseModelTrainer | |
| 4 from dashboard import generate_classifier_explainer_dashboard | |
| 5 from pycaret.classification import ClassificationExperiment | |
| 6 from utils import add_hr_to_html, add_plot_to_html, predict_proba | |
| 7 | |
| 8 LOG = logging.getLogger(__name__) | |
| 9 | |
| 10 | |
| 11 class ClassificationModelTrainer(BaseModelTrainer): | |
| 12 def __init__( | |
| 13 self, | |
| 14 input_file, | |
| 15 target_col, | |
| 16 output_dir, | |
| 17 task_type, | |
| 18 random_seed, | |
| 19 test_file=None, | |
| 20 **kwargs): | |
| 21 super().__init__( | |
| 22 input_file, | |
| 23 target_col, | |
| 24 output_dir, | |
| 25 task_type, | |
| 26 random_seed, | |
| 27 test_file, | |
| 28 **kwargs) | |
| 29 self.exp = ClassificationExperiment() | |
| 30 | |
| 31 def save_dashboard(self): | |
| 32 LOG.info("Saving explainer dashboard") | |
| 33 dashboard = generate_classifier_explainer_dashboard(self.exp, | |
| 34 self.best_model) | |
| 35 dashboard.save_html("dashboard.html") | |
| 36 | |
| 37 def generate_plots(self): | |
| 38 LOG.info("Generating and saving plots") | |
| 39 | |
| 40 if not hasattr(self.best_model, "predict_proba"): | |
| 41 import types | |
| 42 self.best_model.predict_proba = types.MethodType( | |
| 43 predict_proba, self.best_model) | |
| 44 LOG.warning( | |
| 45 f"The model {type(self.best_model).__name__}\ | |
| 46 does not support `predict_proba`. \ | |
| 47 Applying monkey patch.") | |
| 48 | |
| 49 plots = ['confusion_matrix', 'auc', 'threshold', 'pr', | |
| 50 'error', 'class_report', 'learning', 'calibration', | |
| 51 'vc', 'dimension', 'manifold', 'rfe', 'feature', | |
| 52 'feature_all'] | |
| 53 for plot_name in plots: | |
| 54 try: | |
| 55 if plot_name == 'auc' and not self.exp.is_multiclass: | |
| 56 plot_path = self.exp.plot_model(self.best_model, | |
| 57 plot=plot_name, | |
| 58 save=True, | |
| 59 plot_kwargs={ | |
| 60 'micro': False, | |
| 61 'macro': False, | |
| 62 'per_class': False, | |
| 63 'binary': True | |
| 64 }) | |
| 65 self.plots[plot_name] = plot_path | |
| 66 continue | |
| 67 | |
| 68 plot_path = self.exp.plot_model(self.best_model, | |
| 69 plot=plot_name, save=True) | |
| 70 self.plots[plot_name] = plot_path | |
| 71 except Exception as e: | |
| 72 LOG.error(f"Error generating plot {plot_name}: {e}") | |
| 73 continue | |
| 74 | |
| 75 def generate_plots_explainer(self): | |
| 76 LOG.info("Generating and saving plots from explainer") | |
| 77 | |
| 78 from explainerdashboard import ClassifierExplainer | |
| 79 | |
| 80 X_test = self.exp.X_test_transformed.copy() | |
| 81 y_test = self.exp.y_test_transformed | |
| 82 | |
| 83 try: | |
| 84 explainer = ClassifierExplainer(self.best_model, X_test, y_test) | |
| 85 self.expaliner = explainer | |
| 86 plots_explainer_html = "" | |
| 87 except Exception as e: | |
| 88 LOG.error(f"Error creating explainer: {e}") | |
| 89 self.plots_explainer_html = None | |
| 90 return | |
| 91 | |
| 92 try: | |
| 93 fig_importance = explainer.plot_importances() | |
| 94 plots_explainer_html += add_plot_to_html(fig_importance) | |
| 95 plots_explainer_html += add_hr_to_html() | |
| 96 except Exception as e: | |
| 97 LOG.error(f"Error generating plot importance(mean shap): {e}") | |
| 98 | |
| 99 try: | |
| 100 fig_importance_perm = explainer.plot_importances( | |
| 101 kind="permutation") | |
| 102 plots_explainer_html += add_plot_to_html(fig_importance_perm) | |
| 103 plots_explainer_html += add_hr_to_html() | |
| 104 except Exception as e: | |
| 105 LOG.error(f"Error generating plot importance(permutation): {e}") | |
| 106 | |
| 107 # try: | |
| 108 # fig_shap = explainer.plot_shap_summary() | |
| 109 # plots_explainer_html += add_plot_to_html(fig_shap, | |
| 110 # include_plotlyjs=False) | |
| 111 # except Exception as e: | |
| 112 # LOG.error(f"Error generating plot shap: {e}") | |
| 113 | |
| 114 # try: | |
| 115 # fig_contributions = explainer.plot_contributions( | |
| 116 # index=0) | |
| 117 # plots_explainer_html += add_plot_to_html( | |
| 118 # fig_contributions, include_plotlyjs=False) | |
| 119 # except Exception as e: | |
| 120 # LOG.error(f"Error generating plot contributions: {e}") | |
| 121 | |
| 122 # try: | |
| 123 # for feature in self.features_name: | |
| 124 # fig_dependence = explainer.plot_dependence(col=feature) | |
| 125 # plots_explainer_html += add_plot_to_html(fig_dependence) | |
| 126 # except Exception as e: | |
| 127 # LOG.error(f"Error generating plot dependencies: {e}") | |
| 128 | |
| 129 try: | |
| 130 for feature in self.features_name: | |
| 131 fig_pdp = explainer.plot_pdp(feature) | |
| 132 plots_explainer_html += add_plot_to_html(fig_pdp) | |
| 133 plots_explainer_html += add_hr_to_html() | |
| 134 except Exception as e: | |
| 135 LOG.error(f"Error generating plot pdp: {e}") | |
| 136 | |
| 137 try: | |
| 138 for feature in self.features_name: | |
| 139 fig_interaction = explainer.plot_interaction( | |
| 140 col=feature, interact_col=feature) | |
| 141 plots_explainer_html += add_plot_to_html(fig_interaction) | |
| 142 except Exception as e: | |
| 143 LOG.error(f"Error generating plot interactions: {e}") | |
| 144 | |
| 145 try: | |
| 146 for feature in self.features_name: | |
| 147 fig_interactions_importance = \ | |
| 148 explainer.plot_interactions_importance( | |
| 149 col=feature) | |
| 150 plots_explainer_html += add_plot_to_html( | |
| 151 fig_interactions_importance) | |
| 152 plots_explainer_html += add_hr_to_html() | |
| 153 except Exception as e: | |
| 154 LOG.error(f"Error generating plot interactions importance: {e}") | |
| 155 | |
| 156 # try: | |
| 157 # for feature in self.features_name: | |
| 158 # fig_interactions_detailed = \ | |
| 159 # explainer.plot_interactions_detailed( | |
| 160 # col=feature) | |
| 161 # plots_explainer_html += add_plot_to_html( | |
| 162 # fig_interactions_detailed) | |
| 163 # except Exception as e: | |
| 164 # LOG.error(f"Error generating plot interactions detailed: {e}") | |
| 165 | |
| 166 try: | |
| 167 fig_precision = explainer.plot_precision() | |
| 168 plots_explainer_html += add_plot_to_html(fig_precision) | |
| 169 plots_explainer_html += add_hr_to_html() | |
| 170 except Exception as e: | |
| 171 LOG.error(f"Error generating plot precision: {e}") | |
| 172 | |
| 173 try: | |
| 174 fig_cumulative_precision = explainer.plot_cumulative_precision() | |
| 175 plots_explainer_html += add_plot_to_html(fig_cumulative_precision) | |
| 176 plots_explainer_html += add_hr_to_html() | |
| 177 except Exception as e: | |
| 178 LOG.error(f"Error generating plot cumulative precision: {e}") | |
| 179 | |
| 180 try: | |
| 181 fig_classification = explainer.plot_classification() | |
| 182 plots_explainer_html += add_plot_to_html(fig_classification) | |
| 183 plots_explainer_html += add_hr_to_html() | |
| 184 except Exception as e: | |
| 185 LOG.error(f"Error generating plot classification: {e}") | |
| 186 | |
| 187 try: | |
| 188 fig_confusion_matrix = explainer.plot_confusion_matrix() | |
| 189 plots_explainer_html += add_plot_to_html(fig_confusion_matrix) | |
| 190 plots_explainer_html += add_hr_to_html() | |
| 191 except Exception as e: | |
| 192 LOG.error(f"Error generating plot confusion matrix: {e}") | |
| 193 | |
| 194 try: | |
| 195 fig_lift_curve = explainer.plot_lift_curve() | |
| 196 plots_explainer_html += add_plot_to_html(fig_lift_curve) | |
| 197 plots_explainer_html += add_hr_to_html() | |
| 198 except Exception as e: | |
| 199 LOG.error(f"Error generating plot lift curve: {e}") | |
| 200 | |
| 201 try: | |
| 202 fig_roc_auc = explainer.plot_roc_auc() | |
| 203 plots_explainer_html += add_plot_to_html(fig_roc_auc) | |
| 204 plots_explainer_html += add_hr_to_html() | |
| 205 except Exception as e: | |
| 206 LOG.error(f"Error generating plot roc auc: {e}") | |
| 207 | |
| 208 try: | |
| 209 fig_pr_auc = explainer.plot_pr_auc() | |
| 210 plots_explainer_html += add_plot_to_html(fig_pr_auc) | |
| 211 plots_explainer_html += add_hr_to_html() | |
| 212 except Exception as e: | |
| 213 LOG.error(f"Error generating plot pr auc: {e}") | |
| 214 | |
| 215 self.plots_explainer_html = plots_explainer_html |
