Mercurial > repos > goeckslab > multimodal_learner
comparison metrics_logic.py @ 0:375c36923da1 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
| author | goeckslab |
|---|---|
| date | Tue, 09 Dec 2025 23:49:47 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:375c36923da1 |
|---|---|
| 1 from collections import OrderedDict | |
| 2 from typing import Dict, Optional, Tuple | |
| 3 | |
| 4 import numpy as np | |
| 5 import pandas as pd | |
| 6 from sklearn.metrics import ( | |
| 7 accuracy_score, | |
| 8 average_precision_score, | |
| 9 cohen_kappa_score, | |
| 10 confusion_matrix, | |
| 11 f1_score, | |
| 12 log_loss, | |
| 13 matthews_corrcoef, | |
| 14 mean_absolute_error, | |
| 15 mean_squared_error, | |
| 16 median_absolute_error, | |
| 17 precision_score, | |
| 18 r2_score, | |
| 19 recall_score, | |
| 20 roc_auc_score, | |
| 21 ) | |
| 22 | |
| 23 | |
| 24 # -------------------- Transparent Metrics (task-aware) -------------------- # | |
| 25 | |
| 26 def _safe_y_proba_to_array(y_proba) -> Optional[np.ndarray]: | |
| 27 """Convert predictor.predict_proba output (array/DataFrame/dict) to np.ndarray or None.""" | |
| 28 if y_proba is None: | |
| 29 return None | |
| 30 if isinstance(y_proba, pd.DataFrame): | |
| 31 return y_proba.values | |
| 32 if isinstance(y_proba, (list, tuple)): | |
| 33 return np.asarray(y_proba) | |
| 34 if isinstance(y_proba, np.ndarray): | |
| 35 return y_proba | |
| 36 if isinstance(y_proba, dict): | |
| 37 try: | |
| 38 return np.vstack([np.asarray(v) for _, v in sorted(y_proba.items())]).T | |
| 39 except Exception: | |
| 40 return None | |
| 41 return None | |
| 42 | |
| 43 | |
| 44 def _specificity_from_cm(cm: np.ndarray) -> float: | |
| 45 """Specificity (TNR) for binary confusion matrix.""" | |
| 46 if cm.shape != (2, 2): | |
| 47 return np.nan | |
| 48 tn, fp, fn, tp = cm.ravel() | |
| 49 denom = (tn + fp) | |
| 50 return float(tn / denom) if denom > 0 else np.nan | |
| 51 | |
| 52 | |
| 53 def _compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> "OrderedDict[str, float]": | |
| 54 mse = mean_squared_error(y_true, y_pred) | |
| 55 rmse = float(np.sqrt(mse)) | |
| 56 mae = mean_absolute_error(y_true, y_pred) | |
| 57 # Avoid division by zero using clip | |
| 58 mape = float(np.mean(np.abs((y_true - y_pred) / np.clip(np.abs(y_true), 1e-12, None))) * 100.0) | |
| 59 r2 = r2_score(y_true, y_pred) | |
| 60 medae = median_absolute_error(y_true, y_pred) | |
| 61 | |
| 62 metrics = OrderedDict() | |
| 63 metrics["MSE"] = mse | |
| 64 metrics["RMSE"] = rmse | |
| 65 metrics["MAE"] = mae | |
| 66 metrics["MAPE_%"] = mape | |
| 67 metrics["R2"] = r2 | |
| 68 metrics["MedianAE"] = medae | |
| 69 return metrics | |
| 70 | |
| 71 | |
| 72 def _compute_binary_metrics( | |
| 73 y_true: pd.Series, | |
| 74 y_pred: pd.Series, | |
| 75 y_proba: Optional[np.ndarray], | |
| 76 predictor | |
| 77 ) -> "OrderedDict[str, float]": | |
| 78 metrics = OrderedDict() | |
| 79 classes_sorted = np.sort(pd.unique(y_true)) | |
| 80 # Choose the lexicographically larger class as "positive" | |
| 81 pos_label = classes_sorted[-1] | |
| 82 | |
| 83 metrics["Accuracy"] = accuracy_score(y_true, y_pred) | |
| 84 metrics["Precision"] = precision_score(y_true, y_pred, pos_label=pos_label, zero_division=0) | |
| 85 metrics["Recall_(Sensitivity/TPR)"] = recall_score(y_true, y_pred, pos_label=pos_label, zero_division=0) | |
| 86 metrics["F1-Score"] = f1_score(y_true, y_pred, pos_label=pos_label, zero_division=0) | |
| 87 | |
| 88 try: | |
| 89 cm = confusion_matrix(y_true, y_pred, labels=classes_sorted) | |
| 90 metrics["Specificity_(TNR)"] = _specificity_from_cm(cm) | |
| 91 except Exception: | |
| 92 metrics["Specificity_(TNR)"] = np.nan | |
| 93 | |
| 94 # Probabilistic metrics | |
| 95 if y_proba is not None: | |
| 96 # pick column of positive class | |
| 97 if y_proba.ndim == 1: | |
| 98 pos_scores = y_proba | |
| 99 else: | |
| 100 pos_col_idx = -1 | |
| 101 try: | |
| 102 if hasattr(predictor, "class_labels") and predictor.class_labels: | |
| 103 pos_col_idx = list(predictor.class_labels).index(pos_label) | |
| 104 except Exception: | |
| 105 pos_col_idx = -1 | |
| 106 pos_scores = y_proba[:, pos_col_idx] | |
| 107 try: | |
| 108 metrics["ROC-AUC"] = roc_auc_score(y_true == pos_label, pos_scores) | |
| 109 except Exception: | |
| 110 metrics["ROC-AUC"] = np.nan | |
| 111 try: | |
| 112 metrics["PR-AUC"] = average_precision_score(y_true == pos_label, pos_scores) | |
| 113 except Exception: | |
| 114 metrics["PR-AUC"] = np.nan | |
| 115 try: | |
| 116 if y_proba.ndim == 1: | |
| 117 y_proba_ll = np.column_stack([1 - pos_scores, pos_scores]) | |
| 118 else: | |
| 119 y_proba_ll = y_proba | |
| 120 metrics["LogLoss"] = log_loss(y_true, y_proba_ll, labels=classes_sorted) | |
| 121 except Exception: | |
| 122 metrics["LogLoss"] = np.nan | |
| 123 else: | |
| 124 metrics["ROC-AUC"] = np.nan | |
| 125 metrics["PR-AUC"] = np.nan | |
| 126 metrics["LogLoss"] = np.nan | |
| 127 | |
| 128 try: | |
| 129 metrics["MCC"] = matthews_corrcoef(y_true, y_pred) | |
| 130 except Exception: | |
| 131 metrics["MCC"] = np.nan | |
| 132 | |
| 133 return metrics | |
| 134 | |
| 135 | |
| 136 def _compute_multiclass_metrics( | |
| 137 y_true: pd.Series, | |
| 138 y_pred: pd.Series, | |
| 139 y_proba: Optional[np.ndarray] | |
| 140 ) -> "OrderedDict[str, float]": | |
| 141 metrics = OrderedDict() | |
| 142 metrics["Accuracy"] = accuracy_score(y_true, y_pred) | |
| 143 metrics["Macro Precision"] = precision_score(y_true, y_pred, average="macro", zero_division=0) | |
| 144 metrics["Macro Recall"] = recall_score(y_true, y_pred, average="macro", zero_division=0) | |
| 145 metrics["Macro F1"] = f1_score(y_true, y_pred, average="macro", zero_division=0) | |
| 146 metrics["Weighted Precision"] = precision_score(y_true, y_pred, average="weighted", zero_division=0) | |
| 147 metrics["Weighted Recall"] = recall_score(y_true, y_pred, average="weighted", zero_division=0) | |
| 148 metrics["Weighted F1"] = f1_score(y_true, y_pred, average="weighted", zero_division=0) | |
| 149 | |
| 150 try: | |
| 151 metrics["Cohen_Kappa"] = cohen_kappa_score(y_true, y_pred) | |
| 152 except Exception: | |
| 153 metrics["Cohen_Kappa"] = np.nan | |
| 154 try: | |
| 155 metrics["MCC"] = matthews_corrcoef(y_true, y_pred) | |
| 156 except Exception: | |
| 157 metrics["MCC"] = np.nan | |
| 158 | |
| 159 # Probabilistic metrics | |
| 160 classes_sorted = np.sort(pd.unique(y_true)) | |
| 161 if y_proba is not None and y_proba.ndim == 2: | |
| 162 try: | |
| 163 metrics["LogLoss"] = log_loss(y_true, y_proba, labels=classes_sorted) | |
| 164 except Exception: | |
| 165 metrics["LogLoss"] = np.nan | |
| 166 # Macro ROC-AUC / PR-AUC via OVR | |
| 167 try: | |
| 168 class_to_index = {c: i for i, c in enumerate(classes_sorted)} | |
| 169 y_true_idx = np.vectorize(class_to_index.get)(y_true) | |
| 170 metrics["ROC-AUC_macro"] = roc_auc_score(y_true_idx, y_proba, multi_class="ovr", average="macro") | |
| 171 except Exception: | |
| 172 metrics["ROC-AUC_macro"] = np.nan | |
| 173 try: | |
| 174 Y_true_ind = np.zeros_like(y_proba) | |
| 175 idx_map = {c: i for i, c in enumerate(classes_sorted)} | |
| 176 Y_true_ind[np.arange(y_proba.shape[0]), np.vectorize(idx_map.get)(y_true)] = 1 | |
| 177 metrics["PR-AUC_macro"] = average_precision_score(Y_true_ind, y_proba, average="macro") | |
| 178 except Exception: | |
| 179 metrics["PR-AUC_macro"] = np.nan | |
| 180 else: | |
| 181 metrics["LogLoss"] = np.nan | |
| 182 metrics["ROC-AUC_macro"] = np.nan | |
| 183 metrics["PR-AUC_macro"] = np.nan | |
| 184 | |
| 185 return metrics | |
| 186 | |
| 187 | |
| 188 def aggregate_metrics(list_of_dicts): | |
| 189 """Aggregate a list of metrics dicts (per split) into mean/std.""" | |
| 190 agg_mean = {} | |
| 191 agg_std = {} | |
| 192 for split in ("Train", "Validation", "Test", "Test (external)"): | |
| 193 keys = set() | |
| 194 for m in list_of_dicts: | |
| 195 if isinstance(m, dict) and split in m: | |
| 196 keys.update(m[split].keys()) | |
| 197 if not keys: | |
| 198 continue | |
| 199 agg_mean[split] = {} | |
| 200 agg_std[split] = {} | |
| 201 for k in keys: | |
| 202 vals = [m[split][k] for m in list_of_dicts if split in m and k in m[split]] | |
| 203 numeric_vals = [] | |
| 204 for v in vals: | |
| 205 try: | |
| 206 numeric_vals.append(float(v)) | |
| 207 except Exception: | |
| 208 pass | |
| 209 if numeric_vals: | |
| 210 agg_mean[split][k] = float(np.mean(numeric_vals)) | |
| 211 agg_std[split][k] = float(np.std(numeric_vals, ddof=0)) | |
| 212 else: | |
| 213 agg_mean[split][k] = vals[-1] if vals else None | |
| 214 agg_std[split][k] = None | |
| 215 return agg_mean, agg_std | |
| 216 | |
| 217 | |
| 218 def compute_metrics_for_split( | |
| 219 predictor, | |
| 220 df: pd.DataFrame, | |
| 221 target_col: str, | |
| 222 problem_type: str, | |
| 223 threshold: Optional[float] = None, # <— NEW | |
| 224 ) -> "OrderedDict[str, float]": | |
| 225 """Compute transparency metrics for one split (Train/Val/Test) based on task type.""" | |
| 226 # Prepare inputs | |
| 227 features = df.drop(columns=[target_col], errors="ignore") | |
| 228 y_true_series = df[target_col].reset_index(drop=True) | |
| 229 | |
| 230 # Probabilities (if available) | |
| 231 y_proba = None | |
| 232 try: | |
| 233 y_proba_raw = predictor.predict_proba(features) | |
| 234 y_proba = _safe_y_proba_to_array(y_proba_raw) | |
| 235 except Exception: | |
| 236 y_proba = None | |
| 237 | |
| 238 # Labels (optionally thresholded for binary) | |
| 239 y_pred_series = None | |
| 240 if problem_type == "binary" and (threshold is not None) and (y_proba is not None): | |
| 241 classes_sorted = np.sort(pd.unique(y_true_series)) | |
| 242 pos_label = classes_sorted[-1] | |
| 243 neg_label = classes_sorted[0] | |
| 244 if y_proba.ndim == 1: | |
| 245 pos_scores = y_proba | |
| 246 else: | |
| 247 pos_col_idx = -1 | |
| 248 try: | |
| 249 if hasattr(predictor, "class_labels") and predictor.class_labels: | |
| 250 pos_col_idx = list(predictor.class_labels).index(pos_label) | |
| 251 except Exception: | |
| 252 pos_col_idx = -1 | |
| 253 pos_scores = y_proba[:, pos_col_idx] | |
| 254 y_pred_series = pd.Series(np.where(pos_scores >= float(threshold), pos_label, neg_label)).reset_index(drop=True) | |
| 255 else: | |
| 256 # Fall back to model's default label prediction (argmax / 0.5 equivalent) | |
| 257 y_pred_series = pd.Series(predictor.predict(features)).reset_index(drop=True) | |
| 258 | |
| 259 if problem_type == "regression": | |
| 260 y_true_arr = np.asarray(y_true_series, dtype=float) | |
| 261 y_pred_arr = np.asarray(y_pred_series, dtype=float) | |
| 262 return _compute_regression_metrics(y_true_arr, y_pred_arr) | |
| 263 | |
| 264 if problem_type == "binary": | |
| 265 return _compute_binary_metrics(y_true_series, y_pred_series, y_proba, predictor) | |
| 266 | |
| 267 # multiclass | |
| 268 return _compute_multiclass_metrics(y_true_series, y_pred_series, y_proba) | |
| 269 | |
| 270 | |
| 271 def evaluate_all_transparency( | |
| 272 predictor, | |
| 273 train_df: Optional[pd.DataFrame], | |
| 274 val_df: Optional[pd.DataFrame], | |
| 275 test_df: Optional[pd.DataFrame], | |
| 276 target_col: str, | |
| 277 problem_type: str, | |
| 278 threshold: Optional[float] = None, | |
| 279 ) -> Tuple[pd.DataFrame, Dict[str, Dict[str, float]]]: | |
| 280 """ | |
| 281 Evaluate Train/Val/Test with the transparent metrics suite. | |
| 282 Returns: | |
| 283 - metrics_table: DataFrame with index=Metric, columns subset of [Train, Validation, Test] | |
| 284 - raw_dict: nested dict {split -> {metric -> value}} | |
| 285 """ | |
| 286 split_results: Dict[str, Dict[str, float]] = {} | |
| 287 splits = [] | |
| 288 | |
| 289 # IMPORTANT: do NOT apply threshold to Train/Val | |
| 290 if train_df is not None and len(train_df): | |
| 291 split_results["Train"] = compute_metrics_for_split(predictor, train_df, target_col, problem_type, threshold=None) | |
| 292 splits.append("Train") | |
| 293 if val_df is not None and len(val_df): | |
| 294 split_results["Validation"] = compute_metrics_for_split(predictor, val_df, target_col, problem_type, threshold=None) | |
| 295 splits.append("Validation") | |
| 296 if test_df is not None and len(test_df): | |
| 297 split_results["Test"] = compute_metrics_for_split(predictor, test_df, target_col, problem_type, threshold=threshold) | |
| 298 splits.append("Test") | |
| 299 | |
| 300 # Preserve order from the first split; include any extras from others | |
| 301 order_source = split_results[splits[0]] if splits else {} | |
| 302 all_metrics = list(order_source.keys()) | |
| 303 for s in splits[1:]: | |
| 304 for m in split_results[s].keys(): | |
| 305 if m not in all_metrics: | |
| 306 all_metrics.append(m) | |
| 307 | |
| 308 metrics_table = pd.DataFrame(index=all_metrics, columns=splits, dtype=float) | |
| 309 for s in splits: | |
| 310 for m, v in split_results[s].items(): | |
| 311 metrics_table.loc[m, s] = v | |
| 312 | |
| 313 return metrics_table, split_results |
