comparison pycaret_regression.py @ 0:209b663a4f62 draft

planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
author goeckslab
date Wed, 18 Jun 2025 15:38:19 +0000
parents
children 11fdac5affb3
comparison
equal deleted inserted replaced
-1:000000000000 0:209b663a4f62
1 import logging
2
3 from base_model_trainer import BaseModelTrainer
4 from dashboard import generate_regression_explainer_dashboard
5 from pycaret.regression import RegressionExperiment
6 from utils import add_hr_to_html, add_plot_to_html
7
8 LOG = logging.getLogger(__name__)
9
10
11 class RegressionModelTrainer(BaseModelTrainer):
12 def __init__(
13 self,
14 input_file,
15 target_col,
16 output_dir,
17 task_type,
18 random_seed,
19 test_file=None,
20 **kwargs):
21 super().__init__(
22 input_file,
23 target_col,
24 output_dir,
25 task_type,
26 random_seed,
27 test_file,
28 **kwargs)
29 self.exp = RegressionExperiment()
30
31 def save_dashboard(self):
32 LOG.info("Saving explainer dashboard")
33 dashboard = generate_regression_explainer_dashboard(self.exp,
34 self.best_model)
35 dashboard.save_html("dashboard.html")
36
37 def generate_plots(self):
38 LOG.info("Generating and saving plots")
39 plots = ['residuals', 'error', 'cooks',
40 'learning', 'vc', 'manifold',
41 'rfe', 'feature', 'feature_all']
42 for plot_name in plots:
43 try:
44 plot_path = self.exp.plot_model(self.best_model,
45 plot=plot_name, save=True)
46 self.plots[plot_name] = plot_path
47 except Exception as e:
48 LOG.error(f"Error generating plot {plot_name}: {e}")
49 continue
50
51 def generate_plots_explainer(self):
52 LOG.info("Generating and saving plots from explainer")
53
54 from explainerdashboard import RegressionExplainer
55
56 X_test = self.exp.X_test_transformed.copy()
57 y_test = self.exp.y_test_transformed
58
59 try:
60 explainer = RegressionExplainer(self.best_model, X_test, y_test)
61 self.expaliner = explainer
62 plots_explainer_html = ""
63 except Exception as e:
64 LOG.error(f"Error creating explainer: {e}")
65 self.plots_explainer_html = None
66 return
67
68 try:
69 fig_importance = explainer.plot_importances()
70 plots_explainer_html += add_plot_to_html(fig_importance)
71 plots_explainer_html += add_hr_to_html()
72 except Exception as e:
73 LOG.error(f"Error generating plot importance: {e}")
74
75 try:
76 fig_importance_permutation = \
77 explainer.plot_importances_permutation(
78 kind="permutation")
79 plots_explainer_html += add_plot_to_html(
80 fig_importance_permutation)
81 plots_explainer_html += add_hr_to_html()
82 except Exception as e:
83 LOG.error(f"Error generating plot importance permutation: {e}")
84
85 try:
86 for feature in self.features_name:
87 fig_shap = explainer.plot_pdp(feature)
88 plots_explainer_html += add_plot_to_html(fig_shap)
89 plots_explainer_html += add_hr_to_html()
90 except Exception as e:
91 LOG.error(f"Error generating plot shap dependence: {e}")
92
93 # try:
94 # for feature in self.features_name:
95 # fig_interaction = explainer.plot_interaction(col=feature)
96 # plots_explainer_html += add_plot_to_html(fig_interaction)
97 # except Exception as e:
98 # LOG.error(f"Error generating plot shap interaction: {e}")
99
100 try:
101 for feature in self.features_name:
102 fig_interactions_importance = \
103 explainer.plot_interactions_importance(
104 col=feature)
105 plots_explainer_html += add_plot_to_html(
106 fig_interactions_importance)
107 plots_explainer_html += add_hr_to_html()
108 except Exception as e:
109 LOG.error(f"Error generating plot shap summary: {e}")
110
111 # Regression specific plots
112 try:
113 fig_pred_actual = explainer.plot_predicted_vs_actual()
114 plots_explainer_html += add_plot_to_html(fig_pred_actual)
115 plots_explainer_html += add_hr_to_html()
116 except Exception as e:
117 LOG.error(f"Error generating plot prediction vs actual: {e}")
118
119 try:
120 fig_residuals = explainer.plot_residuals()
121 plots_explainer_html += add_plot_to_html(fig_residuals)
122 plots_explainer_html += add_hr_to_html()
123 except Exception as e:
124 LOG.error(f"Error generating plot residuals: {e}")
125
126 try:
127 for feature in self.features_name:
128 fig_residuals_vs_feature = \
129 explainer.plot_residuals_vs_feature(feature)
130 plots_explainer_html += add_plot_to_html(
131 fig_residuals_vs_feature)
132 plots_explainer_html += add_hr_to_html()
133 except Exception as e:
134 LOG.error(f"Error generating plot residuals vs feature: {e}")
135
136 self.plots_explainer_html = plots_explainer_html