diff training_pipeline.py @ 8:a48e750cfd25 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit c8a7fef0c54c269afd6c6bdf035af1a7574d11cb
author goeckslab
date Fri, 30 Jan 2026 14:20:49 +0000
parents 871957823d0c
children
line wrap: on
line diff
--- a/training_pipeline.py	Wed Jan 28 19:56:37 2026 +0000
+++ b/training_pipeline.py	Fri Jan 30 14:20:49 2026 +0000
@@ -316,7 +316,7 @@
     eval_metric: Optional[str],
     threshold_test: Optional[float],
     df_test_external: Optional[pd.DataFrame] = None,
-) -> Tuple[Dict[str, Dict[str, float]], Dict[str, Dict[str, float]]]:
+) -> Tuple[Dict[str, Dict[str, float]], Dict[str, Dict[str, float]], Dict[str, dict]]:
     """
     Returns (raw_metrics, ag_scores_by_split)
       - raw_metrics: our transparent suite (threshold applied to Test/External Test only inside metrics_logic)
@@ -335,7 +335,7 @@
         ag_by_split["Test"] = ag_evaluate_safely(predictor, df_test_effective, metrics=metrics_req)
 
     # Transparent suite (threshold on Test handled inside metrics_logic)
-    _, raw_metrics = evaluate_all_transparency(
+    _, raw_metrics, roc_curves = evaluate_all_transparency(
         predictor=predictor,
         train_df=df_train,
         val_df=df_val,
@@ -346,12 +346,20 @@
     )
 
     if df_test_external is not None and df_test_external is not df_test and len(df_test_external):
-        raw_metrics["Test (external)"] = compute_metrics_for_split(
-            predictor, df_test_external, label_col, problem_type, threshold=threshold_test
+        ext_metrics, ext_curve = compute_metrics_for_split(
+            predictor,
+            df_test_external,
+            label_col,
+            problem_type,
+            threshold=threshold_test,
+            return_curve=True,
         )
+        raw_metrics["Test (external)"] = ext_metrics
+        if ext_curve:
+            roc_curves["Test (external)"] = ext_curve
         ag_by_split["Test (external)"] = ag_evaluate_safely(predictor, df_test_external, metrics=metrics_req)
 
-    return raw_metrics, ag_by_split
+    return raw_metrics, ag_by_split, roc_curves
 
 
 def fit_summary_safely(predictor) -> Optional[dict]: