diff metrics_logic.py @ 8:a48e750cfd25 draft

planemo upload for repository https://github.com/goeckslab/gleam.git commit c8a7fef0c54c269afd6c6bdf035af1a7574d11cb
author goeckslab
date Fri, 30 Jan 2026 14:20:49 +0000
parents 375c36923da1
children
line wrap: on
line diff
--- a/metrics_logic.py	Wed Jan 28 19:56:37 2026 +0000
+++ b/metrics_logic.py	Fri Jan 30 14:20:49 2026 +0000
@@ -18,6 +18,7 @@
     r2_score,
     recall_score,
     roc_auc_score,
+    roc_curve
 )
 
 
@@ -69,16 +70,40 @@
     return metrics
 
 
+def _get_binary_scores(
+    y_true: pd.Series,
+    y_proba: Optional[np.ndarray],
+    predictor,
+) -> Tuple[np.ndarray, object, Optional[np.ndarray]]:
+    classes_sorted = np.sort(pd.unique(y_true))
+    pos_label = classes_sorted[-1]
+    pos_scores = None
+    if y_proba is not None:
+        if y_proba.ndim == 1:
+            pos_scores = y_proba
+        else:
+            pos_col_idx = -1
+            try:
+                if hasattr(predictor, "class_labels") and predictor.class_labels:
+                    pos_col_idx = list(predictor.class_labels).index(pos_label)
+            except Exception:
+                pos_col_idx = -1
+            pos_scores = y_proba[:, pos_col_idx]
+    return classes_sorted, pos_label, pos_scores
+
+
 def _compute_binary_metrics(
     y_true: pd.Series,
     y_pred: pd.Series,
     y_proba: Optional[np.ndarray],
-    predictor
+    predictor,
+    classes_sorted: Optional[np.ndarray] = None,
+    pos_label: Optional[object] = None,
+    pos_scores: Optional[np.ndarray] = None,
 ) -> "OrderedDict[str, float]":
     metrics = OrderedDict()
-    classes_sorted = np.sort(pd.unique(y_true))
-    # Choose the lexicographically larger class as "positive"
-    pos_label = classes_sorted[-1]
+    if classes_sorted is None or pos_label is None or pos_scores is None:
+        classes_sorted, pos_label, pos_scores = _get_binary_scores(y_true, y_proba, predictor)
 
     metrics["Accuracy"] = accuracy_score(y_true, y_pred)
     metrics["Precision"] = precision_score(y_true, y_pred, pos_label=pos_label, zero_division=0)
@@ -92,18 +117,7 @@
         metrics["Specificity_(TNR)"] = np.nan
 
     # Probabilistic metrics
-    if y_proba is not None:
-        # pick column of positive class
-        if y_proba.ndim == 1:
-            pos_scores = y_proba
-        else:
-            pos_col_idx = -1
-            try:
-                if hasattr(predictor, "class_labels") and predictor.class_labels:
-                    pos_col_idx = list(predictor.class_labels).index(pos_label)
-            except Exception:
-                pos_col_idx = -1
-            pos_scores = y_proba[:, pos_col_idx]
+    if y_proba is not None and pos_scores is not None:
         try:
             metrics["ROC-AUC"] = roc_auc_score(y_true == pos_label, pos_scores)
         except Exception:
@@ -221,7 +235,8 @@
     target_col: str,
     problem_type: str,
     threshold: Optional[float] = None,    # <— NEW
-) -> "OrderedDict[str, float]":
+    return_curve: bool = False,
+) -> "OrderedDict[str, float] | Tuple[OrderedDict[str, float], Optional[dict]]":
     """Compute transparency metrics for one split (Train/Val/Test) based on task type."""
     # Prepare inputs
     features = df.drop(columns=[target_col], errors="ignore")
@@ -235,22 +250,14 @@
     except Exception:
         y_proba = None
 
+    classes_sorted = pos_label = pos_scores = None
+    if problem_type == "binary":
+        classes_sorted, pos_label, pos_scores = _get_binary_scores(y_true_series, y_proba, predictor)
+
     # Labels (optionally thresholded for binary)
     y_pred_series = None
-    if problem_type == "binary" and (threshold is not None) and (y_proba is not None):
-        classes_sorted = np.sort(pd.unique(y_true_series))
-        pos_label = classes_sorted[-1]
+    if problem_type == "binary" and (threshold is not None) and (pos_scores is not None):
         neg_label = classes_sorted[0]
-        if y_proba.ndim == 1:
-            pos_scores = y_proba
-        else:
-            pos_col_idx = -1
-            try:
-                if hasattr(predictor, "class_labels") and predictor.class_labels:
-                    pos_col_idx = list(predictor.class_labels).index(pos_label)
-            except Exception:
-                pos_col_idx = -1
-            pos_scores = y_proba[:, pos_col_idx]
         y_pred_series = pd.Series(np.where(pos_scores >= float(threshold), pos_label, neg_label)).reset_index(drop=True)
     else:
         # Fall back to model's default label prediction (argmax / 0.5 equivalent)
@@ -259,13 +266,35 @@
     if problem_type == "regression":
         y_true_arr = np.asarray(y_true_series, dtype=float)
         y_pred_arr = np.asarray(y_pred_series, dtype=float)
-        return _compute_regression_metrics(y_true_arr, y_pred_arr)
+        metrics = _compute_regression_metrics(y_true_arr, y_pred_arr)
+        return (metrics, None) if return_curve else metrics
 
     if problem_type == "binary":
-        return _compute_binary_metrics(y_true_series, y_pred_series, y_proba, predictor)
+        metrics = _compute_binary_metrics(
+            y_true_series,
+            y_pred_series,
+            y_proba,
+            predictor,
+            classes_sorted=classes_sorted,
+            pos_label=pos_label,
+            pos_scores=pos_scores,
+        )
+        roc_curve_data = None
+        if return_curve and pos_scores is not None and pos_label is not None:
+            try:
+                fpr, tpr, thresholds = roc_curve(y_true_series == pos_label, pos_scores)
+                roc_curve_data = {
+                    "fpr": fpr.tolist(),
+                    "tpr": tpr.tolist(),
+                    "thresholds": thresholds.tolist(),
+                }
+            except Exception:
+                roc_curve_data = None
+        return (metrics, roc_curve_data) if return_curve else metrics
 
     # multiclass
-    return _compute_multiclass_metrics(y_true_series, y_pred_series, y_proba)
+    metrics = _compute_multiclass_metrics(y_true_series, y_pred_series, y_proba)
+    return (metrics, None) if return_curve else metrics
 
 
 def evaluate_all_transparency(
@@ -276,25 +305,57 @@
     target_col: str,
     problem_type: str,
     threshold: Optional[float] = None,
-) -> Tuple[pd.DataFrame, Dict[str, Dict[str, float]]]:
+) -> Tuple[pd.DataFrame, Dict[str, Dict[str, float]], Dict[str, dict]]:
     """
     Evaluate Train/Val/Test with the transparent metrics suite.
     Returns:
       - metrics_table: DataFrame with index=Metric, columns subset of [Train, Validation, Test]
       - raw_dict: nested dict {split -> {metric -> value}}
+      - roc_curves: nested dict {split -> {fpr, tpr, thresholds}} (binary only)
     """
     split_results: Dict[str, Dict[str, float]] = {}
+    roc_curves: Dict[str, dict] = {}
     splits = []
 
     # IMPORTANT: do NOT apply threshold to Train/Val
     if train_df is not None and len(train_df):
-        split_results["Train"] = compute_metrics_for_split(predictor, train_df, target_col, problem_type, threshold=None)
+        train_metrics, train_curve = compute_metrics_for_split(
+            predictor,
+            train_df,
+            target_col,
+            problem_type,
+            threshold=None,
+            return_curve=True,
+        )
+        split_results["Train"] = train_metrics
+        if train_curve:
+            roc_curves["Train"] = train_curve
         splits.append("Train")
     if val_df is not None and len(val_df):
-        split_results["Validation"] = compute_metrics_for_split(predictor, val_df, target_col, problem_type, threshold=None)
+        val_metrics, val_curve = compute_metrics_for_split(
+            predictor,
+            val_df,
+            target_col,
+            problem_type,
+            threshold=None,
+            return_curve=True,
+        )
+        split_results["Validation"] = val_metrics
+        if val_curve:
+            roc_curves["Validation"] = val_curve
         splits.append("Validation")
     if test_df is not None and len(test_df):
-        split_results["Test"] = compute_metrics_for_split(predictor, test_df, target_col, problem_type, threshold=threshold)
+        test_metrics, test_curve = compute_metrics_for_split(
+            predictor,
+            test_df,
+            target_col,
+            problem_type,
+            threshold=threshold,
+            return_curve=True,
+        )
+        split_results["Test"] = test_metrics
+        if test_curve:
+            roc_curves["Test"] = test_curve
         splits.append("Test")
 
     # Preserve order from the first split; include any extras from others
@@ -310,4 +371,4 @@
         for m, v in split_results[s].items():
             metrics_table.loc[m, s] = v
 
-    return metrics_table, split_results
+    return metrics_table, split_results, roc_curves