comparison pycaret_regression.py @ 0:1f20fe57fdee draft

planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
author goeckslab
date Wed, 11 Dec 2024 04:59:43 +0000
parents
children 0314dad38aaa
comparison
equal deleted inserted replaced
-1:000000000000 0:1f20fe57fdee
1 import logging
2
3 from base_model_trainer import BaseModelTrainer
4
5 from dashboard import generate_regression_explainer_dashboard
6
7 from pycaret.regression import RegressionExperiment
8
9 from utils import add_hr_to_html, add_plot_to_html
10
11 LOG = logging.getLogger(__name__)
12
13
14 class RegressionModelTrainer(BaseModelTrainer):
15 def __init__(
16 self,
17 input_file,
18 target_col,
19 output_dir,
20 task_type,
21 random_seed,
22 test_file=None,
23 **kwargs):
24 super().__init__(
25 input_file,
26 target_col,
27 output_dir,
28 task_type,
29 random_seed,
30 test_file,
31 **kwargs)
32 self.exp = RegressionExperiment()
33
34 def save_dashboard(self):
35 LOG.info("Saving explainer dashboard")
36 dashboard = generate_regression_explainer_dashboard(self.exp,
37 self.best_model)
38 dashboard.save_html("dashboard.html")
39
40 def generate_plots(self):
41 LOG.info("Generating and saving plots")
42 plots = ['residuals', 'error', 'cooks',
43 'learning', 'vc', 'manifold',
44 'rfe', 'feature', 'feature_all']
45 for plot_name in plots:
46 try:
47 plot_path = self.exp.plot_model(self.best_model,
48 plot=plot_name, save=True)
49 self.plots[plot_name] = plot_path
50 except Exception as e:
51 LOG.error(f"Error generating plot {plot_name}: {e}")
52 continue
53
54 def generate_plots_explainer(self):
55 LOG.info("Generating and saving plots from explainer")
56
57 from explainerdashboard import RegressionExplainer
58
59 X_test = self.exp.X_test_transformed.copy()
60 y_test = self.exp.y_test_transformed
61
62 explainer = RegressionExplainer(self.best_model, X_test, y_test)
63 self.expaliner = explainer
64 plots_explainer_html = ""
65
66 try:
67 fig_importance = explainer.plot_importances()
68 plots_explainer_html += add_plot_to_html(fig_importance)
69 plots_explainer_html += add_hr_to_html()
70 except Exception as e:
71 LOG.error(f"Error generating plot importance: {e}")
72
73 try:
74 fig_importance_permutation = \
75 explainer.plot_importances_permutation(
76 kind="permutation")
77 plots_explainer_html += add_plot_to_html(
78 fig_importance_permutation)
79 plots_explainer_html += add_hr_to_html()
80 except Exception as e:
81 LOG.error(f"Error generating plot importance permutation: {e}")
82
83 try:
84 for feature in self.features_name:
85 fig_shap = explainer.plot_pdp(feature)
86 plots_explainer_html += add_plot_to_html(fig_shap)
87 plots_explainer_html += add_hr_to_html()
88 except Exception as e:
89 LOG.error(f"Error generating plot shap dependence: {e}")
90
91 # try:
92 # for feature in self.features_name:
93 # fig_interaction = explainer.plot_interaction(col=feature)
94 # plots_explainer_html += add_plot_to_html(fig_interaction)
95 # except Exception as e:
96 # LOG.error(f"Error generating plot shap interaction: {e}")
97
98 try:
99 for feature in self.features_name:
100 fig_interactions_importance = \
101 explainer.plot_interactions_importance(
102 col=feature)
103 plots_explainer_html += add_plot_to_html(
104 fig_interactions_importance)
105 plots_explainer_html += add_hr_to_html()
106 except Exception as e:
107 LOG.error(f"Error generating plot shap summary: {e}")
108
109 # Regression specific plots
110 try:
111 fig_pred_actual = explainer.plot_predicted_vs_actual()
112 plots_explainer_html += add_plot_to_html(fig_pred_actual)
113 plots_explainer_html += add_hr_to_html()
114 except Exception as e:
115 LOG.error(f"Error generating plot prediction vs actual: {e}")
116
117 try:
118 fig_residuals = explainer.plot_residuals()
119 plots_explainer_html += add_plot_to_html(fig_residuals)
120 plots_explainer_html += add_hr_to_html()
121 except Exception as e:
122 LOG.error(f"Error generating plot residuals: {e}")
123
124 try:
125 for feature in self.features_name:
126 fig_residuals_vs_feature = \
127 explainer.plot_residuals_vs_feature(feature)
128 plots_explainer_html += add_plot_to_html(
129 fig_residuals_vs_feature)
130 plots_explainer_html += add_hr_to_html()
131 except Exception as e:
132 LOG.error(f"Error generating plot residuals vs feature: {e}")
133
134 self.plots_explainer_html = plots_explainer_html