Mercurial > repos > goeckslab > pycaret_predict
comparison pycaret_regression.py @ 0:1f20fe57fdee draft
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
| author | goeckslab |
|---|---|
| date | Wed, 11 Dec 2024 04:59:43 +0000 |
| parents | |
| children | 0314dad38aaa |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:1f20fe57fdee |
|---|---|
| 1 import logging | |
| 2 | |
| 3 from base_model_trainer import BaseModelTrainer | |
| 4 | |
| 5 from dashboard import generate_regression_explainer_dashboard | |
| 6 | |
| 7 from pycaret.regression import RegressionExperiment | |
| 8 | |
| 9 from utils import add_hr_to_html, add_plot_to_html | |
| 10 | |
| 11 LOG = logging.getLogger(__name__) | |
| 12 | |
| 13 | |
| 14 class RegressionModelTrainer(BaseModelTrainer): | |
| 15 def __init__( | |
| 16 self, | |
| 17 input_file, | |
| 18 target_col, | |
| 19 output_dir, | |
| 20 task_type, | |
| 21 random_seed, | |
| 22 test_file=None, | |
| 23 **kwargs): | |
| 24 super().__init__( | |
| 25 input_file, | |
| 26 target_col, | |
| 27 output_dir, | |
| 28 task_type, | |
| 29 random_seed, | |
| 30 test_file, | |
| 31 **kwargs) | |
| 32 self.exp = RegressionExperiment() | |
| 33 | |
| 34 def save_dashboard(self): | |
| 35 LOG.info("Saving explainer dashboard") | |
| 36 dashboard = generate_regression_explainer_dashboard(self.exp, | |
| 37 self.best_model) | |
| 38 dashboard.save_html("dashboard.html") | |
| 39 | |
| 40 def generate_plots(self): | |
| 41 LOG.info("Generating and saving plots") | |
| 42 plots = ['residuals', 'error', 'cooks', | |
| 43 'learning', 'vc', 'manifold', | |
| 44 'rfe', 'feature', 'feature_all'] | |
| 45 for plot_name in plots: | |
| 46 try: | |
| 47 plot_path = self.exp.plot_model(self.best_model, | |
| 48 plot=plot_name, save=True) | |
| 49 self.plots[plot_name] = plot_path | |
| 50 except Exception as e: | |
| 51 LOG.error(f"Error generating plot {plot_name}: {e}") | |
| 52 continue | |
| 53 | |
| 54 def generate_plots_explainer(self): | |
| 55 LOG.info("Generating and saving plots from explainer") | |
| 56 | |
| 57 from explainerdashboard import RegressionExplainer | |
| 58 | |
| 59 X_test = self.exp.X_test_transformed.copy() | |
| 60 y_test = self.exp.y_test_transformed | |
| 61 | |
| 62 explainer = RegressionExplainer(self.best_model, X_test, y_test) | |
| 63 self.expaliner = explainer | |
| 64 plots_explainer_html = "" | |
| 65 | |
| 66 try: | |
| 67 fig_importance = explainer.plot_importances() | |
| 68 plots_explainer_html += add_plot_to_html(fig_importance) | |
| 69 plots_explainer_html += add_hr_to_html() | |
| 70 except Exception as e: | |
| 71 LOG.error(f"Error generating plot importance: {e}") | |
| 72 | |
| 73 try: | |
| 74 fig_importance_permutation = \ | |
| 75 explainer.plot_importances_permutation( | |
| 76 kind="permutation") | |
| 77 plots_explainer_html += add_plot_to_html( | |
| 78 fig_importance_permutation) | |
| 79 plots_explainer_html += add_hr_to_html() | |
| 80 except Exception as e: | |
| 81 LOG.error(f"Error generating plot importance permutation: {e}") | |
| 82 | |
| 83 try: | |
| 84 for feature in self.features_name: | |
| 85 fig_shap = explainer.plot_pdp(feature) | |
| 86 plots_explainer_html += add_plot_to_html(fig_shap) | |
| 87 plots_explainer_html += add_hr_to_html() | |
| 88 except Exception as e: | |
| 89 LOG.error(f"Error generating plot shap dependence: {e}") | |
| 90 | |
| 91 # try: | |
| 92 # for feature in self.features_name: | |
| 93 # fig_interaction = explainer.plot_interaction(col=feature) | |
| 94 # plots_explainer_html += add_plot_to_html(fig_interaction) | |
| 95 # except Exception as e: | |
| 96 # LOG.error(f"Error generating plot shap interaction: {e}") | |
| 97 | |
| 98 try: | |
| 99 for feature in self.features_name: | |
| 100 fig_interactions_importance = \ | |
| 101 explainer.plot_interactions_importance( | |
| 102 col=feature) | |
| 103 plots_explainer_html += add_plot_to_html( | |
| 104 fig_interactions_importance) | |
| 105 plots_explainer_html += add_hr_to_html() | |
| 106 except Exception as e: | |
| 107 LOG.error(f"Error generating plot shap summary: {e}") | |
| 108 | |
| 109 # Regression specific plots | |
| 110 try: | |
| 111 fig_pred_actual = explainer.plot_predicted_vs_actual() | |
| 112 plots_explainer_html += add_plot_to_html(fig_pred_actual) | |
| 113 plots_explainer_html += add_hr_to_html() | |
| 114 except Exception as e: | |
| 115 LOG.error(f"Error generating plot prediction vs actual: {e}") | |
| 116 | |
| 117 try: | |
| 118 fig_residuals = explainer.plot_residuals() | |
| 119 plots_explainer_html += add_plot_to_html(fig_residuals) | |
| 120 plots_explainer_html += add_hr_to_html() | |
| 121 except Exception as e: | |
| 122 LOG.error(f"Error generating plot residuals: {e}") | |
| 123 | |
| 124 try: | |
| 125 for feature in self.features_name: | |
| 126 fig_residuals_vs_feature = \ | |
| 127 explainer.plot_residuals_vs_feature(feature) | |
| 128 plots_explainer_html += add_plot_to_html( | |
| 129 fig_residuals_vs_feature) | |
| 130 plots_explainer_html += add_hr_to_html() | |
| 131 except Exception as e: | |
| 132 LOG.error(f"Error generating plot residuals vs feature: {e}") | |
| 133 | |
| 134 self.plots_explainer_html = plots_explainer_html |
