Mercurial > repos > goeckslab > multimodal_learner
diff metrics_logic.py @ 0:375c36923da1 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
| author | goeckslab |
|---|---|
| date | Tue, 09 Dec 2025 23:49:47 +0000 |
| parents | |
| children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/metrics_logic.py Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,313 @@ +from collections import OrderedDict +from typing import Dict, Optional, Tuple + +import numpy as np +import pandas as pd +from sklearn.metrics import ( + accuracy_score, + average_precision_score, + cohen_kappa_score, + confusion_matrix, + f1_score, + log_loss, + matthews_corrcoef, + mean_absolute_error, + mean_squared_error, + median_absolute_error, + precision_score, + r2_score, + recall_score, + roc_auc_score, +) + + +# -------------------- Transparent Metrics (task-aware) -------------------- # + +def _safe_y_proba_to_array(y_proba) -> Optional[np.ndarray]: + """Convert predictor.predict_proba output (array/DataFrame/dict) to np.ndarray or None.""" + if y_proba is None: + return None + if isinstance(y_proba, pd.DataFrame): + return y_proba.values + if isinstance(y_proba, (list, tuple)): + return np.asarray(y_proba) + if isinstance(y_proba, np.ndarray): + return y_proba + if isinstance(y_proba, dict): + try: + return np.vstack([np.asarray(v) for _, v in sorted(y_proba.items())]).T + except Exception: + return None + return None + + +def _specificity_from_cm(cm: np.ndarray) -> float: + """Specificity (TNR) for binary confusion matrix.""" + if cm.shape != (2, 2): + return np.nan + tn, fp, fn, tp = cm.ravel() + denom = (tn + fp) + return float(tn / denom) if denom > 0 else np.nan + + +def _compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> "OrderedDict[str, float]": + mse = mean_squared_error(y_true, y_pred) + rmse = float(np.sqrt(mse)) + mae = mean_absolute_error(y_true, y_pred) + # Avoid division by zero using clip + mape = float(np.mean(np.abs((y_true - y_pred) / np.clip(np.abs(y_true), 1e-12, None))) * 100.0) + r2 = r2_score(y_true, y_pred) + medae = median_absolute_error(y_true, y_pred) + + metrics = OrderedDict() + metrics["MSE"] = mse + metrics["RMSE"] = rmse + metrics["MAE"] = mae + metrics["MAPE_%"] = mape + metrics["R2"] = r2 + metrics["MedianAE"] = medae + return metrics + + +def _compute_binary_metrics( + y_true: pd.Series, + y_pred: pd.Series, + y_proba: Optional[np.ndarray], + predictor +) -> "OrderedDict[str, float]": + metrics = OrderedDict() + classes_sorted = np.sort(pd.unique(y_true)) + # Choose the lexicographically larger class as "positive" + pos_label = classes_sorted[-1] + + metrics["Accuracy"] = accuracy_score(y_true, y_pred) + metrics["Precision"] = precision_score(y_true, y_pred, pos_label=pos_label, zero_division=0) + metrics["Recall_(Sensitivity/TPR)"] = recall_score(y_true, y_pred, pos_label=pos_label, zero_division=0) + metrics["F1-Score"] = f1_score(y_true, y_pred, pos_label=pos_label, zero_division=0) + + try: + cm = confusion_matrix(y_true, y_pred, labels=classes_sorted) + metrics["Specificity_(TNR)"] = _specificity_from_cm(cm) + except Exception: + metrics["Specificity_(TNR)"] = np.nan + + # Probabilistic metrics + if y_proba is not None: + # pick column of positive class + if y_proba.ndim == 1: + pos_scores = y_proba + else: + pos_col_idx = -1 + try: + if hasattr(predictor, "class_labels") and predictor.class_labels: + pos_col_idx = list(predictor.class_labels).index(pos_label) + except Exception: + pos_col_idx = -1 + pos_scores = y_proba[:, pos_col_idx] + try: + metrics["ROC-AUC"] = roc_auc_score(y_true == pos_label, pos_scores) + except Exception: + metrics["ROC-AUC"] = np.nan + try: + metrics["PR-AUC"] = average_precision_score(y_true == pos_label, pos_scores) + except Exception: + metrics["PR-AUC"] = np.nan + try: + if y_proba.ndim == 1: + y_proba_ll = np.column_stack([1 - pos_scores, pos_scores]) + else: + y_proba_ll = y_proba + metrics["LogLoss"] = log_loss(y_true, y_proba_ll, labels=classes_sorted) + except Exception: + metrics["LogLoss"] = np.nan + else: + metrics["ROC-AUC"] = np.nan + metrics["PR-AUC"] = np.nan + metrics["LogLoss"] = np.nan + + try: + metrics["MCC"] = matthews_corrcoef(y_true, y_pred) + except Exception: + metrics["MCC"] = np.nan + + return metrics + + +def _compute_multiclass_metrics( + y_true: pd.Series, + y_pred: pd.Series, + y_proba: Optional[np.ndarray] +) -> "OrderedDict[str, float]": + metrics = OrderedDict() + metrics["Accuracy"] = accuracy_score(y_true, y_pred) + metrics["Macro Precision"] = precision_score(y_true, y_pred, average="macro", zero_division=0) + metrics["Macro Recall"] = recall_score(y_true, y_pred, average="macro", zero_division=0) + metrics["Macro F1"] = f1_score(y_true, y_pred, average="macro", zero_division=0) + metrics["Weighted Precision"] = precision_score(y_true, y_pred, average="weighted", zero_division=0) + metrics["Weighted Recall"] = recall_score(y_true, y_pred, average="weighted", zero_division=0) + metrics["Weighted F1"] = f1_score(y_true, y_pred, average="weighted", zero_division=0) + + try: + metrics["Cohen_Kappa"] = cohen_kappa_score(y_true, y_pred) + except Exception: + metrics["Cohen_Kappa"] = np.nan + try: + metrics["MCC"] = matthews_corrcoef(y_true, y_pred) + except Exception: + metrics["MCC"] = np.nan + + # Probabilistic metrics + classes_sorted = np.sort(pd.unique(y_true)) + if y_proba is not None and y_proba.ndim == 2: + try: + metrics["LogLoss"] = log_loss(y_true, y_proba, labels=classes_sorted) + except Exception: + metrics["LogLoss"] = np.nan + # Macro ROC-AUC / PR-AUC via OVR + try: + class_to_index = {c: i for i, c in enumerate(classes_sorted)} + y_true_idx = np.vectorize(class_to_index.get)(y_true) + metrics["ROC-AUC_macro"] = roc_auc_score(y_true_idx, y_proba, multi_class="ovr", average="macro") + except Exception: + metrics["ROC-AUC_macro"] = np.nan + try: + Y_true_ind = np.zeros_like(y_proba) + idx_map = {c: i for i, c in enumerate(classes_sorted)} + Y_true_ind[np.arange(y_proba.shape[0]), np.vectorize(idx_map.get)(y_true)] = 1 + metrics["PR-AUC_macro"] = average_precision_score(Y_true_ind, y_proba, average="macro") + except Exception: + metrics["PR-AUC_macro"] = np.nan + else: + metrics["LogLoss"] = np.nan + metrics["ROC-AUC_macro"] = np.nan + metrics["PR-AUC_macro"] = np.nan + + return metrics + + +def aggregate_metrics(list_of_dicts): + """Aggregate a list of metrics dicts (per split) into mean/std.""" + agg_mean = {} + agg_std = {} + for split in ("Train", "Validation", "Test", "Test (external)"): + keys = set() + for m in list_of_dicts: + if isinstance(m, dict) and split in m: + keys.update(m[split].keys()) + if not keys: + continue + agg_mean[split] = {} + agg_std[split] = {} + for k in keys: + vals = [m[split][k] for m in list_of_dicts if split in m and k in m[split]] + numeric_vals = [] + for v in vals: + try: + numeric_vals.append(float(v)) + except Exception: + pass + if numeric_vals: + agg_mean[split][k] = float(np.mean(numeric_vals)) + agg_std[split][k] = float(np.std(numeric_vals, ddof=0)) + else: + agg_mean[split][k] = vals[-1] if vals else None + agg_std[split][k] = None + return agg_mean, agg_std + + +def compute_metrics_for_split( + predictor, + df: pd.DataFrame, + target_col: str, + problem_type: str, + threshold: Optional[float] = None, # <— NEW +) -> "OrderedDict[str, float]": + """Compute transparency metrics for one split (Train/Val/Test) based on task type.""" + # Prepare inputs + features = df.drop(columns=[target_col], errors="ignore") + y_true_series = df[target_col].reset_index(drop=True) + + # Probabilities (if available) + y_proba = None + try: + y_proba_raw = predictor.predict_proba(features) + y_proba = _safe_y_proba_to_array(y_proba_raw) + except Exception: + y_proba = None + + # Labels (optionally thresholded for binary) + y_pred_series = None + if problem_type == "binary" and (threshold is not None) and (y_proba is not None): + classes_sorted = np.sort(pd.unique(y_true_series)) + pos_label = classes_sorted[-1] + neg_label = classes_sorted[0] + if y_proba.ndim == 1: + pos_scores = y_proba + else: + pos_col_idx = -1 + try: + if hasattr(predictor, "class_labels") and predictor.class_labels: + pos_col_idx = list(predictor.class_labels).index(pos_label) + except Exception: + pos_col_idx = -1 + pos_scores = y_proba[:, pos_col_idx] + y_pred_series = pd.Series(np.where(pos_scores >= float(threshold), pos_label, neg_label)).reset_index(drop=True) + else: + # Fall back to model's default label prediction (argmax / 0.5 equivalent) + y_pred_series = pd.Series(predictor.predict(features)).reset_index(drop=True) + + if problem_type == "regression": + y_true_arr = np.asarray(y_true_series, dtype=float) + y_pred_arr = np.asarray(y_pred_series, dtype=float) + return _compute_regression_metrics(y_true_arr, y_pred_arr) + + if problem_type == "binary": + return _compute_binary_metrics(y_true_series, y_pred_series, y_proba, predictor) + + # multiclass + return _compute_multiclass_metrics(y_true_series, y_pred_series, y_proba) + + +def evaluate_all_transparency( + predictor, + train_df: Optional[pd.DataFrame], + val_df: Optional[pd.DataFrame], + test_df: Optional[pd.DataFrame], + target_col: str, + problem_type: str, + threshold: Optional[float] = None, +) -> Tuple[pd.DataFrame, Dict[str, Dict[str, float]]]: + """ + Evaluate Train/Val/Test with the transparent metrics suite. + Returns: + - metrics_table: DataFrame with index=Metric, columns subset of [Train, Validation, Test] + - raw_dict: nested dict {split -> {metric -> value}} + """ + split_results: Dict[str, Dict[str, float]] = {} + splits = [] + + # IMPORTANT: do NOT apply threshold to Train/Val + if train_df is not None and len(train_df): + split_results["Train"] = compute_metrics_for_split(predictor, train_df, target_col, problem_type, threshold=None) + splits.append("Train") + if val_df is not None and len(val_df): + split_results["Validation"] = compute_metrics_for_split(predictor, val_df, target_col, problem_type, threshold=None) + splits.append("Validation") + if test_df is not None and len(test_df): + split_results["Test"] = compute_metrics_for_split(predictor, test_df, target_col, problem_type, threshold=threshold) + splits.append("Test") + + # Preserve order from the first split; include any extras from others + order_source = split_results[splits[0]] if splits else {} + all_metrics = list(order_source.keys()) + for s in splits[1:]: + for m in split_results[s].keys(): + if m not in all_metrics: + all_metrics.append(m) + + metrics_table = pd.DataFrame(index=all_metrics, columns=splits, dtype=float) + for s in splits: + for m, v in split_results[s].items(): + metrics_table.loc[m, s] = v + + return metrics_table, split_results
