comparison training_pipeline.py @ 0:375c36923da1 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
author goeckslab
date Tue, 09 Dec 2025 23:49:47 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:375c36923da1
1 from __future__ import annotations
2
3 import contextlib
4 import importlib
5 import io
6 import json
7 import logging
8 import os
9 import tempfile
10 import uuid
11 from pathlib import Path
12 from typing import Dict, List, Optional, Tuple
13
14 import numpy as np
15 import pandas as pd
16 import torch
17 from autogluon.multimodal import MultiModalPredictor
18 from metrics_logic import compute_metrics_for_split, evaluate_all_transparency
19 from packaging.version import Version
20
21 logger = logging.getLogger(__name__)
22
23 # ---------------------- small utilities ----------------------
24
25
26 def load_user_hparams(hp_arg: Optional[str]) -> dict:
27 """Parse --hyperparameters (inline JSON or path to .json)."""
28 if not hp_arg:
29 return {}
30 try:
31 s = hp_arg.strip()
32 if s.startswith("{"):
33 return json.loads(s)
34 with open(s, "r") as f:
35 return json.load(f)
36 except Exception as e:
37 logger.warning(f"Could not parse --hyperparameters: {e}. Ignoring.")
38 return {}
39
40
41 def deep_update(dst: dict, src: dict) -> dict:
42 """Recursive dict update (src overrides dst)."""
43 for k, v in (src or {}).items():
44 if isinstance(v, dict) and isinstance(dst.get(k), dict):
45 deep_update(dst[k], v)
46 else:
47 dst[k] = v
48 return dst
49
50
51 @contextlib.contextmanager
52 def suppress_stdout_stderr():
53 """Silence noisy prints from AG internals (fit_summary)."""
54 with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
55 yield
56
57
58 def ag_evaluate_safely(predictor, df: Optional[pd.DataFrame], metrics: Optional[List[str]] = None) -> Dict[str, float]:
59 """
60 Call predictor.evaluate and normalize the output to a dict.
61 """
62 if df is None or len(df) == 0:
63 return {}
64 try:
65 res = predictor.evaluate(df, metrics=metrics)
66 except TypeError:
67 if metrics and len(metrics) == 1:
68 res = predictor.evaluate(df, metrics[0])
69 else:
70 res = predictor.evaluate(df)
71 if isinstance(res, (int, float, np.floating)):
72 name = (metrics[0] if metrics else "metric")
73 return {name: float(res)}
74 if isinstance(res, dict):
75 return {k: float(v) for k, v in res.items()}
76 return {"metric": float(res)}
77
78
79 # ---------------------- hparams & training ----------------------
80 def build_mm_hparams(args, df_train: pd.DataFrame, image_columns: Optional[List[str]]) -> dict:
81 """
82 Build hyperparameters for MultiModalPredictor.
83 Handles text checkpoints for torch<2.6 and merges user overrides.
84 """
85 inferred_text_cols = [
86 c for c in df_train.columns
87 if c != args.label_column
88 and str(df_train[c].dtype) == "object"
89 and df_train[c].notna().any()
90 ]
91 text_cols = inferred_text_cols
92
93 ag_version = None
94 try:
95 ag_mod = importlib.import_module("autogluon")
96 ag_ver = getattr(ag_mod, "__version__", None)
97 if ag_ver:
98 ag_version = Version(str(ag_ver))
99 except Exception:
100 ag_mod = None
101
102 def _log_missing_support(key: str) -> None:
103 logger.info(
104 "AutoGluon version %s does not expose '%s'; skipping override.",
105 ag_version or "unknown",
106 key,
107 )
108
109 hp = {}
110
111 # Setup environment
112 hp["env"] = {
113 "seed": int(args.random_seed)
114 }
115
116 # Set eval metric through model config
117 model_block = hp.setdefault("model", {})
118 if args.eval_metric:
119 model_block.setdefault("metric_learning", {})["metric"] = str(args.eval_metric)
120
121 if text_cols and Version(torch.__version__) < Version("2.6"):
122 safe_ckpt = "distilbert-base-uncased"
123 logger.warning(f"Forcing HF text checkpoint with safetensors: {safe_ckpt}")
124 hp["model.hf_text.checkpoint_name"] = safe_ckpt
125 hp.setdefault(
126 "model.names",
127 ["hf_text", "timm_image", "numerical_mlp", "categorical_mlp", "fusion_mlp"],
128 )
129
130 def _is_valid_hp_dict(d) -> bool:
131 if not isinstance(d, dict):
132 logger.warning("User-supplied hyperparameters must be a dict; received %s", type(d).__name__)
133 return False
134 return True
135
136 user_hp = args.hyperparameters if isinstance(args.hyperparameters, dict) else load_user_hparams(args.hyperparameters)
137 if user_hp and _is_valid_hp_dict(user_hp):
138 hp = deep_update(hp, user_hp)
139
140 # Map CLI knobs into AutoMM optimization hyperparameters when provided.
141 # We set multiple common key names (nested dicts and dotted flat keys) to
142 # maximize compatibility across AutoMM/AutoGluon versions.
143 try:
144 if any(getattr(args, param, None) is not None for param in ["epochs", "learning_rate", "batch_size"]):
145 if getattr(args, "epochs", None) is not None:
146 hp["optim.max_epochs"] = int(args.epochs)
147 hp["optim.epochs"] = int(args.epochs)
148 if getattr(args, "learning_rate", None) is not None:
149 hp["optim.learning_rate"] = float(args.learning_rate)
150 hp["optim.lr"] = float(args.learning_rate)
151 if getattr(args, "batch_size", None) is not None:
152 hp["optim.batch_size"] = int(args.batch_size)
153 hp["optim.per_device_train_batch_size"] = int(args.batch_size)
154
155 # Also set dotted flat keys for max compatibility (e.g., 'optimization.max_epochs')
156 if getattr(args, "epochs", None) is not None:
157 hp["optimization.max_epochs"] = int(args.epochs)
158 hp["optimization.epochs"] = int(args.epochs)
159 if getattr(args, "learning_rate", None) is not None:
160 hp["optimization.learning_rate"] = float(args.learning_rate)
161 hp["optimization.lr"] = float(args.learning_rate)
162 if getattr(args, "batch_size", None) is not None:
163 hp["optimization.batch_size"] = int(args.batch_size)
164 hp["optimization.per_device_train_batch_size"] = int(args.batch_size)
165 except Exception:
166 logger.warning("Failed to attach epochs/learning_rate/batch_size to mm_hparams; continuing without them.")
167
168 # Map backbone selections into mm_hparams if provided
169 try:
170 has_text_cols = bool(text_cols)
171 has_image_cols = False
172 model_names_cache: Optional[List[str]] = None
173 model_names_modified = False
174
175 def _dedupe_preserve(seq: List[str]) -> List[str]:
176 seen = set()
177 ordered = []
178 for item in seq:
179 if item in seen:
180 continue
181 seen.add(item)
182 ordered.append(item)
183 return ordered
184
185 def _get_model_names() -> List[str]:
186 nonlocal model_names_cache
187 if model_names_cache is not None:
188 return model_names_cache
189 names = model_block.get("names")
190 if isinstance(names, list):
191 model_names_cache = list(names)
192 else:
193 model_names_cache = []
194 if has_text_cols:
195 model_names_cache.append("hf_text")
196 if has_image_cols:
197 model_names_cache.append("timm_image")
198 model_names_cache.extend(["numerical_mlp", "categorical_mlp"])
199 model_names_cache.append("fusion_mlp")
200 return model_names_cache
201
202 def _set_model_names(new_names: List[str]) -> None:
203 nonlocal model_names_cache, model_names_modified
204 model_names_cache = new_names
205 model_names_modified = True
206
207 if has_text_cols and getattr(args, "backbone_text", None):
208 text_choice = str(args.backbone_text)
209 model_block.setdefault("hf_text", {})["checkpoint_name"] = text_choice
210 hp["model.hf_text.checkpoint_name"] = text_choice
211 if has_image_cols and getattr(args, "backbone_image", None):
212 image_choice = str(args.backbone_image)
213 model_block.setdefault("timm_image", {})["checkpoint_name"] = image_choice
214 hp["model.timm_image.checkpoint_name"] = image_choice
215 if model_names_modified and model_names_cache is not None:
216 model_block["names"] = model_names_cache
217 except Exception:
218 logger.warning("Failed to attach backbone selections to mm_hparams; continuing without them.")
219
220 if ag_version:
221 logger.info(f"Detected AutoGluon version: {ag_version}; applied robust hyperparameter mappings.")
222
223 return hp
224
225
226 def train_predictor(
227 args,
228 df_train: pd.DataFrame,
229 df_val: pd.DataFrame,
230 image_columns: Optional[List[str]],
231 mm_hparams: dict,
232 ):
233 """
234 Train a MultiModalPredictor, honoring common knobs (presets, eval_metric, etc.).
235 """
236 logger.info("Starting AutoGluon MultiModal training...")
237 predictor = MultiModalPredictor(label=args.label_column, path=None)
238 column_types = {}
239
240 mm_fit_kwargs = dict(
241 train_data=df_train,
242 time_limit=args.time_limit,
243 seed=int(args.random_seed),
244 hyperparameters=mm_hparams,
245 )
246 if df_val is not None and not df_val.empty:
247 mm_fit_kwargs["tuning_data"] = df_val
248 if column_types:
249 mm_fit_kwargs["column_types"] = column_types
250
251 preset_mm = getattr(args, "presets", None)
252 if preset_mm is None:
253 preset_mm = getattr(args, "preset", None)
254 if preset_mm is not None:
255 mm_fit_kwargs["presets"] = preset_mm
256
257 predictor.fit(**mm_fit_kwargs)
258 return predictor
259
260
261 # ---------------------- evaluation ----------------------
262 def evaluate_predictor_all_splits(
263 predictor,
264 df_train: Optional[pd.DataFrame],
265 df_val: Optional[pd.DataFrame],
266 df_test: Optional[pd.DataFrame],
267 label_col: str,
268 problem_type: str,
269 eval_metric: Optional[str],
270 threshold_test: Optional[float],
271 df_test_external: Optional[pd.DataFrame] = None,
272 ) -> Tuple[Dict[str, Dict[str, float]], Dict[str, Dict[str, float]]]:
273 """
274 Returns (raw_metrics, ag_scores_by_split)
275 - raw_metrics: our transparent suite (threshold applied to Test/External Test only inside metrics_logic)
276 - ag_scores_by_split: AutoGluon's evaluate() per split for the chosen eval_metric (or default)
277 """
278 metrics_req = None if (eval_metric is None or str(eval_metric).lower() == "auto") else [eval_metric]
279 ag_by_split: Dict[str, Dict[str, float]] = {}
280
281 if df_train is not None and len(df_train):
282 ag_by_split["Train"] = ag_evaluate_safely(predictor, df_train, metrics=metrics_req)
283 if df_val is not None and len(df_val):
284 ag_by_split["Validation"] = ag_evaluate_safely(predictor, df_val, metrics=metrics_req)
285
286 df_test_effective = df_test_external if df_test_external is not None else df_test
287 if df_test_effective is not None and len(df_test_effective):
288 ag_by_split["Test"] = ag_evaluate_safely(predictor, df_test_effective, metrics=metrics_req)
289
290 # Transparent suite (threshold on Test handled inside metrics_logic)
291 _, raw_metrics = evaluate_all_transparency(
292 predictor=predictor,
293 train_df=df_train,
294 val_df=df_val,
295 test_df=df_test_effective,
296 target_col=label_col,
297 problem_type=problem_type,
298 threshold=threshold_test,
299 )
300
301 if df_test_external is not None and df_test_external is not df_test and len(df_test_external):
302 raw_metrics["Test (external)"] = compute_metrics_for_split(
303 predictor, df_test_external, label_col, problem_type, threshold=threshold_test
304 )
305 ag_by_split["Test (external)"] = ag_evaluate_safely(predictor, df_test_external, metrics=metrics_req)
306
307 return raw_metrics, ag_by_split
308
309
310 def fit_summary_safely(predictor) -> Optional[dict]:
311 """Get fit summary without printing misleading one-liners."""
312 with suppress_stdout_stderr():
313 try:
314 return predictor.fit_summary()
315 except Exception:
316 return None
317
318
319 # ---------------------- image helpers ----------------------
320 _PLACEHOLDER_PATH = None
321
322
323 def _create_placeholder() -> str:
324 global _PLACEHOLDER_PATH
325 if _PLACEHOLDER_PATH and os.path.exists(_PLACEHOLDER_PATH):
326 return _PLACEHOLDER_PATH
327
328 dir_ = Path(tempfile.mkdtemp(prefix="ag_placeholder_"))
329 file_ = dir_ / f"placeholder_{uuid.uuid4().hex}.png"
330
331 try:
332 from PIL import Image
333 Image.new("RGB", (64, 64), (180, 180, 180)).save(file_)
334 except Exception:
335 import matplotlib.pyplot as plt
336 import numpy as np
337 plt.imsave(file_, np.full((64, 64, 3), 180, dtype=np.uint8))
338 plt.close("all")
339
340 _PLACEHOLDER_PATH = str(file_)
341 logger.info(f"Placeholder image created: {file_}")
342 return _PLACEHOLDER_PATH
343
344
345 def _is_valid_path(val) -> bool:
346 if pd.isna(val):
347 return False
348 s = str(val).strip()
349 return s and os.path.isfile(s)
350
351
352 def handle_missing_images(
353 df: pd.DataFrame,
354 image_columns: List[str],
355 strategy: str = "false",
356 ) -> pd.DataFrame:
357 if not image_columns or df.empty:
358 return df
359
360 remove = str(strategy).lower() == "true"
361 masks = [~df[col].apply(_is_valid_path) for col in image_columns if col in df.columns]
362 if not masks:
363 return df
364
365 any_missing = pd.concat(masks, axis=1).any(axis=1)
366 n_missing = int(any_missing.sum())
367
368 if n_missing == 0:
369 return df
370
371 if remove:
372 result = df[~any_missing].reset_index(drop=True)
373 logger.info(f"Dropped {n_missing} rows with missing images → {len(result)} remain")
374 else:
375 placeholder = _create_placeholder()
376 result = df.copy()
377 for col in image_columns:
378 if col in result.columns:
379 result.loc[~result[col].apply(_is_valid_path), col] = placeholder
380 logger.info(f"Filled {n_missing} missing images with placeholder")
381
382 return result
383
384
385 # ---------------------- AutoGluon config helpers ----------------------
386 def autogluon_hyperparameters(
387 threshold,
388 time_limit,
389 random_seed,
390 epochs,
391 learning_rate,
392 batch_size,
393 backbone_image,
394 backbone_text,
395 preset,
396 eval_metric,
397 hyperparameters,
398 ):
399 """
400 Build a MultiModalPredictor configuration (fit kwargs + hyperparameters) from CLI inputs.
401 The returned dict separates what should be passed to predictor.fit (under ``fit``)
402 from the model/optimization configuration (under ``hyperparameters``). Threshold is
403 preserved for downstream evaluation but not passed into AutoGluon directly.
404 """
405
406 def _prune_empty(d: dict) -> dict:
407 cleaned = {}
408 for k, v in (d or {}).items():
409 if isinstance(v, dict):
410 nested = _prune_empty(v)
411 if nested:
412 cleaned[k] = nested
413 elif v is not None:
414 cleaned[k] = v
415 return cleaned
416
417 # Base hyperparameters following the structure described in the AutoGluon
418 # customization guide (env / optimization / model).
419 env_cfg = {}
420 if random_seed is not None:
421 env_cfg["seed"] = int(random_seed)
422 if batch_size is not None:
423 env_cfg["per_gpu_batch_size"] = int(batch_size)
424
425 optim_cfg = {}
426 if epochs is not None:
427 optim_cfg["max_epochs"] = int(epochs)
428 if learning_rate is not None:
429 optim_cfg["learning_rate"] = float(learning_rate)
430 if batch_size is not None:
431 bs = int(batch_size)
432 optim_cfg["per_device_train_batch_size"] = bs
433 optim_cfg["train_batch_size"] = bs
434
435 model_cfg = {}
436 if eval_metric:
437 model_cfg.setdefault("metric_learning", {})["metric"] = str(eval_metric)
438 if backbone_image:
439 model_cfg.setdefault("timm_image", {})["checkpoint_name"] = str(backbone_image)
440 if backbone_text:
441 model_cfg.setdefault("hf_text", {})["checkpoint_name"] = str(backbone_text)
442
443 hp = {
444 "env": env_cfg,
445 "optimization": optim_cfg,
446 "model": model_cfg,
447 }
448
449 # Also expose the most common dotted aliases for robustness across AG versions.
450 if epochs is not None:
451 hp["optimization.max_epochs"] = int(epochs)
452 hp["optim.max_epochs"] = int(epochs)
453 if learning_rate is not None:
454 lr_val = float(learning_rate)
455 hp["optimization.learning_rate"] = lr_val
456 hp["optimization.lr"] = lr_val
457 hp["optim.learning_rate"] = lr_val
458 hp["optim.lr"] = lr_val
459 if batch_size is not None:
460 bs_val = int(batch_size)
461 hp["optimization.per_device_train_batch_size"] = bs_val
462 hp["optimization.batch_size"] = bs_val
463 hp["optim.per_device_train_batch_size"] = bs_val
464 hp["optim.batch_size"] = bs_val
465 hp["env.per_gpu_batch_size"] = bs_val
466 if backbone_image:
467 hp["model.timm_image.checkpoint_name"] = str(backbone_image)
468 if backbone_text:
469 hp["model.hf_text.checkpoint_name"] = str(backbone_text)
470
471 # Merge user-provided hyperparameters (inline JSON or path) last so they win.
472 if isinstance(hyperparameters, dict):
473 user_hp = hyperparameters
474 else:
475 user_hp = load_user_hparams(hyperparameters)
476 hp = deep_update(hp, user_hp)
477 hp = _prune_empty(hp)
478
479 fit_cfg = {}
480 if time_limit is not None:
481 fit_cfg["time_limit"] = time_limit
482 if random_seed is not None:
483 fit_cfg["seed"] = int(random_seed)
484 if preset:
485 fit_cfg["presets"] = preset
486
487 config = {
488 "fit": fit_cfg,
489 "hyperparameters": hp,
490 }
491 if threshold is not None:
492 config["threshold"] = float(threshold)
493
494 return config
495
496
497 def run_autogluon_experiment(
498 train_dataset: pd.DataFrame,
499 test_dataset: Optional[pd.DataFrame],
500 target_column: str,
501 image_columns: Optional[List[str]],
502 ag_config: dict,
503 ):
504 """
505 Launch an AutoGluon MultiModal training run using the config from
506 autogluon_hyperparameters(). Returns (predictor, context dict) so callers
507 can evaluate downstream with the chosen threshold.
508 """
509 if ag_config is None:
510 raise ValueError("ag_config is required to launch AutoGluon training.")
511
512 hyperparameters = ag_config.get("hyperparameters") or {}
513 fit_cfg = dict(ag_config.get("fit") or {})
514 threshold = ag_config.get("threshold")
515
516 if "split" not in train_dataset.columns:
517 raise ValueError("train_dataset must contain a 'split' column. Did you call split_dataset?")
518
519 df_train = train_dataset[train_dataset["split"] == "train"].copy()
520 df_val = train_dataset[train_dataset["split"].isin(["val", "validation"])].copy()
521 df_test_internal = train_dataset[train_dataset["split"] == "test"].copy()
522
523 predictor = MultiModalPredictor(label=target_column, path=None)
524 column_types = {c: "image_path" for c in (image_columns or [])}
525
526 fit_kwargs = {
527 "train_data": df_train,
528 "hyperparameters": hyperparameters,
529 }
530 fit_kwargs.update(fit_cfg)
531 if not df_val.empty:
532 fit_kwargs.setdefault("tuning_data", df_val)
533 if column_types:
534 fit_kwargs.setdefault("column_types", column_types)
535
536 logger.info(
537 "Fitting AutoGluon with %d train / %d val rows (internal test rows: %d, external test provided: %s)",
538 len(df_train),
539 len(df_val),
540 len(df_test_internal),
541 (test_dataset is not None and not test_dataset.empty),
542 )
543 predictor.fit(**fit_kwargs)
544
545 return predictor, {
546 "train": df_train,
547 "val": df_val,
548 "test_internal": df_test_internal,
549 "test_external": test_dataset,
550 "threshold": threshold,
551 }