Mercurial > repos > goeckslab > pycaret_predict
diff pycaret_classification.py @ 12:e674b9e946fb draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
author | goeckslab |
---|---|
date | Mon, 08 Sep 2025 22:39:12 +0000 |
parents | 1aed7d47c5ec |
children |
line wrap: on
line diff
--- a/pycaret_classification.py Fri Aug 22 21:13:30 2025 +0000 +++ b/pycaret_classification.py Mon Sep 08 22:39:12 2025 +0000 @@ -2,15 +2,29 @@ import types from typing import Dict +import numpy as np +import pandas as pd +import plotly.graph_objects as go from base_model_trainer import BaseModelTrainer from dashboard import generate_classifier_explainer_dashboard -from plotly.graph_objects import Figure from pycaret.classification import ClassificationExperiment +from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve from utils import predict_proba LOG = logging.getLogger(__name__) +def _apply_report_layout(fig: go.Figure) -> go.Figure: + # Give the left side more space for y-axis title/ticks and let axes auto-reserve room + fig.update_xaxes(automargin=True, title_standoff=12) + fig.update_yaxes(automargin=True, title_standoff=12) + fig.update_layout( + autosize=True, + margin=dict(l=120, r=40, t=60, b=60), # bump 'l' if you still see clipping + ) + return fig + + class ClassificationModelTrainer(BaseModelTrainer): def __init__( self, @@ -50,20 +64,19 @@ ) plots = [ - 'confusion_matrix', - 'auc', - 'threshold', - 'pr', - 'error', - 'class_report', - 'learning', - 'calibration', - 'vc', - 'dimension', - 'manifold', - 'rfe', - 'feature', - 'feature_all', + "auc", + "threshold", + "pr", + "error", + "class_report", + "learning", + "calibration", + "vc", + "dimension", + "manifold", + "rfe", + "feature", + "feature_all", ] for plot_name in plots: try: @@ -102,24 +115,146 @@ LOG.info("Generating explainer plots") + # Ensure predict_proba is available here too + if not hasattr(self.best_model, "predict_proba"): + self.best_model.predict_proba = types.MethodType( + predict_proba, self.best_model + ) + LOG.warning( + f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." + ) + X_test = self.exp.X_test_transformed.copy() y_test = self.exp.y_test_transformed explainer = ClassifierExplainer(self.best_model, X_test, y_test) # a dict to hold the raw Figure objects or callables - self.explainer_plots: Dict[str, Figure] = {} + self.explainer_plots: Dict[str, go.Figure] = {} + + # --- Threshold-aware overrides for CM / ROC / PR --- + prob_thresh = getattr(self, "probability_threshold", None) + + # Only for binary classification and when threshold is provided + if (prob_thresh is not None) and (not self.exp.is_multiclass): + X = self.exp.X_test_transformed + y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True) + + # Get positive-class scores (robust defaults) + classes = list(getattr(self.best_model, "classes_", [0, 1])) + try: + pos_idx = classes.index(1) if 1 in classes else 1 + except Exception: + pos_idx = 1 + + proba = self.best_model.predict_proba(X) + y_scores = proba[:, pos_idx] + + # Derive label names consistently + pos_label = classes[pos_idx] if len(classes) > pos_idx else 1 + neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0 + + # ---- Confusion Matrix @ threshold ---- + try: + y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label) + cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label]) + fig_cm = go.Figure( + data=go.Heatmap( + z=cm, + x=[f"Pred {neg_label}", f"Pred {pos_label}"], + y=[f"True {neg_label}", f"True {pos_label}"], + text=cm, + texttemplate="%{text}", + colorscale="Blues", + showscale=False, + ) + ) + fig_cm.update_layout( + title=f"Confusion Matrix @ threshold={prob_thresh:.2f}", + xaxis_title="Predicted label", + yaxis_title="True label", + ) + _apply_report_layout(fig_cm) + self.explainer_plots["confusion_matrix"] = fig_cm + except Exception as e: + LOG.warning( + f"Threshold-aware confusion matrix failed; falling back: {e}" + ) - # these go into the Test tab + # ---- ROC with threshold marker ---- + try: + fpr, tpr, thr = roc_curve(y, y_scores) + roc_auc = auc(fpr, tpr) + fig_roc = go.Figure() + fig_roc.add_scatter( + x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})" + ) + if len(thr): + mask = np.isfinite(thr) + if mask.any(): + idx_local = int(np.argmin(np.abs(thr[mask] - prob_thresh))) + idx = np.where(mask)[0][idx_local] + if 0 <= idx < len(fpr): + fig_roc.add_scatter( + x=[fpr[idx]], + y=[tpr[idx]], + mode="markers", + name=f"@ {prob_thresh:.2f}", + marker=dict(size=10), + ) + fig_roc.update_layout( + title=f"ROC Curve (marker at threshold={prob_thresh:.2f})", + xaxis_title="False Positive Rate", + yaxis_title="True Positive Rate", + ) + _apply_report_layout(fig_roc) + self.explainer_plots["roc_auc"] = fig_roc + except Exception as e: + LOG.warning(f"Threshold marker on ROC failed; falling back: {e}") + + # ---- PR with threshold marker ---- + try: + precision, recall, thr_pr = precision_recall_curve(y, y_scores) + pr_auc = auc(recall, precision) + fig_pr = go.Figure() + fig_pr.add_scatter( + x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})" + ) + if len(thr_pr): + idx_pr = int(np.argmin(np.abs(thr_pr - prob_thresh))) + # note: thr_pr has length = len(precision) - 1 + idx_pr = max(0, min(idx_pr, len(recall) - 1)) + fig_pr.add_scatter( + x=[recall[idx_pr]], + y=[precision[idx_pr]], + mode="markers", + name=f"@ {prob_thresh:.2f}", + marker=dict(size=10), + ) + fig_pr.update_layout( + title=f"Precision–Recall (marker at threshold={prob_thresh:.2f})", + xaxis_title="Recall", + yaxis_title="Precision", + ) + _apply_report_layout(fig_pr) + self.explainer_plots["pr_auc"] = fig_pr + except Exception as e: + LOG.warning(f"Threshold marker on PR failed; falling back: {e}") + + # these go into the Test tab (don't overwrite overrides) for key, fn in [ ("roc_auc", explainer.plot_roc_auc), ("pr_auc", explainer.plot_pr_auc), ("lift_curve", explainer.plot_lift_curve), ("confusion_matrix", explainer.plot_confusion_matrix), - ("threshold", explainer.plot_precision), # Percentage vs probability + ("threshold", explainer.plot_precision), # percentage vs probability ("cumulative_precision", explainer.plot_cumulative_precision), ]: + if key in self.explainer_plots: + continue try: - self.explainer_plots[key] = fn() + fig = fn() + if fig is not None: + self.explainer_plots[key] = fig except Exception as e: LOG.error(f"Error generating explainer plot {key}: {e}") @@ -143,7 +278,9 @@ if feat in explainer.X.columns or feat in explainer.onehot_cols: valid_feats.append(feat) else: - LOG.warning(f"Skipping PDP for feature {feat!r}: not found in explainer data") + LOG.warning( + f"Skipping PDP for feature {feat!r}: not found in explainer data" + ) for feat in valid_feats: # wrap each PDP call to catch any unexpected AssertionErrors @@ -157,6 +294,7 @@ except Exception as e: LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}") return None + return _plot self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat)