Mercurial > repos > goeckslab > pycaret_predict
diff pycaret_classification.py @ 17:c5c324ac29fc draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
| author | goeckslab |
|---|---|
| date | Sat, 06 Dec 2025 14:20:36 +0000 |
| parents | a2aeeb754d76 |
| children |
line wrap: on
line diff
--- a/pycaret_classification.py Fri Nov 28 22:28:26 2025 +0000 +++ b/pycaret_classification.py Sat Dec 06 14:20:36 2025 +0000 @@ -8,7 +8,14 @@ from base_model_trainer import BaseModelTrainer from dashboard import generate_classifier_explainer_dashboard from pycaret.classification import ClassificationExperiment -from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve +from sklearn.metrics import ( + auc, + confusion_matrix, + matthews_corrcoef, + precision_recall_curve, + precision_recall_fscore_support, + roc_curve, +) from utils import predict_proba LOG = logging.getLogger(__name__) @@ -137,58 +144,36 @@ # a dict to hold the raw Figure objects or callables self.explainer_plots: Dict[str, go.Figure] = {} + y_true, y_pred, label_values, y_scores = self._get_test_predictions() + + # — Classification report (Plotly table) — + try: + fig_report = self._build_classification_report_fig( + y_true, y_pred, label_values + ) + if fig_report is not None: + self.explainer_plots["class_report"] = fig_report + except Exception as e: + LOG.warning(f"Could not generate Plotly classification report: {e}") + + # — Confusion matrix with actual labels — + try: + fig_cm = self._build_confusion_matrix_fig(y_true, y_pred, label_values) + if fig_cm is not None: + self.explainer_plots["confusion_matrix"] = fig_cm + except Exception as e: + LOG.warning(f"Could not generate Plotly confusion matrix: {e}") + # --- Threshold-aware overrides for CM / ROC / PR --- prob_thresh = getattr(self, "probability_threshold", None) # Only for binary classification and when threshold is provided if (prob_thresh is not None) and (not self.exp.is_multiclass): - X = self.exp.X_test_transformed - y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True) - - # Get positive-class scores (robust defaults) - classes = list(getattr(self.best_model, "classes_", [0, 1])) - try: - pos_idx = classes.index(1) if 1 in classes else 1 - except Exception: - pos_idx = 1 - - proba = self.best_model.predict_proba(X) - y_scores = proba[:, pos_idx] - - # Derive label names consistently - pos_label = classes[pos_idx] if len(classes) > pos_idx else 1 - neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0 - - # ---- Confusion Matrix @ threshold ---- - try: - y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label) - cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label]) - fig_cm = go.Figure( - data=go.Heatmap( - z=cm, - x=[f"Pred {neg_label}", f"Pred {pos_label}"], - y=[f"True {neg_label}", f"True {pos_label}"], - text=cm, - texttemplate="%{text}", - colorscale="Blues", - showscale=False, - ) - ) - fig_cm.update_layout( - title=f"Confusion Matrix @ threshold={prob_thresh:.2f}", - xaxis_title="Predicted label", - yaxis_title="True label", - ) - _apply_report_layout(fig_cm) - self.explainer_plots["confusion_matrix"] = fig_cm - except Exception as e: - LOG.warning( - f"Threshold-aware confusion matrix failed; falling back: {e}" - ) - # ---- ROC with threshold marker ---- try: - fpr, tpr, thr = roc_curve(y, y_scores) + if y_scores is None: + raise ValueError("Predicted probabilities unavailable") + fpr, tpr, thr = roc_curve(y_true, y_scores) roc_auc = auc(fpr, tpr) fig_roc = go.Figure() fig_roc.add_scatter( @@ -219,7 +204,9 @@ # ---- PR with threshold marker ---- try: - precision, recall, thr_pr = precision_recall_curve(y, y_scores) + if y_scores is None: + raise ValueError("Predicted probabilities unavailable") + precision, recall, thr_pr = precision_recall_curve(y_true, y_scores) pr_auc = auc(recall, precision) fig_pr = go.Figure() fig_pr.add_scatter( @@ -304,3 +291,182 @@ return _plot self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat) + + def _get_test_predictions(self): + """ + Return y_true, y_pred, label list, and (optionally) positive-class + probabilities when available. Ensures predictions respect the optional + probability threshold for binary tasks. + """ + y_true = pd.Series(self.exp.y_test_transformed).reset_index(drop=True) + X_test = self.exp.X_test_transformed + prob_thresh = getattr(self, "probability_threshold", None) + + y_scores = None + try: + proba = self.best_model.predict_proba(X_test) + y_scores = proba + except Exception: + LOG.debug("predict_proba unavailable for test predictions.") + + try: + if ( + prob_thresh is not None + and not self.exp.is_multiclass + and y_scores is not None + and y_scores.ndim == 2 + and y_scores.shape[1] > 1 + ): + classes = list(getattr(self.best_model, "classes_", [])) + try: + pos_idx = classes.index(1) if 1 in classes else 1 + except Exception: + pos_idx = 1 + neg_idx = 1 - pos_idx if y_scores.shape[1] > 1 else 0 + pos_label = classes[pos_idx] if len(classes) > pos_idx else 1 + neg_label = classes[neg_idx] if len(classes) > neg_idx else 0 + y_pred = np.where(y_scores[:, pos_idx] >= prob_thresh, pos_label, neg_label) + y_scores = y_scores[:, pos_idx] + else: + y_pred = self.best_model.predict(X_test) + except Exception as exc: + LOG.warning("Falling back to raw predict for test predictions: %s", exc) + y_pred = self.best_model.predict(X_test) + + y_pred = pd.Series(y_pred).reset_index(drop=True) + if y_scores is not None: + y_scores = np.asarray(y_scores) + if y_scores.ndim > 1 and y_scores.shape[1] == 1: + y_scores = y_scores.ravel() + if self.exp.is_multiclass and y_scores.ndim > 1: + # Avoid passing multiclass score matrices to ROC/PR utilities + y_scores = None + label_values = pd.unique(pd.concat([y_true, y_pred], ignore_index=True)) + return y_true, y_pred, label_values.tolist(), y_scores + + def _threshold_suffix(self) -> str: + """ + Build a suffix like ' (threshold=0.50)' for binary tasks; omit for + multiclass where thresholds are not applied. + """ + if getattr(self, "task_type", None) != "classification": + return "" + if getattr(self.exp, "is_multiclass", False): + return "" + prob_thresh = getattr(self, "probability_threshold", None) + if prob_thresh is None: + return " (threshold=0.50)" + try: + return f" (threshold={float(prob_thresh):.2f})" + except Exception: + return f" (threshold={prob_thresh})" + + def _build_confusion_matrix_fig(self, y_true, y_pred, labels): + def _label_sort_key(lbl): + try: + return (0, float(lbl)) + except Exception: + return (1, str(lbl)) + + ordered_labels = sorted(labels, key=_label_sort_key) + cm = confusion_matrix(y_true, y_pred, labels=ordered_labels) + label_names = [str(lbl) for lbl in ordered_labels] + fig_cm = go.Figure( + data=go.Heatmap( + z=cm, + x=[f"Pred {lbl}" for lbl in label_names], + y=[f"True {lbl}" for lbl in label_names], + text=cm, + texttemplate="%{text}", + colorscale="Blues", + showscale=False, + ) + ) + fig_cm.update_layout( + title=f"Confusion Matrix{self._threshold_suffix()}", + xaxis_title=f"Predicted label ({self.target})", + yaxis_title=f"True label ({self.target})", + ) + fig_cm.update_xaxes( + type="category", + categoryorder="array", + categoryarray=[f"Pred {lbl}" for lbl in label_names], + ) + fig_cm.update_yaxes( + type="category", + categoryorder="array", + categoryarray=[f"True {lbl}" for lbl in label_names], + autorange="reversed", + ) + _apply_report_layout(fig_cm) + return fig_cm + + def _build_classification_report_fig(self, y_true, y_pred, labels): + precision, recall, f1, support = precision_recall_fscore_support( + y_true, y_pred, labels=labels, zero_division=0 + ) + mcc_scores = [] + for lbl in labels: + y_true_bin = (y_true == lbl).astype(int) + y_pred_bin = (y_pred == lbl).astype(int) + try: + mcc_val = matthews_corrcoef(y_true_bin, y_pred_bin) + except Exception: + mcc_val = 0.0 + mcc_scores.append(mcc_val) + + label_names = [str(lbl) for lbl in labels] + metrics = ["precision", "recall", "f1", "support"] + + max_support = float(max(support) if len(support) else 0) + z_rows = [] + text_rows = [] + for i, lbl in enumerate(label_names): + norm_support = (support[i] / max_support) if max_support else 0.0 + z_rows.append( + [ + precision[i], + recall[i], + f1[i], + norm_support, + ] + ) + text_rows.append( + [ + f"{precision[i]:.3f}", + f"{recall[i]:.3f}", + f"{f1[i]:.3f}", + f"{int(support[i])}", + ] + ) + + fig = go.Figure( + data=go.Heatmap( + z=z_rows, + x=metrics, + y=label_names, + colorscale="YlOrRd", + zmin=0, + zmax=1, + colorbar=dict(title="Scale"), + text=text_rows, + texttemplate="%{text}", + hovertemplate="Label=%{y}<br>Metric=%{x}<br>Value=%{text}<extra></extra>", + ) + ) + fig.update_yaxes( + title_text=f"Label ({self.target})", + autorange="reversed", + type="category", + tickmode="array", + tickvals=label_names, + ticktext=label_names, + showgrid=False, + ) + fig.update_xaxes(title_text="", tickangle=45) + fig.update_layout( + title=f"Per-Class Metrics{self._threshold_suffix()}", + margin=dict(l=70, r=60, t=70, b=80), + ) + _apply_report_layout(fig) + return fig
