Mercurial > repos > goeckslab > multimodal_learner
diff training_pipeline.py @ 0:375c36923da1 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
| author | goeckslab |
|---|---|
| date | Tue, 09 Dec 2025 23:49:47 +0000 |
| parents | |
| children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/training_pipeline.py Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,551 @@ +from __future__ import annotations + +import contextlib +import importlib +import io +import json +import logging +import os +import tempfile +import uuid +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd +import torch +from autogluon.multimodal import MultiModalPredictor +from metrics_logic import compute_metrics_for_split, evaluate_all_transparency +from packaging.version import Version + +logger = logging.getLogger(__name__) + +# ---------------------- small utilities ---------------------- + + +def load_user_hparams(hp_arg: Optional[str]) -> dict: + """Parse --hyperparameters (inline JSON or path to .json).""" + if not hp_arg: + return {} + try: + s = hp_arg.strip() + if s.startswith("{"): + return json.loads(s) + with open(s, "r") as f: + return json.load(f) + except Exception as e: + logger.warning(f"Could not parse --hyperparameters: {e}. Ignoring.") + return {} + + +def deep_update(dst: dict, src: dict) -> dict: + """Recursive dict update (src overrides dst).""" + for k, v in (src or {}).items(): + if isinstance(v, dict) and isinstance(dst.get(k), dict): + deep_update(dst[k], v) + else: + dst[k] = v + return dst + + +@contextlib.contextmanager +def suppress_stdout_stderr(): + """Silence noisy prints from AG internals (fit_summary).""" + with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()): + yield + + +def ag_evaluate_safely(predictor, df: Optional[pd.DataFrame], metrics: Optional[List[str]] = None) -> Dict[str, float]: + """ + Call predictor.evaluate and normalize the output to a dict. + """ + if df is None or len(df) == 0: + return {} + try: + res = predictor.evaluate(df, metrics=metrics) + except TypeError: + if metrics and len(metrics) == 1: + res = predictor.evaluate(df, metrics[0]) + else: + res = predictor.evaluate(df) + if isinstance(res, (int, float, np.floating)): + name = (metrics[0] if metrics else "metric") + return {name: float(res)} + if isinstance(res, dict): + return {k: float(v) for k, v in res.items()} + return {"metric": float(res)} + + +# ---------------------- hparams & training ---------------------- +def build_mm_hparams(args, df_train: pd.DataFrame, image_columns: Optional[List[str]]) -> dict: + """ + Build hyperparameters for MultiModalPredictor. + Handles text checkpoints for torch<2.6 and merges user overrides. + """ + inferred_text_cols = [ + c for c in df_train.columns + if c != args.label_column + and str(df_train[c].dtype) == "object" + and df_train[c].notna().any() + ] + text_cols = inferred_text_cols + + ag_version = None + try: + ag_mod = importlib.import_module("autogluon") + ag_ver = getattr(ag_mod, "__version__", None) + if ag_ver: + ag_version = Version(str(ag_ver)) + except Exception: + ag_mod = None + + def _log_missing_support(key: str) -> None: + logger.info( + "AutoGluon version %s does not expose '%s'; skipping override.", + ag_version or "unknown", + key, + ) + + hp = {} + + # Setup environment + hp["env"] = { + "seed": int(args.random_seed) + } + + # Set eval metric through model config + model_block = hp.setdefault("model", {}) + if args.eval_metric: + model_block.setdefault("metric_learning", {})["metric"] = str(args.eval_metric) + + if text_cols and Version(torch.__version__) < Version("2.6"): + safe_ckpt = "distilbert-base-uncased" + logger.warning(f"Forcing HF text checkpoint with safetensors: {safe_ckpt}") + hp["model.hf_text.checkpoint_name"] = safe_ckpt + hp.setdefault( + "model.names", + ["hf_text", "timm_image", "numerical_mlp", "categorical_mlp", "fusion_mlp"], + ) + + def _is_valid_hp_dict(d) -> bool: + if not isinstance(d, dict): + logger.warning("User-supplied hyperparameters must be a dict; received %s", type(d).__name__) + return False + return True + + user_hp = args.hyperparameters if isinstance(args.hyperparameters, dict) else load_user_hparams(args.hyperparameters) + if user_hp and _is_valid_hp_dict(user_hp): + hp = deep_update(hp, user_hp) + + # Map CLI knobs into AutoMM optimization hyperparameters when provided. + # We set multiple common key names (nested dicts and dotted flat keys) to + # maximize compatibility across AutoMM/AutoGluon versions. + try: + if any(getattr(args, param, None) is not None for param in ["epochs", "learning_rate", "batch_size"]): + if getattr(args, "epochs", None) is not None: + hp["optim.max_epochs"] = int(args.epochs) + hp["optim.epochs"] = int(args.epochs) + if getattr(args, "learning_rate", None) is not None: + hp["optim.learning_rate"] = float(args.learning_rate) + hp["optim.lr"] = float(args.learning_rate) + if getattr(args, "batch_size", None) is not None: + hp["optim.batch_size"] = int(args.batch_size) + hp["optim.per_device_train_batch_size"] = int(args.batch_size) + + # Also set dotted flat keys for max compatibility (e.g., 'optimization.max_epochs') + if getattr(args, "epochs", None) is not None: + hp["optimization.max_epochs"] = int(args.epochs) + hp["optimization.epochs"] = int(args.epochs) + if getattr(args, "learning_rate", None) is not None: + hp["optimization.learning_rate"] = float(args.learning_rate) + hp["optimization.lr"] = float(args.learning_rate) + if getattr(args, "batch_size", None) is not None: + hp["optimization.batch_size"] = int(args.batch_size) + hp["optimization.per_device_train_batch_size"] = int(args.batch_size) + except Exception: + logger.warning("Failed to attach epochs/learning_rate/batch_size to mm_hparams; continuing without them.") + + # Map backbone selections into mm_hparams if provided + try: + has_text_cols = bool(text_cols) + has_image_cols = False + model_names_cache: Optional[List[str]] = None + model_names_modified = False + + def _dedupe_preserve(seq: List[str]) -> List[str]: + seen = set() + ordered = [] + for item in seq: + if item in seen: + continue + seen.add(item) + ordered.append(item) + return ordered + + def _get_model_names() -> List[str]: + nonlocal model_names_cache + if model_names_cache is not None: + return model_names_cache + names = model_block.get("names") + if isinstance(names, list): + model_names_cache = list(names) + else: + model_names_cache = [] + if has_text_cols: + model_names_cache.append("hf_text") + if has_image_cols: + model_names_cache.append("timm_image") + model_names_cache.extend(["numerical_mlp", "categorical_mlp"]) + model_names_cache.append("fusion_mlp") + return model_names_cache + + def _set_model_names(new_names: List[str]) -> None: + nonlocal model_names_cache, model_names_modified + model_names_cache = new_names + model_names_modified = True + + if has_text_cols and getattr(args, "backbone_text", None): + text_choice = str(args.backbone_text) + model_block.setdefault("hf_text", {})["checkpoint_name"] = text_choice + hp["model.hf_text.checkpoint_name"] = text_choice + if has_image_cols and getattr(args, "backbone_image", None): + image_choice = str(args.backbone_image) + model_block.setdefault("timm_image", {})["checkpoint_name"] = image_choice + hp["model.timm_image.checkpoint_name"] = image_choice + if model_names_modified and model_names_cache is not None: + model_block["names"] = model_names_cache + except Exception: + logger.warning("Failed to attach backbone selections to mm_hparams; continuing without them.") + + if ag_version: + logger.info(f"Detected AutoGluon version: {ag_version}; applied robust hyperparameter mappings.") + + return hp + + +def train_predictor( + args, + df_train: pd.DataFrame, + df_val: pd.DataFrame, + image_columns: Optional[List[str]], + mm_hparams: dict, +): + """ + Train a MultiModalPredictor, honoring common knobs (presets, eval_metric, etc.). + """ + logger.info("Starting AutoGluon MultiModal training...") + predictor = MultiModalPredictor(label=args.label_column, path=None) + column_types = {} + + mm_fit_kwargs = dict( + train_data=df_train, + time_limit=args.time_limit, + seed=int(args.random_seed), + hyperparameters=mm_hparams, + ) + if df_val is not None and not df_val.empty: + mm_fit_kwargs["tuning_data"] = df_val + if column_types: + mm_fit_kwargs["column_types"] = column_types + + preset_mm = getattr(args, "presets", None) + if preset_mm is None: + preset_mm = getattr(args, "preset", None) + if preset_mm is not None: + mm_fit_kwargs["presets"] = preset_mm + + predictor.fit(**mm_fit_kwargs) + return predictor + + +# ---------------------- evaluation ---------------------- +def evaluate_predictor_all_splits( + predictor, + df_train: Optional[pd.DataFrame], + df_val: Optional[pd.DataFrame], + df_test: Optional[pd.DataFrame], + label_col: str, + problem_type: str, + eval_metric: Optional[str], + threshold_test: Optional[float], + df_test_external: Optional[pd.DataFrame] = None, +) -> Tuple[Dict[str, Dict[str, float]], Dict[str, Dict[str, float]]]: + """ + Returns (raw_metrics, ag_scores_by_split) + - raw_metrics: our transparent suite (threshold applied to Test/External Test only inside metrics_logic) + - ag_scores_by_split: AutoGluon's evaluate() per split for the chosen eval_metric (or default) + """ + metrics_req = None if (eval_metric is None or str(eval_metric).lower() == "auto") else [eval_metric] + ag_by_split: Dict[str, Dict[str, float]] = {} + + if df_train is not None and len(df_train): + ag_by_split["Train"] = ag_evaluate_safely(predictor, df_train, metrics=metrics_req) + if df_val is not None and len(df_val): + ag_by_split["Validation"] = ag_evaluate_safely(predictor, df_val, metrics=metrics_req) + + df_test_effective = df_test_external if df_test_external is not None else df_test + if df_test_effective is not None and len(df_test_effective): + ag_by_split["Test"] = ag_evaluate_safely(predictor, df_test_effective, metrics=metrics_req) + + # Transparent suite (threshold on Test handled inside metrics_logic) + _, raw_metrics = evaluate_all_transparency( + predictor=predictor, + train_df=df_train, + val_df=df_val, + test_df=df_test_effective, + target_col=label_col, + problem_type=problem_type, + threshold=threshold_test, + ) + + if df_test_external is not None and df_test_external is not df_test and len(df_test_external): + raw_metrics["Test (external)"] = compute_metrics_for_split( + predictor, df_test_external, label_col, problem_type, threshold=threshold_test + ) + ag_by_split["Test (external)"] = ag_evaluate_safely(predictor, df_test_external, metrics=metrics_req) + + return raw_metrics, ag_by_split + + +def fit_summary_safely(predictor) -> Optional[dict]: + """Get fit summary without printing misleading one-liners.""" + with suppress_stdout_stderr(): + try: + return predictor.fit_summary() + except Exception: + return None + + +# ---------------------- image helpers ---------------------- +_PLACEHOLDER_PATH = None + + +def _create_placeholder() -> str: + global _PLACEHOLDER_PATH + if _PLACEHOLDER_PATH and os.path.exists(_PLACEHOLDER_PATH): + return _PLACEHOLDER_PATH + + dir_ = Path(tempfile.mkdtemp(prefix="ag_placeholder_")) + file_ = dir_ / f"placeholder_{uuid.uuid4().hex}.png" + + try: + from PIL import Image + Image.new("RGB", (64, 64), (180, 180, 180)).save(file_) + except Exception: + import matplotlib.pyplot as plt + import numpy as np + plt.imsave(file_, np.full((64, 64, 3), 180, dtype=np.uint8)) + plt.close("all") + + _PLACEHOLDER_PATH = str(file_) + logger.info(f"Placeholder image created: {file_}") + return _PLACEHOLDER_PATH + + +def _is_valid_path(val) -> bool: + if pd.isna(val): + return False + s = str(val).strip() + return s and os.path.isfile(s) + + +def handle_missing_images( + df: pd.DataFrame, + image_columns: List[str], + strategy: str = "false", +) -> pd.DataFrame: + if not image_columns or df.empty: + return df + + remove = str(strategy).lower() == "true" + masks = [~df[col].apply(_is_valid_path) for col in image_columns if col in df.columns] + if not masks: + return df + + any_missing = pd.concat(masks, axis=1).any(axis=1) + n_missing = int(any_missing.sum()) + + if n_missing == 0: + return df + + if remove: + result = df[~any_missing].reset_index(drop=True) + logger.info(f"Dropped {n_missing} rows with missing images → {len(result)} remain") + else: + placeholder = _create_placeholder() + result = df.copy() + for col in image_columns: + if col in result.columns: + result.loc[~result[col].apply(_is_valid_path), col] = placeholder + logger.info(f"Filled {n_missing} missing images with placeholder") + + return result + + +# ---------------------- AutoGluon config helpers ---------------------- +def autogluon_hyperparameters( + threshold, + time_limit, + random_seed, + epochs, + learning_rate, + batch_size, + backbone_image, + backbone_text, + preset, + eval_metric, + hyperparameters, +): + """ + Build a MultiModalPredictor configuration (fit kwargs + hyperparameters) from CLI inputs. + The returned dict separates what should be passed to predictor.fit (under ``fit``) + from the model/optimization configuration (under ``hyperparameters``). Threshold is + preserved for downstream evaluation but not passed into AutoGluon directly. + """ + + def _prune_empty(d: dict) -> dict: + cleaned = {} + for k, v in (d or {}).items(): + if isinstance(v, dict): + nested = _prune_empty(v) + if nested: + cleaned[k] = nested + elif v is not None: + cleaned[k] = v + return cleaned + + # Base hyperparameters following the structure described in the AutoGluon + # customization guide (env / optimization / model). + env_cfg = {} + if random_seed is not None: + env_cfg["seed"] = int(random_seed) + if batch_size is not None: + env_cfg["per_gpu_batch_size"] = int(batch_size) + + optim_cfg = {} + if epochs is not None: + optim_cfg["max_epochs"] = int(epochs) + if learning_rate is not None: + optim_cfg["learning_rate"] = float(learning_rate) + if batch_size is not None: + bs = int(batch_size) + optim_cfg["per_device_train_batch_size"] = bs + optim_cfg["train_batch_size"] = bs + + model_cfg = {} + if eval_metric: + model_cfg.setdefault("metric_learning", {})["metric"] = str(eval_metric) + if backbone_image: + model_cfg.setdefault("timm_image", {})["checkpoint_name"] = str(backbone_image) + if backbone_text: + model_cfg.setdefault("hf_text", {})["checkpoint_name"] = str(backbone_text) + + hp = { + "env": env_cfg, + "optimization": optim_cfg, + "model": model_cfg, + } + + # Also expose the most common dotted aliases for robustness across AG versions. + if epochs is not None: + hp["optimization.max_epochs"] = int(epochs) + hp["optim.max_epochs"] = int(epochs) + if learning_rate is not None: + lr_val = float(learning_rate) + hp["optimization.learning_rate"] = lr_val + hp["optimization.lr"] = lr_val + hp["optim.learning_rate"] = lr_val + hp["optim.lr"] = lr_val + if batch_size is not None: + bs_val = int(batch_size) + hp["optimization.per_device_train_batch_size"] = bs_val + hp["optimization.batch_size"] = bs_val + hp["optim.per_device_train_batch_size"] = bs_val + hp["optim.batch_size"] = bs_val + hp["env.per_gpu_batch_size"] = bs_val + if backbone_image: + hp["model.timm_image.checkpoint_name"] = str(backbone_image) + if backbone_text: + hp["model.hf_text.checkpoint_name"] = str(backbone_text) + + # Merge user-provided hyperparameters (inline JSON or path) last so they win. + if isinstance(hyperparameters, dict): + user_hp = hyperparameters + else: + user_hp = load_user_hparams(hyperparameters) + hp = deep_update(hp, user_hp) + hp = _prune_empty(hp) + + fit_cfg = {} + if time_limit is not None: + fit_cfg["time_limit"] = time_limit + if random_seed is not None: + fit_cfg["seed"] = int(random_seed) + if preset: + fit_cfg["presets"] = preset + + config = { + "fit": fit_cfg, + "hyperparameters": hp, + } + if threshold is not None: + config["threshold"] = float(threshold) + + return config + + +def run_autogluon_experiment( + train_dataset: pd.DataFrame, + test_dataset: Optional[pd.DataFrame], + target_column: str, + image_columns: Optional[List[str]], + ag_config: dict, +): + """ + Launch an AutoGluon MultiModal training run using the config from + autogluon_hyperparameters(). Returns (predictor, context dict) so callers + can evaluate downstream with the chosen threshold. + """ + if ag_config is None: + raise ValueError("ag_config is required to launch AutoGluon training.") + + hyperparameters = ag_config.get("hyperparameters") or {} + fit_cfg = dict(ag_config.get("fit") or {}) + threshold = ag_config.get("threshold") + + if "split" not in train_dataset.columns: + raise ValueError("train_dataset must contain a 'split' column. Did you call split_dataset?") + + df_train = train_dataset[train_dataset["split"] == "train"].copy() + df_val = train_dataset[train_dataset["split"].isin(["val", "validation"])].copy() + df_test_internal = train_dataset[train_dataset["split"] == "test"].copy() + + predictor = MultiModalPredictor(label=target_column, path=None) + column_types = {c: "image_path" for c in (image_columns or [])} + + fit_kwargs = { + "train_data": df_train, + "hyperparameters": hyperparameters, + } + fit_kwargs.update(fit_cfg) + if not df_val.empty: + fit_kwargs.setdefault("tuning_data", df_val) + if column_types: + fit_kwargs.setdefault("column_types", column_types) + + logger.info( + "Fitting AutoGluon with %d train / %d val rows (internal test rows: %d, external test provided: %s)", + len(df_train), + len(df_val), + len(df_test_internal), + (test_dataset is not None and not test_dataset.empty), + ) + predictor.fit(**fit_kwargs) + + return predictor, { + "train": df_train, + "val": df_val, + "test_internal": df_test_internal, + "test_external": test_dataset, + "threshold": threshold, + }
