Mercurial > repos > goeckslab > multimodal_learner
comparison training_pipeline.py @ 8:a48e750cfd25 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit c8a7fef0c54c269afd6c6bdf035af1a7574d11cb
| author | goeckslab |
|---|---|
| date | Fri, 30 Jan 2026 14:20:49 +0000 |
| parents | 871957823d0c |
| children |
comparison
equal
deleted
inserted
replaced
| 7:ed2fefc8d892 | 8:a48e750cfd25 |
|---|---|
| 314 label_col: str, | 314 label_col: str, |
| 315 problem_type: str, | 315 problem_type: str, |
| 316 eval_metric: Optional[str], | 316 eval_metric: Optional[str], |
| 317 threshold_test: Optional[float], | 317 threshold_test: Optional[float], |
| 318 df_test_external: Optional[pd.DataFrame] = None, | 318 df_test_external: Optional[pd.DataFrame] = None, |
| 319 ) -> Tuple[Dict[str, Dict[str, float]], Dict[str, Dict[str, float]]]: | 319 ) -> Tuple[Dict[str, Dict[str, float]], Dict[str, Dict[str, float]], Dict[str, dict]]: |
| 320 """ | 320 """ |
| 321 Returns (raw_metrics, ag_scores_by_split) | 321 Returns (raw_metrics, ag_scores_by_split) |
| 322 - raw_metrics: our transparent suite (threshold applied to Test/External Test only inside metrics_logic) | 322 - raw_metrics: our transparent suite (threshold applied to Test/External Test only inside metrics_logic) |
| 323 - ag_scores_by_split: AutoGluon's evaluate() per split for the chosen eval_metric (or default) | 323 - ag_scores_by_split: AutoGluon's evaluate() per split for the chosen eval_metric (or default) |
| 324 """ | 324 """ |
| 333 df_test_effective = df_test_external if df_test_external is not None else df_test | 333 df_test_effective = df_test_external if df_test_external is not None else df_test |
| 334 if df_test_effective is not None and len(df_test_effective): | 334 if df_test_effective is not None and len(df_test_effective): |
| 335 ag_by_split["Test"] = ag_evaluate_safely(predictor, df_test_effective, metrics=metrics_req) | 335 ag_by_split["Test"] = ag_evaluate_safely(predictor, df_test_effective, metrics=metrics_req) |
| 336 | 336 |
| 337 # Transparent suite (threshold on Test handled inside metrics_logic) | 337 # Transparent suite (threshold on Test handled inside metrics_logic) |
| 338 _, raw_metrics = evaluate_all_transparency( | 338 _, raw_metrics, roc_curves = evaluate_all_transparency( |
| 339 predictor=predictor, | 339 predictor=predictor, |
| 340 train_df=df_train, | 340 train_df=df_train, |
| 341 val_df=df_val, | 341 val_df=df_val, |
| 342 test_df=df_test_effective, | 342 test_df=df_test_effective, |
| 343 target_col=label_col, | 343 target_col=label_col, |
| 344 problem_type=problem_type, | 344 problem_type=problem_type, |
| 345 threshold=threshold_test, | 345 threshold=threshold_test, |
| 346 ) | 346 ) |
| 347 | 347 |
| 348 if df_test_external is not None and df_test_external is not df_test and len(df_test_external): | 348 if df_test_external is not None and df_test_external is not df_test and len(df_test_external): |
| 349 raw_metrics["Test (external)"] = compute_metrics_for_split( | 349 ext_metrics, ext_curve = compute_metrics_for_split( |
| 350 predictor, df_test_external, label_col, problem_type, threshold=threshold_test | 350 predictor, |
| 351 df_test_external, | |
| 352 label_col, | |
| 353 problem_type, | |
| 354 threshold=threshold_test, | |
| 355 return_curve=True, | |
| 351 ) | 356 ) |
| 357 raw_metrics["Test (external)"] = ext_metrics | |
| 358 if ext_curve: | |
| 359 roc_curves["Test (external)"] = ext_curve | |
| 352 ag_by_split["Test (external)"] = ag_evaluate_safely(predictor, df_test_external, metrics=metrics_req) | 360 ag_by_split["Test (external)"] = ag_evaluate_safely(predictor, df_test_external, metrics=metrics_req) |
| 353 | 361 |
| 354 return raw_metrics, ag_by_split | 362 return raw_metrics, ag_by_split, roc_curves |
| 355 | 363 |
| 356 | 364 |
| 357 def fit_summary_safely(predictor) -> Optional[dict]: | 365 def fit_summary_safely(predictor) -> Optional[dict]: |
| 358 """Get fit summary without printing misleading one-liners.""" | 366 """Get fit summary without printing misleading one-liners.""" |
| 359 with suppress_stdout_stderr(): | 367 with suppress_stdout_stderr(): |
