Mercurial > repos > goeckslab > multimodal_learner
comparison test_pipeline.py @ 0:375c36923da1 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
| author | goeckslab |
|---|---|
| date | Tue, 09 Dec 2025 23:49:47 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:375c36923da1 |
|---|---|
| 1 from __future__ import annotations | |
| 2 | |
| 3 import logging | |
| 4 from typing import Dict, Optional | |
| 5 | |
| 6 import pandas as pd | |
| 7 from plot_logic import infer_problem_type | |
| 8 from training_pipeline import evaluate_predictor_all_splits, fit_summary_safely | |
| 9 | |
| 10 logger = logging.getLogger(__name__) | |
| 11 | |
| 12 | |
| 13 def run_autogluon_test_experiment( | |
| 14 predictor, | |
| 15 data_ctx: Dict[str, pd.DataFrame], | |
| 16 target_column: str, | |
| 17 eval_metric: Optional[str] = None, | |
| 18 ag_config: Optional[dict] = None, | |
| 19 problem_type: Optional[str] = None, | |
| 20 ) -> Dict[str, object]: | |
| 21 """ | |
| 22 Evaluate a trained predictor on train/val/test splits using prepared data_ctx. | |
| 23 | |
| 24 data_ctx is typically the context returned by ``run_autogluon_experiment``: | |
| 25 { | |
| 26 "train": df_train, | |
| 27 "val": df_val, | |
| 28 "test_internal": df_test_internal, | |
| 29 "test_external": df_test_external, | |
| 30 "threshold": threshold, | |
| 31 } | |
| 32 """ | |
| 33 if predictor is None: | |
| 34 raise ValueError("predictor is required for evaluation.") | |
| 35 if data_ctx is None: | |
| 36 raise ValueError("data_ctx is required; usually from run_autogluon_experiment.") | |
| 37 | |
| 38 df_train = data_ctx.get("train") | |
| 39 df_val = data_ctx.get("val") | |
| 40 df_test_internal = data_ctx.get("test_internal") | |
| 41 df_test_external = data_ctx.get("test_external") | |
| 42 threshold = None | |
| 43 if ag_config is not None: | |
| 44 threshold = ag_config.get("threshold", threshold) | |
| 45 threshold = data_ctx.get("threshold", threshold) | |
| 46 | |
| 47 if problem_type is None: | |
| 48 # Prefer inferring from training data and predictor metadata | |
| 49 base_df = df_train if df_train is not None else df_test_external | |
| 50 problem_type = infer_problem_type(predictor, base_df, target_column) | |
| 51 | |
| 52 df_test_final = df_test_external if df_test_external is not None else df_test_internal | |
| 53 raw_metrics, ag_by_split = evaluate_predictor_all_splits( | |
| 54 predictor=predictor, | |
| 55 df_train=df_train, | |
| 56 df_val=df_val, | |
| 57 df_test=df_test_final, | |
| 58 label_col=target_column, | |
| 59 problem_type=problem_type, | |
| 60 eval_metric=eval_metric, | |
| 61 threshold_test=threshold, | |
| 62 df_test_external=df_test_external, | |
| 63 ) | |
| 64 | |
| 65 summary = fit_summary_safely(predictor) | |
| 66 | |
| 67 result = { | |
| 68 "problem_type": problem_type, | |
| 69 "raw_metrics": raw_metrics, | |
| 70 "ag_eval": ag_by_split, | |
| 71 "fit_summary": summary, | |
| 72 } | |
| 73 logger.info("Evaluation complete; splits: %s", list(raw_metrics.keys())) | |
| 74 return result |
