comparison metrics_logic.py @ 0:375c36923da1 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
author goeckslab
date Tue, 09 Dec 2025 23:49:47 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:375c36923da1
1 from collections import OrderedDict
2 from typing import Dict, Optional, Tuple
3
4 import numpy as np
5 import pandas as pd
6 from sklearn.metrics import (
7 accuracy_score,
8 average_precision_score,
9 cohen_kappa_score,
10 confusion_matrix,
11 f1_score,
12 log_loss,
13 matthews_corrcoef,
14 mean_absolute_error,
15 mean_squared_error,
16 median_absolute_error,
17 precision_score,
18 r2_score,
19 recall_score,
20 roc_auc_score,
21 )
22
23
24 # -------------------- Transparent Metrics (task-aware) -------------------- #
25
26 def _safe_y_proba_to_array(y_proba) -> Optional[np.ndarray]:
27 """Convert predictor.predict_proba output (array/DataFrame/dict) to np.ndarray or None."""
28 if y_proba is None:
29 return None
30 if isinstance(y_proba, pd.DataFrame):
31 return y_proba.values
32 if isinstance(y_proba, (list, tuple)):
33 return np.asarray(y_proba)
34 if isinstance(y_proba, np.ndarray):
35 return y_proba
36 if isinstance(y_proba, dict):
37 try:
38 return np.vstack([np.asarray(v) for _, v in sorted(y_proba.items())]).T
39 except Exception:
40 return None
41 return None
42
43
44 def _specificity_from_cm(cm: np.ndarray) -> float:
45 """Specificity (TNR) for binary confusion matrix."""
46 if cm.shape != (2, 2):
47 return np.nan
48 tn, fp, fn, tp = cm.ravel()
49 denom = (tn + fp)
50 return float(tn / denom) if denom > 0 else np.nan
51
52
53 def _compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> "OrderedDict[str, float]":
54 mse = mean_squared_error(y_true, y_pred)
55 rmse = float(np.sqrt(mse))
56 mae = mean_absolute_error(y_true, y_pred)
57 # Avoid division by zero using clip
58 mape = float(np.mean(np.abs((y_true - y_pred) / np.clip(np.abs(y_true), 1e-12, None))) * 100.0)
59 r2 = r2_score(y_true, y_pred)
60 medae = median_absolute_error(y_true, y_pred)
61
62 metrics = OrderedDict()
63 metrics["MSE"] = mse
64 metrics["RMSE"] = rmse
65 metrics["MAE"] = mae
66 metrics["MAPE_%"] = mape
67 metrics["R2"] = r2
68 metrics["MedianAE"] = medae
69 return metrics
70
71
72 def _compute_binary_metrics(
73 y_true: pd.Series,
74 y_pred: pd.Series,
75 y_proba: Optional[np.ndarray],
76 predictor
77 ) -> "OrderedDict[str, float]":
78 metrics = OrderedDict()
79 classes_sorted = np.sort(pd.unique(y_true))
80 # Choose the lexicographically larger class as "positive"
81 pos_label = classes_sorted[-1]
82
83 metrics["Accuracy"] = accuracy_score(y_true, y_pred)
84 metrics["Precision"] = precision_score(y_true, y_pred, pos_label=pos_label, zero_division=0)
85 metrics["Recall_(Sensitivity/TPR)"] = recall_score(y_true, y_pred, pos_label=pos_label, zero_division=0)
86 metrics["F1-Score"] = f1_score(y_true, y_pred, pos_label=pos_label, zero_division=0)
87
88 try:
89 cm = confusion_matrix(y_true, y_pred, labels=classes_sorted)
90 metrics["Specificity_(TNR)"] = _specificity_from_cm(cm)
91 except Exception:
92 metrics["Specificity_(TNR)"] = np.nan
93
94 # Probabilistic metrics
95 if y_proba is not None:
96 # pick column of positive class
97 if y_proba.ndim == 1:
98 pos_scores = y_proba
99 else:
100 pos_col_idx = -1
101 try:
102 if hasattr(predictor, "class_labels") and predictor.class_labels:
103 pos_col_idx = list(predictor.class_labels).index(pos_label)
104 except Exception:
105 pos_col_idx = -1
106 pos_scores = y_proba[:, pos_col_idx]
107 try:
108 metrics["ROC-AUC"] = roc_auc_score(y_true == pos_label, pos_scores)
109 except Exception:
110 metrics["ROC-AUC"] = np.nan
111 try:
112 metrics["PR-AUC"] = average_precision_score(y_true == pos_label, pos_scores)
113 except Exception:
114 metrics["PR-AUC"] = np.nan
115 try:
116 if y_proba.ndim == 1:
117 y_proba_ll = np.column_stack([1 - pos_scores, pos_scores])
118 else:
119 y_proba_ll = y_proba
120 metrics["LogLoss"] = log_loss(y_true, y_proba_ll, labels=classes_sorted)
121 except Exception:
122 metrics["LogLoss"] = np.nan
123 else:
124 metrics["ROC-AUC"] = np.nan
125 metrics["PR-AUC"] = np.nan
126 metrics["LogLoss"] = np.nan
127
128 try:
129 metrics["MCC"] = matthews_corrcoef(y_true, y_pred)
130 except Exception:
131 metrics["MCC"] = np.nan
132
133 return metrics
134
135
136 def _compute_multiclass_metrics(
137 y_true: pd.Series,
138 y_pred: pd.Series,
139 y_proba: Optional[np.ndarray]
140 ) -> "OrderedDict[str, float]":
141 metrics = OrderedDict()
142 metrics["Accuracy"] = accuracy_score(y_true, y_pred)
143 metrics["Macro Precision"] = precision_score(y_true, y_pred, average="macro", zero_division=0)
144 metrics["Macro Recall"] = recall_score(y_true, y_pred, average="macro", zero_division=0)
145 metrics["Macro F1"] = f1_score(y_true, y_pred, average="macro", zero_division=0)
146 metrics["Weighted Precision"] = precision_score(y_true, y_pred, average="weighted", zero_division=0)
147 metrics["Weighted Recall"] = recall_score(y_true, y_pred, average="weighted", zero_division=0)
148 metrics["Weighted F1"] = f1_score(y_true, y_pred, average="weighted", zero_division=0)
149
150 try:
151 metrics["Cohen_Kappa"] = cohen_kappa_score(y_true, y_pred)
152 except Exception:
153 metrics["Cohen_Kappa"] = np.nan
154 try:
155 metrics["MCC"] = matthews_corrcoef(y_true, y_pred)
156 except Exception:
157 metrics["MCC"] = np.nan
158
159 # Probabilistic metrics
160 classes_sorted = np.sort(pd.unique(y_true))
161 if y_proba is not None and y_proba.ndim == 2:
162 try:
163 metrics["LogLoss"] = log_loss(y_true, y_proba, labels=classes_sorted)
164 except Exception:
165 metrics["LogLoss"] = np.nan
166 # Macro ROC-AUC / PR-AUC via OVR
167 try:
168 class_to_index = {c: i for i, c in enumerate(classes_sorted)}
169 y_true_idx = np.vectorize(class_to_index.get)(y_true)
170 metrics["ROC-AUC_macro"] = roc_auc_score(y_true_idx, y_proba, multi_class="ovr", average="macro")
171 except Exception:
172 metrics["ROC-AUC_macro"] = np.nan
173 try:
174 Y_true_ind = np.zeros_like(y_proba)
175 idx_map = {c: i for i, c in enumerate(classes_sorted)}
176 Y_true_ind[np.arange(y_proba.shape[0]), np.vectorize(idx_map.get)(y_true)] = 1
177 metrics["PR-AUC_macro"] = average_precision_score(Y_true_ind, y_proba, average="macro")
178 except Exception:
179 metrics["PR-AUC_macro"] = np.nan
180 else:
181 metrics["LogLoss"] = np.nan
182 metrics["ROC-AUC_macro"] = np.nan
183 metrics["PR-AUC_macro"] = np.nan
184
185 return metrics
186
187
188 def aggregate_metrics(list_of_dicts):
189 """Aggregate a list of metrics dicts (per split) into mean/std."""
190 agg_mean = {}
191 agg_std = {}
192 for split in ("Train", "Validation", "Test", "Test (external)"):
193 keys = set()
194 for m in list_of_dicts:
195 if isinstance(m, dict) and split in m:
196 keys.update(m[split].keys())
197 if not keys:
198 continue
199 agg_mean[split] = {}
200 agg_std[split] = {}
201 for k in keys:
202 vals = [m[split][k] for m in list_of_dicts if split in m and k in m[split]]
203 numeric_vals = []
204 for v in vals:
205 try:
206 numeric_vals.append(float(v))
207 except Exception:
208 pass
209 if numeric_vals:
210 agg_mean[split][k] = float(np.mean(numeric_vals))
211 agg_std[split][k] = float(np.std(numeric_vals, ddof=0))
212 else:
213 agg_mean[split][k] = vals[-1] if vals else None
214 agg_std[split][k] = None
215 return agg_mean, agg_std
216
217
218 def compute_metrics_for_split(
219 predictor,
220 df: pd.DataFrame,
221 target_col: str,
222 problem_type: str,
223 threshold: Optional[float] = None, # <— NEW
224 ) -> "OrderedDict[str, float]":
225 """Compute transparency metrics for one split (Train/Val/Test) based on task type."""
226 # Prepare inputs
227 features = df.drop(columns=[target_col], errors="ignore")
228 y_true_series = df[target_col].reset_index(drop=True)
229
230 # Probabilities (if available)
231 y_proba = None
232 try:
233 y_proba_raw = predictor.predict_proba(features)
234 y_proba = _safe_y_proba_to_array(y_proba_raw)
235 except Exception:
236 y_proba = None
237
238 # Labels (optionally thresholded for binary)
239 y_pred_series = None
240 if problem_type == "binary" and (threshold is not None) and (y_proba is not None):
241 classes_sorted = np.sort(pd.unique(y_true_series))
242 pos_label = classes_sorted[-1]
243 neg_label = classes_sorted[0]
244 if y_proba.ndim == 1:
245 pos_scores = y_proba
246 else:
247 pos_col_idx = -1
248 try:
249 if hasattr(predictor, "class_labels") and predictor.class_labels:
250 pos_col_idx = list(predictor.class_labels).index(pos_label)
251 except Exception:
252 pos_col_idx = -1
253 pos_scores = y_proba[:, pos_col_idx]
254 y_pred_series = pd.Series(np.where(pos_scores >= float(threshold), pos_label, neg_label)).reset_index(drop=True)
255 else:
256 # Fall back to model's default label prediction (argmax / 0.5 equivalent)
257 y_pred_series = pd.Series(predictor.predict(features)).reset_index(drop=True)
258
259 if problem_type == "regression":
260 y_true_arr = np.asarray(y_true_series, dtype=float)
261 y_pred_arr = np.asarray(y_pred_series, dtype=float)
262 return _compute_regression_metrics(y_true_arr, y_pred_arr)
263
264 if problem_type == "binary":
265 return _compute_binary_metrics(y_true_series, y_pred_series, y_proba, predictor)
266
267 # multiclass
268 return _compute_multiclass_metrics(y_true_series, y_pred_series, y_proba)
269
270
271 def evaluate_all_transparency(
272 predictor,
273 train_df: Optional[pd.DataFrame],
274 val_df: Optional[pd.DataFrame],
275 test_df: Optional[pd.DataFrame],
276 target_col: str,
277 problem_type: str,
278 threshold: Optional[float] = None,
279 ) -> Tuple[pd.DataFrame, Dict[str, Dict[str, float]]]:
280 """
281 Evaluate Train/Val/Test with the transparent metrics suite.
282 Returns:
283 - metrics_table: DataFrame with index=Metric, columns subset of [Train, Validation, Test]
284 - raw_dict: nested dict {split -> {metric -> value}}
285 """
286 split_results: Dict[str, Dict[str, float]] = {}
287 splits = []
288
289 # IMPORTANT: do NOT apply threshold to Train/Val
290 if train_df is not None and len(train_df):
291 split_results["Train"] = compute_metrics_for_split(predictor, train_df, target_col, problem_type, threshold=None)
292 splits.append("Train")
293 if val_df is not None and len(val_df):
294 split_results["Validation"] = compute_metrics_for_split(predictor, val_df, target_col, problem_type, threshold=None)
295 splits.append("Validation")
296 if test_df is not None and len(test_df):
297 split_results["Test"] = compute_metrics_for_split(predictor, test_df, target_col, problem_type, threshold=threshold)
298 splits.append("Test")
299
300 # Preserve order from the first split; include any extras from others
301 order_source = split_results[splits[0]] if splits else {}
302 all_metrics = list(order_source.keys())
303 for s in splits[1:]:
304 for m in split_results[s].keys():
305 if m not in all_metrics:
306 all_metrics.append(m)
307
308 metrics_table = pd.DataFrame(index=all_metrics, columns=splits, dtype=float)
309 for s in splits:
310 for m, v in split_results[s].items():
311 metrics_table.loc[m, s] = v
312
313 return metrics_table, split_results