comparison test_pipeline.py @ 0:375c36923da1 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
author goeckslab
date Tue, 09 Dec 2025 23:49:47 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:375c36923da1
1 from __future__ import annotations
2
3 import logging
4 from typing import Dict, Optional
5
6 import pandas as pd
7 from plot_logic import infer_problem_type
8 from training_pipeline import evaluate_predictor_all_splits, fit_summary_safely
9
10 logger = logging.getLogger(__name__)
11
12
13 def run_autogluon_test_experiment(
14 predictor,
15 data_ctx: Dict[str, pd.DataFrame],
16 target_column: str,
17 eval_metric: Optional[str] = None,
18 ag_config: Optional[dict] = None,
19 problem_type: Optional[str] = None,
20 ) -> Dict[str, object]:
21 """
22 Evaluate a trained predictor on train/val/test splits using prepared data_ctx.
23
24 data_ctx is typically the context returned by ``run_autogluon_experiment``:
25 {
26 "train": df_train,
27 "val": df_val,
28 "test_internal": df_test_internal,
29 "test_external": df_test_external,
30 "threshold": threshold,
31 }
32 """
33 if predictor is None:
34 raise ValueError("predictor is required for evaluation.")
35 if data_ctx is None:
36 raise ValueError("data_ctx is required; usually from run_autogluon_experiment.")
37
38 df_train = data_ctx.get("train")
39 df_val = data_ctx.get("val")
40 df_test_internal = data_ctx.get("test_internal")
41 df_test_external = data_ctx.get("test_external")
42 threshold = None
43 if ag_config is not None:
44 threshold = ag_config.get("threshold", threshold)
45 threshold = data_ctx.get("threshold", threshold)
46
47 if problem_type is None:
48 # Prefer inferring from training data and predictor metadata
49 base_df = df_train if df_train is not None else df_test_external
50 problem_type = infer_problem_type(predictor, base_df, target_column)
51
52 df_test_final = df_test_external if df_test_external is not None else df_test_internal
53 raw_metrics, ag_by_split = evaluate_predictor_all_splits(
54 predictor=predictor,
55 df_train=df_train,
56 df_val=df_val,
57 df_test=df_test_final,
58 label_col=target_column,
59 problem_type=problem_type,
60 eval_metric=eval_metric,
61 threshold_test=threshold,
62 df_test_external=df_test_external,
63 )
64
65 summary = fit_summary_safely(predictor)
66
67 result = {
68 "problem_type": problem_type,
69 "raw_metrics": raw_metrics,
70 "ag_eval": ag_by_split,
71 "fit_summary": summary,
72 }
73 logger.info("Evaluation complete; splits: %s", list(raw_metrics.keys()))
74 return result