diff pycaret_classification.py @ 17:c5c324ac29fc draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
author goeckslab
date Sat, 06 Dec 2025 14:20:36 +0000
parents a2aeeb754d76
children
line wrap: on
line diff
--- a/pycaret_classification.py	Fri Nov 28 22:28:26 2025 +0000
+++ b/pycaret_classification.py	Sat Dec 06 14:20:36 2025 +0000
@@ -8,7 +8,14 @@
 from base_model_trainer import BaseModelTrainer
 from dashboard import generate_classifier_explainer_dashboard
 from pycaret.classification import ClassificationExperiment
-from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve
+from sklearn.metrics import (
+    auc,
+    confusion_matrix,
+    matthews_corrcoef,
+    precision_recall_curve,
+    precision_recall_fscore_support,
+    roc_curve,
+)
 from utils import predict_proba
 
 LOG = logging.getLogger(__name__)
@@ -137,58 +144,36 @@
         # a dict to hold the raw Figure objects or callables
         self.explainer_plots: Dict[str, go.Figure] = {}
 
+        y_true, y_pred, label_values, y_scores = self._get_test_predictions()
+
+        # — Classification report (Plotly table) —
+        try:
+            fig_report = self._build_classification_report_fig(
+                y_true, y_pred, label_values
+            )
+            if fig_report is not None:
+                self.explainer_plots["class_report"] = fig_report
+        except Exception as e:
+            LOG.warning(f"Could not generate Plotly classification report: {e}")
+
+        # — Confusion matrix with actual labels —
+        try:
+            fig_cm = self._build_confusion_matrix_fig(y_true, y_pred, label_values)
+            if fig_cm is not None:
+                self.explainer_plots["confusion_matrix"] = fig_cm
+        except Exception as e:
+            LOG.warning(f"Could not generate Plotly confusion matrix: {e}")
+
         # --- Threshold-aware overrides for CM / ROC / PR ---
         prob_thresh = getattr(self, "probability_threshold", None)
 
         # Only for binary classification and when threshold is provided
         if (prob_thresh is not None) and (not self.exp.is_multiclass):
-            X = self.exp.X_test_transformed
-            y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True)
-
-            # Get positive-class scores (robust defaults)
-            classes = list(getattr(self.best_model, "classes_", [0, 1]))
-            try:
-                pos_idx = classes.index(1) if 1 in classes else 1
-            except Exception:
-                pos_idx = 1
-
-            proba = self.best_model.predict_proba(X)
-            y_scores = proba[:, pos_idx]
-
-            # Derive label names consistently
-            pos_label = classes[pos_idx] if len(classes) > pos_idx else 1
-            neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0
-
-            # ---- Confusion Matrix @ threshold ----
-            try:
-                y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label)
-                cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label])
-                fig_cm = go.Figure(
-                    data=go.Heatmap(
-                        z=cm,
-                        x=[f"Pred {neg_label}", f"Pred {pos_label}"],
-                        y=[f"True {neg_label}", f"True {pos_label}"],
-                        text=cm,
-                        texttemplate="%{text}",
-                        colorscale="Blues",
-                        showscale=False,
-                    )
-                )
-                fig_cm.update_layout(
-                    title=f"Confusion Matrix @ threshold={prob_thresh:.2f}",
-                    xaxis_title="Predicted label",
-                    yaxis_title="True label",
-                )
-                _apply_report_layout(fig_cm)
-                self.explainer_plots["confusion_matrix"] = fig_cm
-            except Exception as e:
-                LOG.warning(
-                    f"Threshold-aware confusion matrix failed; falling back: {e}"
-                )
-
             # ---- ROC with threshold marker ----
             try:
-                fpr, tpr, thr = roc_curve(y, y_scores)
+                if y_scores is None:
+                    raise ValueError("Predicted probabilities unavailable")
+                fpr, tpr, thr = roc_curve(y_true, y_scores)
                 roc_auc = auc(fpr, tpr)
                 fig_roc = go.Figure()
                 fig_roc.add_scatter(
@@ -219,7 +204,9 @@
 
             # ---- PR with threshold marker ----
             try:
-                precision, recall, thr_pr = precision_recall_curve(y, y_scores)
+                if y_scores is None:
+                    raise ValueError("Predicted probabilities unavailable")
+                precision, recall, thr_pr = precision_recall_curve(y_true, y_scores)
                 pr_auc = auc(recall, precision)
                 fig_pr = go.Figure()
                 fig_pr.add_scatter(
@@ -304,3 +291,182 @@
                 return _plot
 
             self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat)
+
+    def _get_test_predictions(self):
+        """
+        Return y_true, y_pred, label list, and (optionally) positive-class
+        probabilities when available. Ensures predictions respect the optional
+        probability threshold for binary tasks.
+        """
+        y_true = pd.Series(self.exp.y_test_transformed).reset_index(drop=True)
+        X_test = self.exp.X_test_transformed
+        prob_thresh = getattr(self, "probability_threshold", None)
+
+        y_scores = None
+        try:
+            proba = self.best_model.predict_proba(X_test)
+            y_scores = proba
+        except Exception:
+            LOG.debug("predict_proba unavailable for test predictions.")
+
+        try:
+            if (
+                prob_thresh is not None
+                and not self.exp.is_multiclass
+                and y_scores is not None
+                and y_scores.ndim == 2
+                and y_scores.shape[1] > 1
+            ):
+                classes = list(getattr(self.best_model, "classes_", []))
+                try:
+                    pos_idx = classes.index(1) if 1 in classes else 1
+                except Exception:
+                    pos_idx = 1
+                neg_idx = 1 - pos_idx if y_scores.shape[1] > 1 else 0
+                pos_label = classes[pos_idx] if len(classes) > pos_idx else 1
+                neg_label = classes[neg_idx] if len(classes) > neg_idx else 0
+                y_pred = np.where(y_scores[:, pos_idx] >= prob_thresh, pos_label, neg_label)
+                y_scores = y_scores[:, pos_idx]
+            else:
+                y_pred = self.best_model.predict(X_test)
+        except Exception as exc:
+            LOG.warning("Falling back to raw predict for test predictions: %s", exc)
+            y_pred = self.best_model.predict(X_test)
+
+        y_pred = pd.Series(y_pred).reset_index(drop=True)
+        if y_scores is not None:
+            y_scores = np.asarray(y_scores)
+            if y_scores.ndim > 1 and y_scores.shape[1] == 1:
+                y_scores = y_scores.ravel()
+            if self.exp.is_multiclass and y_scores.ndim > 1:
+                # Avoid passing multiclass score matrices to ROC/PR utilities
+                y_scores = None
+        label_values = pd.unique(pd.concat([y_true, y_pred], ignore_index=True))
+        return y_true, y_pred, label_values.tolist(), y_scores
+
+    def _threshold_suffix(self) -> str:
+        """
+        Build a suffix like ' (threshold=0.50)' for binary tasks; omit for
+        multiclass where thresholds are not applied.
+        """
+        if getattr(self, "task_type", None) != "classification":
+            return ""
+        if getattr(self.exp, "is_multiclass", False):
+            return ""
+        prob_thresh = getattr(self, "probability_threshold", None)
+        if prob_thresh is None:
+            return " (threshold=0.50)"
+        try:
+            return f" (threshold={float(prob_thresh):.2f})"
+        except Exception:
+            return f" (threshold={prob_thresh})"
+
+    def _build_confusion_matrix_fig(self, y_true, y_pred, labels):
+        def _label_sort_key(lbl):
+            try:
+                return (0, float(lbl))
+            except Exception:
+                return (1, str(lbl))
+
+        ordered_labels = sorted(labels, key=_label_sort_key)
+        cm = confusion_matrix(y_true, y_pred, labels=ordered_labels)
+        label_names = [str(lbl) for lbl in ordered_labels]
+        fig_cm = go.Figure(
+            data=go.Heatmap(
+                z=cm,
+                x=[f"Pred {lbl}" for lbl in label_names],
+                y=[f"True {lbl}" for lbl in label_names],
+                text=cm,
+                texttemplate="%{text}",
+                colorscale="Blues",
+                showscale=False,
+            )
+        )
+        fig_cm.update_layout(
+            title=f"Confusion Matrix{self._threshold_suffix()}",
+            xaxis_title=f"Predicted label ({self.target})",
+            yaxis_title=f"True label ({self.target})",
+        )
+        fig_cm.update_xaxes(
+            type="category",
+            categoryorder="array",
+            categoryarray=[f"Pred {lbl}" for lbl in label_names],
+        )
+        fig_cm.update_yaxes(
+            type="category",
+            categoryorder="array",
+            categoryarray=[f"True {lbl}" for lbl in label_names],
+            autorange="reversed",
+        )
+        _apply_report_layout(fig_cm)
+        return fig_cm
+
+    def _build_classification_report_fig(self, y_true, y_pred, labels):
+        precision, recall, f1, support = precision_recall_fscore_support(
+            y_true, y_pred, labels=labels, zero_division=0
+        )
+        mcc_scores = []
+        for lbl in labels:
+            y_true_bin = (y_true == lbl).astype(int)
+            y_pred_bin = (y_pred == lbl).astype(int)
+            try:
+                mcc_val = matthews_corrcoef(y_true_bin, y_pred_bin)
+            except Exception:
+                mcc_val = 0.0
+            mcc_scores.append(mcc_val)
+
+        label_names = [str(lbl) for lbl in labels]
+        metrics = ["precision", "recall", "f1", "support"]
+
+        max_support = float(max(support) if len(support) else 0)
+        z_rows = []
+        text_rows = []
+        for i, lbl in enumerate(label_names):
+            norm_support = (support[i] / max_support) if max_support else 0.0
+            z_rows.append(
+                [
+                    precision[i],
+                    recall[i],
+                    f1[i],
+                    norm_support,
+                ]
+            )
+            text_rows.append(
+                [
+                    f"{precision[i]:.3f}",
+                    f"{recall[i]:.3f}",
+                    f"{f1[i]:.3f}",
+                    f"{int(support[i])}",
+                ]
+            )
+
+        fig = go.Figure(
+            data=go.Heatmap(
+                z=z_rows,
+                x=metrics,
+                y=label_names,
+                colorscale="YlOrRd",
+                zmin=0,
+                zmax=1,
+                colorbar=dict(title="Scale"),
+                text=text_rows,
+                texttemplate="%{text}",
+                hovertemplate="Label=%{y}<br>Metric=%{x}<br>Value=%{text}<extra></extra>",
+            )
+        )
+        fig.update_yaxes(
+            title_text=f"Label ({self.target})",
+            autorange="reversed",
+            type="category",
+            tickmode="array",
+            tickvals=label_names,
+            ticktext=label_names,
+            showgrid=False,
+        )
+        fig.update_xaxes(title_text="", tickangle=45)
+        fig.update_layout(
+            title=f"Per-Class Metrics{self._threshold_suffix()}",
+            margin=dict(l=70, r=60, t=70, b=80),
+        )
+        _apply_report_layout(fig)
+        return fig