Mercurial > repos > goeckslab > pycaret_predict
view pycaret_classification.py @ 12:e674b9e946fb draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
author | goeckslab |
---|---|
date | Mon, 08 Sep 2025 22:39:12 +0000 |
parents | 1aed7d47c5ec |
children |
line wrap: on
line source
import logging import types from typing import Dict import numpy as np import pandas as pd import plotly.graph_objects as go from base_model_trainer import BaseModelTrainer from dashboard import generate_classifier_explainer_dashboard from pycaret.classification import ClassificationExperiment from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve from utils import predict_proba LOG = logging.getLogger(__name__) def _apply_report_layout(fig: go.Figure) -> go.Figure: # Give the left side more space for y-axis title/ticks and let axes auto-reserve room fig.update_xaxes(automargin=True, title_standoff=12) fig.update_yaxes(automargin=True, title_standoff=12) fig.update_layout( autosize=True, margin=dict(l=120, r=40, t=60, b=60), # bump 'l' if you still see clipping ) return fig class ClassificationModelTrainer(BaseModelTrainer): def __init__( self, input_file, target_col, output_dir, task_type, random_seed, test_file=None, **kwargs, ): super().__init__( input_file, target_col, output_dir, task_type, random_seed, test_file, **kwargs, ) self.exp = ClassificationExperiment() def save_dashboard(self): LOG.info("Saving explainer dashboard") dashboard = generate_classifier_explainer_dashboard(self.exp, self.best_model) dashboard.save_html("dashboard.html") def generate_plots(self): LOG.info("Generating and saving plots") if not hasattr(self.best_model, "predict_proba"): self.best_model.predict_proba = types.MethodType( predict_proba, self.best_model ) LOG.warning( f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." ) plots = [ "auc", "threshold", "pr", "error", "class_report", "learning", "calibration", "vc", "dimension", "manifold", "rfe", "feature", "feature_all", ] for plot_name in plots: try: if plot_name == "threshold": plot_path = self.exp.plot_model( self.best_model, plot=plot_name, save=True, plot_kwargs={"binary": True, "percentage": True}, ) self.plots[plot_name] = plot_path elif plot_name == "auc" and not self.exp.is_multiclass: plot_path = self.exp.plot_model( self.best_model, plot=plot_name, save=True, plot_kwargs={ "micro": False, "macro": False, "per_class": False, "binary": True, }, ) self.plots[plot_name] = plot_path else: plot_path = self.exp.plot_model( self.best_model, plot=plot_name, save=True ) self.plots[plot_name] = plot_path except Exception as e: LOG.error(f"Error generating plot {plot_name}: {e}") continue def generate_plots_explainer(self): from explainerdashboard import ClassifierExplainer LOG.info("Generating explainer plots") # Ensure predict_proba is available here too if not hasattr(self.best_model, "predict_proba"): self.best_model.predict_proba = types.MethodType( predict_proba, self.best_model ) LOG.warning( f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." ) X_test = self.exp.X_test_transformed.copy() y_test = self.exp.y_test_transformed explainer = ClassifierExplainer(self.best_model, X_test, y_test) # a dict to hold the raw Figure objects or callables self.explainer_plots: Dict[str, go.Figure] = {} # --- Threshold-aware overrides for CM / ROC / PR --- prob_thresh = getattr(self, "probability_threshold", None) # Only for binary classification and when threshold is provided if (prob_thresh is not None) and (not self.exp.is_multiclass): X = self.exp.X_test_transformed y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True) # Get positive-class scores (robust defaults) classes = list(getattr(self.best_model, "classes_", [0, 1])) try: pos_idx = classes.index(1) if 1 in classes else 1 except Exception: pos_idx = 1 proba = self.best_model.predict_proba(X) y_scores = proba[:, pos_idx] # Derive label names consistently pos_label = classes[pos_idx] if len(classes) > pos_idx else 1 neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0 # ---- Confusion Matrix @ threshold ---- try: y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label) cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label]) fig_cm = go.Figure( data=go.Heatmap( z=cm, x=[f"Pred {neg_label}", f"Pred {pos_label}"], y=[f"True {neg_label}", f"True {pos_label}"], text=cm, texttemplate="%{text}", colorscale="Blues", showscale=False, ) ) fig_cm.update_layout( title=f"Confusion Matrix @ threshold={prob_thresh:.2f}", xaxis_title="Predicted label", yaxis_title="True label", ) _apply_report_layout(fig_cm) self.explainer_plots["confusion_matrix"] = fig_cm except Exception as e: LOG.warning( f"Threshold-aware confusion matrix failed; falling back: {e}" ) # ---- ROC with threshold marker ---- try: fpr, tpr, thr = roc_curve(y, y_scores) roc_auc = auc(fpr, tpr) fig_roc = go.Figure() fig_roc.add_scatter( x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})" ) if len(thr): mask = np.isfinite(thr) if mask.any(): idx_local = int(np.argmin(np.abs(thr[mask] - prob_thresh))) idx = np.where(mask)[0][idx_local] if 0 <= idx < len(fpr): fig_roc.add_scatter( x=[fpr[idx]], y=[tpr[idx]], mode="markers", name=f"@ {prob_thresh:.2f}", marker=dict(size=10), ) fig_roc.update_layout( title=f"ROC Curve (marker at threshold={prob_thresh:.2f})", xaxis_title="False Positive Rate", yaxis_title="True Positive Rate", ) _apply_report_layout(fig_roc) self.explainer_plots["roc_auc"] = fig_roc except Exception as e: LOG.warning(f"Threshold marker on ROC failed; falling back: {e}") # ---- PR with threshold marker ---- try: precision, recall, thr_pr = precision_recall_curve(y, y_scores) pr_auc = auc(recall, precision) fig_pr = go.Figure() fig_pr.add_scatter( x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})" ) if len(thr_pr): idx_pr = int(np.argmin(np.abs(thr_pr - prob_thresh))) # note: thr_pr has length = len(precision) - 1 idx_pr = max(0, min(idx_pr, len(recall) - 1)) fig_pr.add_scatter( x=[recall[idx_pr]], y=[precision[idx_pr]], mode="markers", name=f"@ {prob_thresh:.2f}", marker=dict(size=10), ) fig_pr.update_layout( title=f"Precision–Recall (marker at threshold={prob_thresh:.2f})", xaxis_title="Recall", yaxis_title="Precision", ) _apply_report_layout(fig_pr) self.explainer_plots["pr_auc"] = fig_pr except Exception as e: LOG.warning(f"Threshold marker on PR failed; falling back: {e}") # these go into the Test tab (don't overwrite overrides) for key, fn in [ ("roc_auc", explainer.plot_roc_auc), ("pr_auc", explainer.plot_pr_auc), ("lift_curve", explainer.plot_lift_curve), ("confusion_matrix", explainer.plot_confusion_matrix), ("threshold", explainer.plot_precision), # percentage vs probability ("cumulative_precision", explainer.plot_cumulative_precision), ]: if key in self.explainer_plots: continue try: fig = fn() if fig is not None: self.explainer_plots[key] = fig except Exception as e: LOG.error(f"Error generating explainer plot {key}: {e}") # mean SHAP importances try: self.explainer_plots["shap_mean"] = explainer.plot_importances() except Exception as e: LOG.warning(f"Could not generate shap_mean: {e}") # permutation importances try: self.explainer_plots["shap_perm"] = lambda: explainer.plot_importances( kind="permutation" ) except Exception as e: LOG.warning(f"Could not generate shap_perm: {e}") # PDPs for each feature (appended last) valid_feats = [] for feat in self.features_name: if feat in explainer.X.columns or feat in explainer.onehot_cols: valid_feats.append(feat) else: LOG.warning( f"Skipping PDP for feature {feat!r}: not found in explainer data" ) for feat in valid_feats: # wrap each PDP call to catch any unexpected AssertionErrors def make_pdp_plotter(f): def _plot(): try: return explainer.plot_pdp(f) except AssertionError as ae: LOG.warning(f"PDP AssertionError for {f!r}: {ae}") return None except Exception as e: LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}") return None return _plot self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat)