Mercurial > repos > goeckslab > tabular_learner
comparison pycaret_regression.py @ 0:209b663a4f62 draft
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
author | goeckslab |
---|---|
date | Wed, 18 Jun 2025 15:38:19 +0000 |
parents | |
children | 11fdac5affb3 |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:209b663a4f62 |
---|---|
1 import logging | |
2 | |
3 from base_model_trainer import BaseModelTrainer | |
4 from dashboard import generate_regression_explainer_dashboard | |
5 from pycaret.regression import RegressionExperiment | |
6 from utils import add_hr_to_html, add_plot_to_html | |
7 | |
8 LOG = logging.getLogger(__name__) | |
9 | |
10 | |
11 class RegressionModelTrainer(BaseModelTrainer): | |
12 def __init__( | |
13 self, | |
14 input_file, | |
15 target_col, | |
16 output_dir, | |
17 task_type, | |
18 random_seed, | |
19 test_file=None, | |
20 **kwargs): | |
21 super().__init__( | |
22 input_file, | |
23 target_col, | |
24 output_dir, | |
25 task_type, | |
26 random_seed, | |
27 test_file, | |
28 **kwargs) | |
29 self.exp = RegressionExperiment() | |
30 | |
31 def save_dashboard(self): | |
32 LOG.info("Saving explainer dashboard") | |
33 dashboard = generate_regression_explainer_dashboard(self.exp, | |
34 self.best_model) | |
35 dashboard.save_html("dashboard.html") | |
36 | |
37 def generate_plots(self): | |
38 LOG.info("Generating and saving plots") | |
39 plots = ['residuals', 'error', 'cooks', | |
40 'learning', 'vc', 'manifold', | |
41 'rfe', 'feature', 'feature_all'] | |
42 for plot_name in plots: | |
43 try: | |
44 plot_path = self.exp.plot_model(self.best_model, | |
45 plot=plot_name, save=True) | |
46 self.plots[plot_name] = plot_path | |
47 except Exception as e: | |
48 LOG.error(f"Error generating plot {plot_name}: {e}") | |
49 continue | |
50 | |
51 def generate_plots_explainer(self): | |
52 LOG.info("Generating and saving plots from explainer") | |
53 | |
54 from explainerdashboard import RegressionExplainer | |
55 | |
56 X_test = self.exp.X_test_transformed.copy() | |
57 y_test = self.exp.y_test_transformed | |
58 | |
59 try: | |
60 explainer = RegressionExplainer(self.best_model, X_test, y_test) | |
61 self.expaliner = explainer | |
62 plots_explainer_html = "" | |
63 except Exception as e: | |
64 LOG.error(f"Error creating explainer: {e}") | |
65 self.plots_explainer_html = None | |
66 return | |
67 | |
68 try: | |
69 fig_importance = explainer.plot_importances() | |
70 plots_explainer_html += add_plot_to_html(fig_importance) | |
71 plots_explainer_html += add_hr_to_html() | |
72 except Exception as e: | |
73 LOG.error(f"Error generating plot importance: {e}") | |
74 | |
75 try: | |
76 fig_importance_permutation = \ | |
77 explainer.plot_importances_permutation( | |
78 kind="permutation") | |
79 plots_explainer_html += add_plot_to_html( | |
80 fig_importance_permutation) | |
81 plots_explainer_html += add_hr_to_html() | |
82 except Exception as e: | |
83 LOG.error(f"Error generating plot importance permutation: {e}") | |
84 | |
85 try: | |
86 for feature in self.features_name: | |
87 fig_shap = explainer.plot_pdp(feature) | |
88 plots_explainer_html += add_plot_to_html(fig_shap) | |
89 plots_explainer_html += add_hr_to_html() | |
90 except Exception as e: | |
91 LOG.error(f"Error generating plot shap dependence: {e}") | |
92 | |
93 # try: | |
94 # for feature in self.features_name: | |
95 # fig_interaction = explainer.plot_interaction(col=feature) | |
96 # plots_explainer_html += add_plot_to_html(fig_interaction) | |
97 # except Exception as e: | |
98 # LOG.error(f"Error generating plot shap interaction: {e}") | |
99 | |
100 try: | |
101 for feature in self.features_name: | |
102 fig_interactions_importance = \ | |
103 explainer.plot_interactions_importance( | |
104 col=feature) | |
105 plots_explainer_html += add_plot_to_html( | |
106 fig_interactions_importance) | |
107 plots_explainer_html += add_hr_to_html() | |
108 except Exception as e: | |
109 LOG.error(f"Error generating plot shap summary: {e}") | |
110 | |
111 # Regression specific plots | |
112 try: | |
113 fig_pred_actual = explainer.plot_predicted_vs_actual() | |
114 plots_explainer_html += add_plot_to_html(fig_pred_actual) | |
115 plots_explainer_html += add_hr_to_html() | |
116 except Exception as e: | |
117 LOG.error(f"Error generating plot prediction vs actual: {e}") | |
118 | |
119 try: | |
120 fig_residuals = explainer.plot_residuals() | |
121 plots_explainer_html += add_plot_to_html(fig_residuals) | |
122 plots_explainer_html += add_hr_to_html() | |
123 except Exception as e: | |
124 LOG.error(f"Error generating plot residuals: {e}") | |
125 | |
126 try: | |
127 for feature in self.features_name: | |
128 fig_residuals_vs_feature = \ | |
129 explainer.plot_residuals_vs_feature(feature) | |
130 plots_explainer_html += add_plot_to_html( | |
131 fig_residuals_vs_feature) | |
132 plots_explainer_html += add_hr_to_html() | |
133 except Exception as e: | |
134 LOG.error(f"Error generating plot residuals vs feature: {e}") | |
135 | |
136 self.plots_explainer_html = plots_explainer_html |