comparison training_pipeline.py @ 8:a48e750cfd25 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit c8a7fef0c54c269afd6c6bdf035af1a7574d11cb
author goeckslab
date Fri, 30 Jan 2026 14:20:49 +0000
parents 871957823d0c
children
comparison
equal deleted inserted replaced
7:ed2fefc8d892 8:a48e750cfd25
314 label_col: str, 314 label_col: str,
315 problem_type: str, 315 problem_type: str,
316 eval_metric: Optional[str], 316 eval_metric: Optional[str],
317 threshold_test: Optional[float], 317 threshold_test: Optional[float],
318 df_test_external: Optional[pd.DataFrame] = None, 318 df_test_external: Optional[pd.DataFrame] = None,
319 ) -> Tuple[Dict[str, Dict[str, float]], Dict[str, Dict[str, float]]]: 319 ) -> Tuple[Dict[str, Dict[str, float]], Dict[str, Dict[str, float]], Dict[str, dict]]:
320 """ 320 """
321 Returns (raw_metrics, ag_scores_by_split) 321 Returns (raw_metrics, ag_scores_by_split)
322 - raw_metrics: our transparent suite (threshold applied to Test/External Test only inside metrics_logic) 322 - raw_metrics: our transparent suite (threshold applied to Test/External Test only inside metrics_logic)
323 - ag_scores_by_split: AutoGluon's evaluate() per split for the chosen eval_metric (or default) 323 - ag_scores_by_split: AutoGluon's evaluate() per split for the chosen eval_metric (or default)
324 """ 324 """
333 df_test_effective = df_test_external if df_test_external is not None else df_test 333 df_test_effective = df_test_external if df_test_external is not None else df_test
334 if df_test_effective is not None and len(df_test_effective): 334 if df_test_effective is not None and len(df_test_effective):
335 ag_by_split["Test"] = ag_evaluate_safely(predictor, df_test_effective, metrics=metrics_req) 335 ag_by_split["Test"] = ag_evaluate_safely(predictor, df_test_effective, metrics=metrics_req)
336 336
337 # Transparent suite (threshold on Test handled inside metrics_logic) 337 # Transparent suite (threshold on Test handled inside metrics_logic)
338 _, raw_metrics = evaluate_all_transparency( 338 _, raw_metrics, roc_curves = evaluate_all_transparency(
339 predictor=predictor, 339 predictor=predictor,
340 train_df=df_train, 340 train_df=df_train,
341 val_df=df_val, 341 val_df=df_val,
342 test_df=df_test_effective, 342 test_df=df_test_effective,
343 target_col=label_col, 343 target_col=label_col,
344 problem_type=problem_type, 344 problem_type=problem_type,
345 threshold=threshold_test, 345 threshold=threshold_test,
346 ) 346 )
347 347
348 if df_test_external is not None and df_test_external is not df_test and len(df_test_external): 348 if df_test_external is not None and df_test_external is not df_test and len(df_test_external):
349 raw_metrics["Test (external)"] = compute_metrics_for_split( 349 ext_metrics, ext_curve = compute_metrics_for_split(
350 predictor, df_test_external, label_col, problem_type, threshold=threshold_test 350 predictor,
351 df_test_external,
352 label_col,
353 problem_type,
354 threshold=threshold_test,
355 return_curve=True,
351 ) 356 )
357 raw_metrics["Test (external)"] = ext_metrics
358 if ext_curve:
359 roc_curves["Test (external)"] = ext_curve
352 ag_by_split["Test (external)"] = ag_evaluate_safely(predictor, df_test_external, metrics=metrics_req) 360 ag_by_split["Test (external)"] = ag_evaluate_safely(predictor, df_test_external, metrics=metrics_req)
353 361
354 return raw_metrics, ag_by_split 362 return raw_metrics, ag_by_split, roc_curves
355 363
356 364
357 def fit_summary_safely(predictor) -> Optional[dict]: 365 def fit_summary_safely(predictor) -> Optional[dict]:
358 """Get fit summary without printing misleading one-liners.""" 366 """Get fit summary without printing misleading one-liners."""
359 with suppress_stdout_stderr(): 367 with suppress_stdout_stderr():