diff training_pipeline.py @ 0:375c36923da1 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
author goeckslab
date Tue, 09 Dec 2025 23:49:47 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/training_pipeline.py	Tue Dec 09 23:49:47 2025 +0000
@@ -0,0 +1,551 @@
+from __future__ import annotations
+
+import contextlib
+import importlib
+import io
+import json
+import logging
+import os
+import tempfile
+import uuid
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+import pandas as pd
+import torch
+from autogluon.multimodal import MultiModalPredictor
+from metrics_logic import compute_metrics_for_split, evaluate_all_transparency
+from packaging.version import Version
+
+logger = logging.getLogger(__name__)
+
+# ---------------------- small utilities ----------------------
+
+
+def load_user_hparams(hp_arg: Optional[str]) -> dict:
+    """Parse --hyperparameters (inline JSON or path to .json)."""
+    if not hp_arg:
+        return {}
+    try:
+        s = hp_arg.strip()
+        if s.startswith("{"):
+            return json.loads(s)
+        with open(s, "r") as f:
+            return json.load(f)
+    except Exception as e:
+        logger.warning(f"Could not parse --hyperparameters: {e}. Ignoring.")
+        return {}
+
+
+def deep_update(dst: dict, src: dict) -> dict:
+    """Recursive dict update (src overrides dst)."""
+    for k, v in (src or {}).items():
+        if isinstance(v, dict) and isinstance(dst.get(k), dict):
+            deep_update(dst[k], v)
+        else:
+            dst[k] = v
+    return dst
+
+
+@contextlib.contextmanager
+def suppress_stdout_stderr():
+    """Silence noisy prints from AG internals (fit_summary)."""
+    with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
+        yield
+
+
+def ag_evaluate_safely(predictor, df: Optional[pd.DataFrame], metrics: Optional[List[str]] = None) -> Dict[str, float]:
+    """
+    Call predictor.evaluate and normalize the output to a dict.
+    """
+    if df is None or len(df) == 0:
+        return {}
+    try:
+        res = predictor.evaluate(df, metrics=metrics)
+    except TypeError:
+        if metrics and len(metrics) == 1:
+            res = predictor.evaluate(df, metrics[0])
+        else:
+            res = predictor.evaluate(df)
+    if isinstance(res, (int, float, np.floating)):
+        name = (metrics[0] if metrics else "metric")
+        return {name: float(res)}
+    if isinstance(res, dict):
+        return {k: float(v) for k, v in res.items()}
+    return {"metric": float(res)}
+
+
+# ---------------------- hparams & training ----------------------
+def build_mm_hparams(args, df_train: pd.DataFrame, image_columns: Optional[List[str]]) -> dict:
+    """
+    Build hyperparameters for MultiModalPredictor.
+    Handles text checkpoints for torch<2.6 and merges user overrides.
+    """
+    inferred_text_cols = [
+        c for c in df_train.columns
+        if c != args.label_column
+        and str(df_train[c].dtype) == "object"
+        and df_train[c].notna().any()
+    ]
+    text_cols = inferred_text_cols
+
+    ag_version = None
+    try:
+        ag_mod = importlib.import_module("autogluon")
+        ag_ver = getattr(ag_mod, "__version__", None)
+        if ag_ver:
+            ag_version = Version(str(ag_ver))
+    except Exception:
+        ag_mod = None
+
+    def _log_missing_support(key: str) -> None:
+        logger.info(
+            "AutoGluon version %s does not expose '%s'; skipping override.",
+            ag_version or "unknown",
+            key,
+        )
+
+    hp = {}
+
+    # Setup environment
+    hp["env"] = {
+        "seed": int(args.random_seed)
+    }
+
+    # Set eval metric through model config
+    model_block = hp.setdefault("model", {})
+    if args.eval_metric:
+        model_block.setdefault("metric_learning", {})["metric"] = str(args.eval_metric)
+
+    if text_cols and Version(torch.__version__) < Version("2.6"):
+        safe_ckpt = "distilbert-base-uncased"
+        logger.warning(f"Forcing HF text checkpoint with safetensors: {safe_ckpt}")
+        hp["model.hf_text.checkpoint_name"] = safe_ckpt
+        hp.setdefault(
+            "model.names",
+            ["hf_text", "timm_image", "numerical_mlp", "categorical_mlp", "fusion_mlp"],
+        )
+
+    def _is_valid_hp_dict(d) -> bool:
+        if not isinstance(d, dict):
+            logger.warning("User-supplied hyperparameters must be a dict; received %s", type(d).__name__)
+            return False
+        return True
+
+    user_hp = args.hyperparameters if isinstance(args.hyperparameters, dict) else load_user_hparams(args.hyperparameters)
+    if user_hp and _is_valid_hp_dict(user_hp):
+        hp = deep_update(hp, user_hp)
+
+    # Map CLI knobs into AutoMM optimization hyperparameters when provided.
+    # We set multiple common key names (nested dicts and dotted flat keys) to
+    # maximize compatibility across AutoMM/AutoGluon versions.
+    try:
+        if any(getattr(args, param, None) is not None for param in ["epochs", "learning_rate", "batch_size"]):
+            if getattr(args, "epochs", None) is not None:
+                hp["optim.max_epochs"] = int(args.epochs)
+                hp["optim.epochs"] = int(args.epochs)
+            if getattr(args, "learning_rate", None) is not None:
+                hp["optim.learning_rate"] = float(args.learning_rate)
+                hp["optim.lr"] = float(args.learning_rate)
+            if getattr(args, "batch_size", None) is not None:
+                hp["optim.batch_size"] = int(args.batch_size)
+                hp["optim.per_device_train_batch_size"] = int(args.batch_size)
+
+        # Also set dotted flat keys for max compatibility (e.g., 'optimization.max_epochs')
+        if getattr(args, "epochs", None) is not None:
+            hp["optimization.max_epochs"] = int(args.epochs)
+            hp["optimization.epochs"] = int(args.epochs)
+        if getattr(args, "learning_rate", None) is not None:
+            hp["optimization.learning_rate"] = float(args.learning_rate)
+            hp["optimization.lr"] = float(args.learning_rate)
+        if getattr(args, "batch_size", None) is not None:
+            hp["optimization.batch_size"] = int(args.batch_size)
+            hp["optimization.per_device_train_batch_size"] = int(args.batch_size)
+    except Exception:
+        logger.warning("Failed to attach epochs/learning_rate/batch_size to mm_hparams; continuing without them.")
+
+    # Map backbone selections into mm_hparams if provided
+    try:
+        has_text_cols = bool(text_cols)
+        has_image_cols = False
+        model_names_cache: Optional[List[str]] = None
+        model_names_modified = False
+
+        def _dedupe_preserve(seq: List[str]) -> List[str]:
+            seen = set()
+            ordered = []
+            for item in seq:
+                if item in seen:
+                    continue
+                seen.add(item)
+                ordered.append(item)
+            return ordered
+
+        def _get_model_names() -> List[str]:
+            nonlocal model_names_cache
+            if model_names_cache is not None:
+                return model_names_cache
+            names = model_block.get("names")
+            if isinstance(names, list):
+                model_names_cache = list(names)
+            else:
+                model_names_cache = []
+                if has_text_cols:
+                    model_names_cache.append("hf_text")
+                if has_image_cols:
+                    model_names_cache.append("timm_image")
+                model_names_cache.extend(["numerical_mlp", "categorical_mlp"])
+                model_names_cache.append("fusion_mlp")
+            return model_names_cache
+
+        def _set_model_names(new_names: List[str]) -> None:
+            nonlocal model_names_cache, model_names_modified
+            model_names_cache = new_names
+            model_names_modified = True
+
+        if has_text_cols and getattr(args, "backbone_text", None):
+            text_choice = str(args.backbone_text)
+            model_block.setdefault("hf_text", {})["checkpoint_name"] = text_choice
+            hp["model.hf_text.checkpoint_name"] = text_choice
+        if has_image_cols and getattr(args, "backbone_image", None):
+            image_choice = str(args.backbone_image)
+            model_block.setdefault("timm_image", {})["checkpoint_name"] = image_choice
+            hp["model.timm_image.checkpoint_name"] = image_choice
+        if model_names_modified and model_names_cache is not None:
+            model_block["names"] = model_names_cache
+    except Exception:
+        logger.warning("Failed to attach backbone selections to mm_hparams; continuing without them.")
+
+    if ag_version:
+        logger.info(f"Detected AutoGluon version: {ag_version}; applied robust hyperparameter mappings.")
+
+    return hp
+
+
+def train_predictor(
+    args,
+    df_train: pd.DataFrame,
+    df_val: pd.DataFrame,
+    image_columns: Optional[List[str]],
+    mm_hparams: dict,
+):
+    """
+    Train a MultiModalPredictor, honoring common knobs (presets, eval_metric, etc.).
+    """
+    logger.info("Starting AutoGluon MultiModal training...")
+    predictor = MultiModalPredictor(label=args.label_column, path=None)
+    column_types = {}
+
+    mm_fit_kwargs = dict(
+        train_data=df_train,
+        time_limit=args.time_limit,
+        seed=int(args.random_seed),
+        hyperparameters=mm_hparams,
+    )
+    if df_val is not None and not df_val.empty:
+        mm_fit_kwargs["tuning_data"] = df_val
+    if column_types:
+        mm_fit_kwargs["column_types"] = column_types
+
+    preset_mm = getattr(args, "presets", None)
+    if preset_mm is None:
+        preset_mm = getattr(args, "preset", None)
+    if preset_mm is not None:
+        mm_fit_kwargs["presets"] = preset_mm
+
+    predictor.fit(**mm_fit_kwargs)
+    return predictor
+
+
+# ---------------------- evaluation ----------------------
+def evaluate_predictor_all_splits(
+    predictor,
+    df_train: Optional[pd.DataFrame],
+    df_val: Optional[pd.DataFrame],
+    df_test: Optional[pd.DataFrame],
+    label_col: str,
+    problem_type: str,
+    eval_metric: Optional[str],
+    threshold_test: Optional[float],
+    df_test_external: Optional[pd.DataFrame] = None,
+) -> Tuple[Dict[str, Dict[str, float]], Dict[str, Dict[str, float]]]:
+    """
+    Returns (raw_metrics, ag_scores_by_split)
+      - raw_metrics: our transparent suite (threshold applied to Test/External Test only inside metrics_logic)
+      - ag_scores_by_split: AutoGluon's evaluate() per split for the chosen eval_metric (or default)
+    """
+    metrics_req = None if (eval_metric is None or str(eval_metric).lower() == "auto") else [eval_metric]
+    ag_by_split: Dict[str, Dict[str, float]] = {}
+
+    if df_train is not None and len(df_train):
+        ag_by_split["Train"] = ag_evaluate_safely(predictor, df_train, metrics=metrics_req)
+    if df_val is not None and len(df_val):
+        ag_by_split["Validation"] = ag_evaluate_safely(predictor, df_val, metrics=metrics_req)
+
+    df_test_effective = df_test_external if df_test_external is not None else df_test
+    if df_test_effective is not None and len(df_test_effective):
+        ag_by_split["Test"] = ag_evaluate_safely(predictor, df_test_effective, metrics=metrics_req)
+
+    # Transparent suite (threshold on Test handled inside metrics_logic)
+    _, raw_metrics = evaluate_all_transparency(
+        predictor=predictor,
+        train_df=df_train,
+        val_df=df_val,
+        test_df=df_test_effective,
+        target_col=label_col,
+        problem_type=problem_type,
+        threshold=threshold_test,
+    )
+
+    if df_test_external is not None and df_test_external is not df_test and len(df_test_external):
+        raw_metrics["Test (external)"] = compute_metrics_for_split(
+            predictor, df_test_external, label_col, problem_type, threshold=threshold_test
+        )
+        ag_by_split["Test (external)"] = ag_evaluate_safely(predictor, df_test_external, metrics=metrics_req)
+
+    return raw_metrics, ag_by_split
+
+
+def fit_summary_safely(predictor) -> Optional[dict]:
+    """Get fit summary without printing misleading one-liners."""
+    with suppress_stdout_stderr():
+        try:
+            return predictor.fit_summary()
+        except Exception:
+            return None
+
+
+# ---------------------- image helpers ----------------------
+_PLACEHOLDER_PATH = None
+
+
+def _create_placeholder() -> str:
+    global _PLACEHOLDER_PATH
+    if _PLACEHOLDER_PATH and os.path.exists(_PLACEHOLDER_PATH):
+        return _PLACEHOLDER_PATH
+
+    dir_ = Path(tempfile.mkdtemp(prefix="ag_placeholder_"))
+    file_ = dir_ / f"placeholder_{uuid.uuid4().hex}.png"
+
+    try:
+        from PIL import Image
+        Image.new("RGB", (64, 64), (180, 180, 180)).save(file_)
+    except Exception:
+        import matplotlib.pyplot as plt
+        import numpy as np
+        plt.imsave(file_, np.full((64, 64, 3), 180, dtype=np.uint8))
+        plt.close("all")
+
+    _PLACEHOLDER_PATH = str(file_)
+    logger.info(f"Placeholder image created: {file_}")
+    return _PLACEHOLDER_PATH
+
+
+def _is_valid_path(val) -> bool:
+    if pd.isna(val):
+        return False
+    s = str(val).strip()
+    return s and os.path.isfile(s)
+
+
+def handle_missing_images(
+    df: pd.DataFrame,
+    image_columns: List[str],
+    strategy: str = "false",
+) -> pd.DataFrame:
+    if not image_columns or df.empty:
+        return df
+
+    remove = str(strategy).lower() == "true"
+    masks = [~df[col].apply(_is_valid_path) for col in image_columns if col in df.columns]
+    if not masks:
+        return df
+
+    any_missing = pd.concat(masks, axis=1).any(axis=1)
+    n_missing = int(any_missing.sum())
+
+    if n_missing == 0:
+        return df
+
+    if remove:
+        result = df[~any_missing].reset_index(drop=True)
+        logger.info(f"Dropped {n_missing} rows with missing images → {len(result)} remain")
+    else:
+        placeholder = _create_placeholder()
+        result = df.copy()
+        for col in image_columns:
+            if col in result.columns:
+                result.loc[~result[col].apply(_is_valid_path), col] = placeholder
+        logger.info(f"Filled {n_missing} missing images with placeholder")
+
+    return result
+
+
+# ---------------------- AutoGluon config helpers ----------------------
+def autogluon_hyperparameters(
+    threshold,
+    time_limit,
+    random_seed,
+    epochs,
+    learning_rate,
+    batch_size,
+    backbone_image,
+    backbone_text,
+    preset,
+    eval_metric,
+    hyperparameters,
+):
+    """
+    Build a MultiModalPredictor configuration (fit kwargs + hyperparameters) from CLI inputs.
+    The returned dict separates what should be passed to predictor.fit (under ``fit``)
+    from the model/optimization configuration (under ``hyperparameters``). Threshold is
+    preserved for downstream evaluation but not passed into AutoGluon directly.
+    """
+
+    def _prune_empty(d: dict) -> dict:
+        cleaned = {}
+        for k, v in (d or {}).items():
+            if isinstance(v, dict):
+                nested = _prune_empty(v)
+                if nested:
+                    cleaned[k] = nested
+            elif v is not None:
+                cleaned[k] = v
+        return cleaned
+
+    # Base hyperparameters following the structure described in the AutoGluon
+    # customization guide (env / optimization / model).
+    env_cfg = {}
+    if random_seed is not None:
+        env_cfg["seed"] = int(random_seed)
+    if batch_size is not None:
+        env_cfg["per_gpu_batch_size"] = int(batch_size)
+
+    optim_cfg = {}
+    if epochs is not None:
+        optim_cfg["max_epochs"] = int(epochs)
+    if learning_rate is not None:
+        optim_cfg["learning_rate"] = float(learning_rate)
+    if batch_size is not None:
+        bs = int(batch_size)
+        optim_cfg["per_device_train_batch_size"] = bs
+        optim_cfg["train_batch_size"] = bs
+
+    model_cfg = {}
+    if eval_metric:
+        model_cfg.setdefault("metric_learning", {})["metric"] = str(eval_metric)
+    if backbone_image:
+        model_cfg.setdefault("timm_image", {})["checkpoint_name"] = str(backbone_image)
+    if backbone_text:
+        model_cfg.setdefault("hf_text", {})["checkpoint_name"] = str(backbone_text)
+
+    hp = {
+        "env": env_cfg,
+        "optimization": optim_cfg,
+        "model": model_cfg,
+    }
+
+    # Also expose the most common dotted aliases for robustness across AG versions.
+    if epochs is not None:
+        hp["optimization.max_epochs"] = int(epochs)
+        hp["optim.max_epochs"] = int(epochs)
+    if learning_rate is not None:
+        lr_val = float(learning_rate)
+        hp["optimization.learning_rate"] = lr_val
+        hp["optimization.lr"] = lr_val
+        hp["optim.learning_rate"] = lr_val
+        hp["optim.lr"] = lr_val
+    if batch_size is not None:
+        bs_val = int(batch_size)
+        hp["optimization.per_device_train_batch_size"] = bs_val
+        hp["optimization.batch_size"] = bs_val
+        hp["optim.per_device_train_batch_size"] = bs_val
+        hp["optim.batch_size"] = bs_val
+        hp["env.per_gpu_batch_size"] = bs_val
+    if backbone_image:
+        hp["model.timm_image.checkpoint_name"] = str(backbone_image)
+    if backbone_text:
+        hp["model.hf_text.checkpoint_name"] = str(backbone_text)
+
+    # Merge user-provided hyperparameters (inline JSON or path) last so they win.
+    if isinstance(hyperparameters, dict):
+        user_hp = hyperparameters
+    else:
+        user_hp = load_user_hparams(hyperparameters)
+    hp = deep_update(hp, user_hp)
+    hp = _prune_empty(hp)
+
+    fit_cfg = {}
+    if time_limit is not None:
+        fit_cfg["time_limit"] = time_limit
+    if random_seed is not None:
+        fit_cfg["seed"] = int(random_seed)
+    if preset:
+        fit_cfg["presets"] = preset
+
+    config = {
+        "fit": fit_cfg,
+        "hyperparameters": hp,
+    }
+    if threshold is not None:
+        config["threshold"] = float(threshold)
+
+    return config
+
+
+def run_autogluon_experiment(
+    train_dataset: pd.DataFrame,
+    test_dataset: Optional[pd.DataFrame],
+    target_column: str,
+    image_columns: Optional[List[str]],
+    ag_config: dict,
+):
+    """
+    Launch an AutoGluon MultiModal training run using the config from
+    autogluon_hyperparameters(). Returns (predictor, context dict) so callers
+    can evaluate downstream with the chosen threshold.
+    """
+    if ag_config is None:
+        raise ValueError("ag_config is required to launch AutoGluon training.")
+
+    hyperparameters = ag_config.get("hyperparameters") or {}
+    fit_cfg = dict(ag_config.get("fit") or {})
+    threshold = ag_config.get("threshold")
+
+    if "split" not in train_dataset.columns:
+        raise ValueError("train_dataset must contain a 'split' column. Did you call split_dataset?")
+
+    df_train = train_dataset[train_dataset["split"] == "train"].copy()
+    df_val = train_dataset[train_dataset["split"].isin(["val", "validation"])].copy()
+    df_test_internal = train_dataset[train_dataset["split"] == "test"].copy()
+
+    predictor = MultiModalPredictor(label=target_column, path=None)
+    column_types = {c: "image_path" for c in (image_columns or [])}
+
+    fit_kwargs = {
+        "train_data": df_train,
+        "hyperparameters": hyperparameters,
+    }
+    fit_kwargs.update(fit_cfg)
+    if not df_val.empty:
+        fit_kwargs.setdefault("tuning_data", df_val)
+    if column_types:
+        fit_kwargs.setdefault("column_types", column_types)
+
+    logger.info(
+        "Fitting AutoGluon with %d train / %d val rows (internal test rows: %d, external test provided: %s)",
+        len(df_train),
+        len(df_val),
+        len(df_test_internal),
+        (test_dataset is not None and not test_dataset.empty),
+    )
+    predictor.fit(**fit_kwargs)
+
+    return predictor, {
+        "train": df_train,
+        "val": df_val,
+        "test_internal": df_test_internal,
+        "test_external": test_dataset,
+        "threshold": threshold,
+    }