Mercurial > repos > goeckslab > pycaret_predict
diff pycaret_regression.py @ 8:1aed7d47c5ec draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author | goeckslab |
---|---|
date | Fri, 25 Jul 2025 19:02:32 +0000 |
parents | ccd798db5abb |
children |
line wrap: on
line diff
--- a/pycaret_regression.py Wed Jul 09 01:13:01 2025 +0000 +++ b/pycaret_regression.py Fri Jul 25 19:02:32 2025 +0000 @@ -3,21 +3,21 @@ from base_model_trainer import BaseModelTrainer from dashboard import generate_regression_explainer_dashboard from pycaret.regression import RegressionExperiment -from utils import add_hr_to_html, add_plot_to_html LOG = logging.getLogger(__name__) class RegressionModelTrainer(BaseModelTrainer): def __init__( - self, - input_file, - target_col, - output_dir, - task_type, - random_seed, - test_file=None, - **kwargs): + self, + input_file, + target_col, + output_dir, + task_type, + random_seed, + test_file=None, + **kwargs, + ): super().__init__( input_file, target_col, @@ -25,24 +25,35 @@ task_type, random_seed, test_file, - **kwargs) + **kwargs, + ) + # The BaseModelTrainer.setup_pycaret will set self.exp appropriately + # But we reassign here for clarity self.exp = RegressionExperiment() def save_dashboard(self): LOG.info("Saving explainer dashboard") - dashboard = generate_regression_explainer_dashboard(self.exp, - self.best_model) + dashboard = generate_regression_explainer_dashboard(self.exp, self.best_model) dashboard.save_html("dashboard.html") def generate_plots(self): LOG.info("Generating and saving plots") - plots = ['residuals', 'error', 'cooks', - 'learning', 'vc', 'manifold', - 'rfe', 'feature', 'feature_all'] + plots = [ + "residuals", + "error", + "cooks", + "learning", + "vc", + "manifold", + "rfe", + "feature", + "feature_all", + ] for plot_name in plots: try: - plot_path = self.exp.plot_model(self.best_model, - plot=plot_name, save=True) + plot_path = self.exp.plot_model( + self.best_model, plot=plot_name, save=True + ) self.plots[plot_name] = plot_path except Exception as e: LOG.error(f"Error generating plot {plot_name}: {e}") @@ -58,79 +69,60 @@ try: explainer = RegressionExplainer(self.best_model, X_test, y_test) - self.expaliner = explainer - plots_explainer_html = "" except Exception as e: LOG.error(f"Error creating explainer: {e}") - self.plots_explainer_html = None return + # --- 1) SHAP mean impact (average absolute SHAP values) --- try: - fig_importance = explainer.plot_importances() - plots_explainer_html += add_plot_to_html(fig_importance) - plots_explainer_html += add_hr_to_html() - except Exception as e: - LOG.error(f"Error generating plot importance: {e}") - - try: - fig_importance_permutation = \ - explainer.plot_importances_permutation( - kind="permutation") - plots_explainer_html += add_plot_to_html( - fig_importance_permutation) - plots_explainer_html += add_hr_to_html() + self.explainer_plots["shap_mean"] = explainer.plot_importances() except Exception as e: - LOG.error(f"Error generating plot importance permutation: {e}") + LOG.error(f"Error generating SHAP mean importance: {e}") - try: - for feature in self.features_name: - fig_shap = explainer.plot_pdp(feature) - plots_explainer_html += add_plot_to_html(fig_shap) - plots_explainer_html += add_hr_to_html() - except Exception as e: - LOG.error(f"Error generating plot shap dependence: {e}") - - # try: - # for feature in self.features_name: - # fig_interaction = explainer.plot_interaction(col=feature) - # plots_explainer_html += add_plot_to_html(fig_interaction) - # except Exception as e: - # LOG.error(f"Error generating plot shap interaction: {e}") - + # --- 2) SHAP permutation importance --- try: - for feature in self.features_name: - fig_interactions_importance = \ - explainer.plot_interactions_importance( - col=feature) - plots_explainer_html += add_plot_to_html( - fig_interactions_importance) - plots_explainer_html += add_hr_to_html() + self.explainer_plots["shap_perm"] = explainer.plot_importances_permutation( + kind="permutation" + ) except Exception as e: - LOG.error(f"Error generating plot shap summary: {e}") + LOG.error(f"Error generating SHAP permutation importance: {e}") - # Regression specific plots - try: - fig_pred_actual = explainer.plot_predicted_vs_actual() - plots_explainer_html += add_plot_to_html(fig_pred_actual) - plots_explainer_html += add_hr_to_html() - except Exception as e: - LOG.error(f"Error generating plot prediction vs actual: {e}") + # Pre-filter features so we never call PDP or residual-vs-feature on missing cols + valid_feats = [] + for feat in self.features_name: + if feat in explainer.X.columns or feat in explainer.onehot_cols: + valid_feats.append(feat) + else: + LOG.warning(f"Skipping feature {feat!r}: not found in explainer data") - try: - fig_residuals = explainer.plot_residuals() - plots_explainer_html += add_plot_to_html(fig_residuals) - plots_explainer_html += add_hr_to_html() - except Exception as e: - LOG.error(f"Error generating plot residuals: {e}") + # --- 3) Partial Dependence Plots (PDPs) per feature --- + for feature in valid_feats: + try: + fig_pdp = explainer.plot_pdp(feature) + self.explainer_plots[f"pdp__{feature}"] = fig_pdp + except AssertionError as ae: + LOG.warning(f"PDP AssertionError for {feature!r}: {ae}") + except Exception as e: + LOG.error(f"Error generating PDP for {feature}: {e}") + # --- 4) Predicted vs Actual plot --- try: - for feature in self.features_name: - fig_residuals_vs_feature = \ - explainer.plot_residuals_vs_feature(feature) - plots_explainer_html += add_plot_to_html( - fig_residuals_vs_feature) - plots_explainer_html += add_hr_to_html() + self.explainer_plots["predicted_vs_actual"] = explainer.plot_predicted_vs_actual() + except Exception as e: + LOG.error(f"Error generating Predicted vs Actual plot: {e}") + + # --- 5) Global residuals distribution --- + try: + self.explainer_plots["residuals"] = explainer.plot_residuals() except Exception as e: - LOG.error(f"Error generating plot residuals vs feature: {e}") + LOG.error(f"Error generating Residuals plot: {e}") - self.plots_explainer_html = plots_explainer_html + # --- 6) Residuals vs each feature --- + for feature in valid_feats: + try: + fig_res_vs_feat = explainer.plot_residuals_vs_feature(feature) + self.explainer_plots[f"residuals_vs_feature__{feature}"] = fig_res_vs_feat + except AssertionError as ae: + LOG.warning(f"Residuals-vs-feature AssertionError for {feature!r}: {ae}") + except Exception as e: + LOG.error(f"Error generating Residuals vs {feature}: {e}")