Mercurial > repos > goeckslab > multimodal_learner
diff test_pipeline.py @ 0:375c36923da1 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
| author | goeckslab |
|---|---|
| date | Tue, 09 Dec 2025 23:49:47 +0000 |
| parents | |
| children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/test_pipeline.py Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,74 @@ +from __future__ import annotations + +import logging +from typing import Dict, Optional + +import pandas as pd +from plot_logic import infer_problem_type +from training_pipeline import evaluate_predictor_all_splits, fit_summary_safely + +logger = logging.getLogger(__name__) + + +def run_autogluon_test_experiment( + predictor, + data_ctx: Dict[str, pd.DataFrame], + target_column: str, + eval_metric: Optional[str] = None, + ag_config: Optional[dict] = None, + problem_type: Optional[str] = None, +) -> Dict[str, object]: + """ + Evaluate a trained predictor on train/val/test splits using prepared data_ctx. + + data_ctx is typically the context returned by ``run_autogluon_experiment``: + { + "train": df_train, + "val": df_val, + "test_internal": df_test_internal, + "test_external": df_test_external, + "threshold": threshold, + } + """ + if predictor is None: + raise ValueError("predictor is required for evaluation.") + if data_ctx is None: + raise ValueError("data_ctx is required; usually from run_autogluon_experiment.") + + df_train = data_ctx.get("train") + df_val = data_ctx.get("val") + df_test_internal = data_ctx.get("test_internal") + df_test_external = data_ctx.get("test_external") + threshold = None + if ag_config is not None: + threshold = ag_config.get("threshold", threshold) + threshold = data_ctx.get("threshold", threshold) + + if problem_type is None: + # Prefer inferring from training data and predictor metadata + base_df = df_train if df_train is not None else df_test_external + problem_type = infer_problem_type(predictor, base_df, target_column) + + df_test_final = df_test_external if df_test_external is not None else df_test_internal + raw_metrics, ag_by_split = evaluate_predictor_all_splits( + predictor=predictor, + df_train=df_train, + df_val=df_val, + df_test=df_test_final, + label_col=target_column, + problem_type=problem_type, + eval_metric=eval_metric, + threshold_test=threshold, + df_test_external=df_test_external, + ) + + summary = fit_summary_safely(predictor) + + result = { + "problem_type": problem_type, + "raw_metrics": raw_metrics, + "ag_eval": ag_by_split, + "fit_summary": summary, + } + logger.info("Evaluation complete; splits: %s", list(raw_metrics.keys())) + return result
