view training_pipeline.py @ 0:375c36923da1 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
author goeckslab
date Tue, 09 Dec 2025 23:49:47 +0000
parents
children
line wrap: on
line source

from __future__ import annotations

import contextlib
import importlib
import io
import json
import logging
import os
import tempfile
import uuid
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
from autogluon.multimodal import MultiModalPredictor
from metrics_logic import compute_metrics_for_split, evaluate_all_transparency
from packaging.version import Version

logger = logging.getLogger(__name__)

# ---------------------- small utilities ----------------------


def load_user_hparams(hp_arg: Optional[str]) -> dict:
    """Parse --hyperparameters (inline JSON or path to .json)."""
    if not hp_arg:
        return {}
    try:
        s = hp_arg.strip()
        if s.startswith("{"):
            return json.loads(s)
        with open(s, "r") as f:
            return json.load(f)
    except Exception as e:
        logger.warning(f"Could not parse --hyperparameters: {e}. Ignoring.")
        return {}


def deep_update(dst: dict, src: dict) -> dict:
    """Recursive dict update (src overrides dst)."""
    for k, v in (src or {}).items():
        if isinstance(v, dict) and isinstance(dst.get(k), dict):
            deep_update(dst[k], v)
        else:
            dst[k] = v
    return dst


@contextlib.contextmanager
def suppress_stdout_stderr():
    """Silence noisy prints from AG internals (fit_summary)."""
    with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
        yield


def ag_evaluate_safely(predictor, df: Optional[pd.DataFrame], metrics: Optional[List[str]] = None) -> Dict[str, float]:
    """
    Call predictor.evaluate and normalize the output to a dict.
    """
    if df is None or len(df) == 0:
        return {}
    try:
        res = predictor.evaluate(df, metrics=metrics)
    except TypeError:
        if metrics and len(metrics) == 1:
            res = predictor.evaluate(df, metrics[0])
        else:
            res = predictor.evaluate(df)
    if isinstance(res, (int, float, np.floating)):
        name = (metrics[0] if metrics else "metric")
        return {name: float(res)}
    if isinstance(res, dict):
        return {k: float(v) for k, v in res.items()}
    return {"metric": float(res)}


# ---------------------- hparams & training ----------------------
def build_mm_hparams(args, df_train: pd.DataFrame, image_columns: Optional[List[str]]) -> dict:
    """
    Build hyperparameters for MultiModalPredictor.
    Handles text checkpoints for torch<2.6 and merges user overrides.
    """
    inferred_text_cols = [
        c for c in df_train.columns
        if c != args.label_column
        and str(df_train[c].dtype) == "object"
        and df_train[c].notna().any()
    ]
    text_cols = inferred_text_cols

    ag_version = None
    try:
        ag_mod = importlib.import_module("autogluon")
        ag_ver = getattr(ag_mod, "__version__", None)
        if ag_ver:
            ag_version = Version(str(ag_ver))
    except Exception:
        ag_mod = None

    def _log_missing_support(key: str) -> None:
        logger.info(
            "AutoGluon version %s does not expose '%s'; skipping override.",
            ag_version or "unknown",
            key,
        )

    hp = {}

    # Setup environment
    hp["env"] = {
        "seed": int(args.random_seed)
    }

    # Set eval metric through model config
    model_block = hp.setdefault("model", {})
    if args.eval_metric:
        model_block.setdefault("metric_learning", {})["metric"] = str(args.eval_metric)

    if text_cols and Version(torch.__version__) < Version("2.6"):
        safe_ckpt = "distilbert-base-uncased"
        logger.warning(f"Forcing HF text checkpoint with safetensors: {safe_ckpt}")
        hp["model.hf_text.checkpoint_name"] = safe_ckpt
        hp.setdefault(
            "model.names",
            ["hf_text", "timm_image", "numerical_mlp", "categorical_mlp", "fusion_mlp"],
        )

    def _is_valid_hp_dict(d) -> bool:
        if not isinstance(d, dict):
            logger.warning("User-supplied hyperparameters must be a dict; received %s", type(d).__name__)
            return False
        return True

    user_hp = args.hyperparameters if isinstance(args.hyperparameters, dict) else load_user_hparams(args.hyperparameters)
    if user_hp and _is_valid_hp_dict(user_hp):
        hp = deep_update(hp, user_hp)

    # Map CLI knobs into AutoMM optimization hyperparameters when provided.
    # We set multiple common key names (nested dicts and dotted flat keys) to
    # maximize compatibility across AutoMM/AutoGluon versions.
    try:
        if any(getattr(args, param, None) is not None for param in ["epochs", "learning_rate", "batch_size"]):
            if getattr(args, "epochs", None) is not None:
                hp["optim.max_epochs"] = int(args.epochs)
                hp["optim.epochs"] = int(args.epochs)
            if getattr(args, "learning_rate", None) is not None:
                hp["optim.learning_rate"] = float(args.learning_rate)
                hp["optim.lr"] = float(args.learning_rate)
            if getattr(args, "batch_size", None) is not None:
                hp["optim.batch_size"] = int(args.batch_size)
                hp["optim.per_device_train_batch_size"] = int(args.batch_size)

        # Also set dotted flat keys for max compatibility (e.g., 'optimization.max_epochs')
        if getattr(args, "epochs", None) is not None:
            hp["optimization.max_epochs"] = int(args.epochs)
            hp["optimization.epochs"] = int(args.epochs)
        if getattr(args, "learning_rate", None) is not None:
            hp["optimization.learning_rate"] = float(args.learning_rate)
            hp["optimization.lr"] = float(args.learning_rate)
        if getattr(args, "batch_size", None) is not None:
            hp["optimization.batch_size"] = int(args.batch_size)
            hp["optimization.per_device_train_batch_size"] = int(args.batch_size)
    except Exception:
        logger.warning("Failed to attach epochs/learning_rate/batch_size to mm_hparams; continuing without them.")

    # Map backbone selections into mm_hparams if provided
    try:
        has_text_cols = bool(text_cols)
        has_image_cols = False
        model_names_cache: Optional[List[str]] = None
        model_names_modified = False

        def _dedupe_preserve(seq: List[str]) -> List[str]:
            seen = set()
            ordered = []
            for item in seq:
                if item in seen:
                    continue
                seen.add(item)
                ordered.append(item)
            return ordered

        def _get_model_names() -> List[str]:
            nonlocal model_names_cache
            if model_names_cache is not None:
                return model_names_cache
            names = model_block.get("names")
            if isinstance(names, list):
                model_names_cache = list(names)
            else:
                model_names_cache = []
                if has_text_cols:
                    model_names_cache.append("hf_text")
                if has_image_cols:
                    model_names_cache.append("timm_image")
                model_names_cache.extend(["numerical_mlp", "categorical_mlp"])
                model_names_cache.append("fusion_mlp")
            return model_names_cache

        def _set_model_names(new_names: List[str]) -> None:
            nonlocal model_names_cache, model_names_modified
            model_names_cache = new_names
            model_names_modified = True

        if has_text_cols and getattr(args, "backbone_text", None):
            text_choice = str(args.backbone_text)
            model_block.setdefault("hf_text", {})["checkpoint_name"] = text_choice
            hp["model.hf_text.checkpoint_name"] = text_choice
        if has_image_cols and getattr(args, "backbone_image", None):
            image_choice = str(args.backbone_image)
            model_block.setdefault("timm_image", {})["checkpoint_name"] = image_choice
            hp["model.timm_image.checkpoint_name"] = image_choice
        if model_names_modified and model_names_cache is not None:
            model_block["names"] = model_names_cache
    except Exception:
        logger.warning("Failed to attach backbone selections to mm_hparams; continuing without them.")

    if ag_version:
        logger.info(f"Detected AutoGluon version: {ag_version}; applied robust hyperparameter mappings.")

    return hp


def train_predictor(
    args,
    df_train: pd.DataFrame,
    df_val: pd.DataFrame,
    image_columns: Optional[List[str]],
    mm_hparams: dict,
):
    """
    Train a MultiModalPredictor, honoring common knobs (presets, eval_metric, etc.).
    """
    logger.info("Starting AutoGluon MultiModal training...")
    predictor = MultiModalPredictor(label=args.label_column, path=None)
    column_types = {}

    mm_fit_kwargs = dict(
        train_data=df_train,
        time_limit=args.time_limit,
        seed=int(args.random_seed),
        hyperparameters=mm_hparams,
    )
    if df_val is not None and not df_val.empty:
        mm_fit_kwargs["tuning_data"] = df_val
    if column_types:
        mm_fit_kwargs["column_types"] = column_types

    preset_mm = getattr(args, "presets", None)
    if preset_mm is None:
        preset_mm = getattr(args, "preset", None)
    if preset_mm is not None:
        mm_fit_kwargs["presets"] = preset_mm

    predictor.fit(**mm_fit_kwargs)
    return predictor


# ---------------------- evaluation ----------------------
def evaluate_predictor_all_splits(
    predictor,
    df_train: Optional[pd.DataFrame],
    df_val: Optional[pd.DataFrame],
    df_test: Optional[pd.DataFrame],
    label_col: str,
    problem_type: str,
    eval_metric: Optional[str],
    threshold_test: Optional[float],
    df_test_external: Optional[pd.DataFrame] = None,
) -> Tuple[Dict[str, Dict[str, float]], Dict[str, Dict[str, float]]]:
    """
    Returns (raw_metrics, ag_scores_by_split)
      - raw_metrics: our transparent suite (threshold applied to Test/External Test only inside metrics_logic)
      - ag_scores_by_split: AutoGluon's evaluate() per split for the chosen eval_metric (or default)
    """
    metrics_req = None if (eval_metric is None or str(eval_metric).lower() == "auto") else [eval_metric]
    ag_by_split: Dict[str, Dict[str, float]] = {}

    if df_train is not None and len(df_train):
        ag_by_split["Train"] = ag_evaluate_safely(predictor, df_train, metrics=metrics_req)
    if df_val is not None and len(df_val):
        ag_by_split["Validation"] = ag_evaluate_safely(predictor, df_val, metrics=metrics_req)

    df_test_effective = df_test_external if df_test_external is not None else df_test
    if df_test_effective is not None and len(df_test_effective):
        ag_by_split["Test"] = ag_evaluate_safely(predictor, df_test_effective, metrics=metrics_req)

    # Transparent suite (threshold on Test handled inside metrics_logic)
    _, raw_metrics = evaluate_all_transparency(
        predictor=predictor,
        train_df=df_train,
        val_df=df_val,
        test_df=df_test_effective,
        target_col=label_col,
        problem_type=problem_type,
        threshold=threshold_test,
    )

    if df_test_external is not None and df_test_external is not df_test and len(df_test_external):
        raw_metrics["Test (external)"] = compute_metrics_for_split(
            predictor, df_test_external, label_col, problem_type, threshold=threshold_test
        )
        ag_by_split["Test (external)"] = ag_evaluate_safely(predictor, df_test_external, metrics=metrics_req)

    return raw_metrics, ag_by_split


def fit_summary_safely(predictor) -> Optional[dict]:
    """Get fit summary without printing misleading one-liners."""
    with suppress_stdout_stderr():
        try:
            return predictor.fit_summary()
        except Exception:
            return None


# ---------------------- image helpers ----------------------
_PLACEHOLDER_PATH = None


def _create_placeholder() -> str:
    global _PLACEHOLDER_PATH
    if _PLACEHOLDER_PATH and os.path.exists(_PLACEHOLDER_PATH):
        return _PLACEHOLDER_PATH

    dir_ = Path(tempfile.mkdtemp(prefix="ag_placeholder_"))
    file_ = dir_ / f"placeholder_{uuid.uuid4().hex}.png"

    try:
        from PIL import Image
        Image.new("RGB", (64, 64), (180, 180, 180)).save(file_)
    except Exception:
        import matplotlib.pyplot as plt
        import numpy as np
        plt.imsave(file_, np.full((64, 64, 3), 180, dtype=np.uint8))
        plt.close("all")

    _PLACEHOLDER_PATH = str(file_)
    logger.info(f"Placeholder image created: {file_}")
    return _PLACEHOLDER_PATH


def _is_valid_path(val) -> bool:
    if pd.isna(val):
        return False
    s = str(val).strip()
    return s and os.path.isfile(s)


def handle_missing_images(
    df: pd.DataFrame,
    image_columns: List[str],
    strategy: str = "false",
) -> pd.DataFrame:
    if not image_columns or df.empty:
        return df

    remove = str(strategy).lower() == "true"
    masks = [~df[col].apply(_is_valid_path) for col in image_columns if col in df.columns]
    if not masks:
        return df

    any_missing = pd.concat(masks, axis=1).any(axis=1)
    n_missing = int(any_missing.sum())

    if n_missing == 0:
        return df

    if remove:
        result = df[~any_missing].reset_index(drop=True)
        logger.info(f"Dropped {n_missing} rows with missing images → {len(result)} remain")
    else:
        placeholder = _create_placeholder()
        result = df.copy()
        for col in image_columns:
            if col in result.columns:
                result.loc[~result[col].apply(_is_valid_path), col] = placeholder
        logger.info(f"Filled {n_missing} missing images with placeholder")

    return result


# ---------------------- AutoGluon config helpers ----------------------
def autogluon_hyperparameters(
    threshold,
    time_limit,
    random_seed,
    epochs,
    learning_rate,
    batch_size,
    backbone_image,
    backbone_text,
    preset,
    eval_metric,
    hyperparameters,
):
    """
    Build a MultiModalPredictor configuration (fit kwargs + hyperparameters) from CLI inputs.
    The returned dict separates what should be passed to predictor.fit (under ``fit``)
    from the model/optimization configuration (under ``hyperparameters``). Threshold is
    preserved for downstream evaluation but not passed into AutoGluon directly.
    """

    def _prune_empty(d: dict) -> dict:
        cleaned = {}
        for k, v in (d or {}).items():
            if isinstance(v, dict):
                nested = _prune_empty(v)
                if nested:
                    cleaned[k] = nested
            elif v is not None:
                cleaned[k] = v
        return cleaned

    # Base hyperparameters following the structure described in the AutoGluon
    # customization guide (env / optimization / model).
    env_cfg = {}
    if random_seed is not None:
        env_cfg["seed"] = int(random_seed)
    if batch_size is not None:
        env_cfg["per_gpu_batch_size"] = int(batch_size)

    optim_cfg = {}
    if epochs is not None:
        optim_cfg["max_epochs"] = int(epochs)
    if learning_rate is not None:
        optim_cfg["learning_rate"] = float(learning_rate)
    if batch_size is not None:
        bs = int(batch_size)
        optim_cfg["per_device_train_batch_size"] = bs
        optim_cfg["train_batch_size"] = bs

    model_cfg = {}
    if eval_metric:
        model_cfg.setdefault("metric_learning", {})["metric"] = str(eval_metric)
    if backbone_image:
        model_cfg.setdefault("timm_image", {})["checkpoint_name"] = str(backbone_image)
    if backbone_text:
        model_cfg.setdefault("hf_text", {})["checkpoint_name"] = str(backbone_text)

    hp = {
        "env": env_cfg,
        "optimization": optim_cfg,
        "model": model_cfg,
    }

    # Also expose the most common dotted aliases for robustness across AG versions.
    if epochs is not None:
        hp["optimization.max_epochs"] = int(epochs)
        hp["optim.max_epochs"] = int(epochs)
    if learning_rate is not None:
        lr_val = float(learning_rate)
        hp["optimization.learning_rate"] = lr_val
        hp["optimization.lr"] = lr_val
        hp["optim.learning_rate"] = lr_val
        hp["optim.lr"] = lr_val
    if batch_size is not None:
        bs_val = int(batch_size)
        hp["optimization.per_device_train_batch_size"] = bs_val
        hp["optimization.batch_size"] = bs_val
        hp["optim.per_device_train_batch_size"] = bs_val
        hp["optim.batch_size"] = bs_val
        hp["env.per_gpu_batch_size"] = bs_val
    if backbone_image:
        hp["model.timm_image.checkpoint_name"] = str(backbone_image)
    if backbone_text:
        hp["model.hf_text.checkpoint_name"] = str(backbone_text)

    # Merge user-provided hyperparameters (inline JSON or path) last so they win.
    if isinstance(hyperparameters, dict):
        user_hp = hyperparameters
    else:
        user_hp = load_user_hparams(hyperparameters)
    hp = deep_update(hp, user_hp)
    hp = _prune_empty(hp)

    fit_cfg = {}
    if time_limit is not None:
        fit_cfg["time_limit"] = time_limit
    if random_seed is not None:
        fit_cfg["seed"] = int(random_seed)
    if preset:
        fit_cfg["presets"] = preset

    config = {
        "fit": fit_cfg,
        "hyperparameters": hp,
    }
    if threshold is not None:
        config["threshold"] = float(threshold)

    return config


def run_autogluon_experiment(
    train_dataset: pd.DataFrame,
    test_dataset: Optional[pd.DataFrame],
    target_column: str,
    image_columns: Optional[List[str]],
    ag_config: dict,
):
    """
    Launch an AutoGluon MultiModal training run using the config from
    autogluon_hyperparameters(). Returns (predictor, context dict) so callers
    can evaluate downstream with the chosen threshold.
    """
    if ag_config is None:
        raise ValueError("ag_config is required to launch AutoGluon training.")

    hyperparameters = ag_config.get("hyperparameters") or {}
    fit_cfg = dict(ag_config.get("fit") or {})
    threshold = ag_config.get("threshold")

    if "split" not in train_dataset.columns:
        raise ValueError("train_dataset must contain a 'split' column. Did you call split_dataset?")

    df_train = train_dataset[train_dataset["split"] == "train"].copy()
    df_val = train_dataset[train_dataset["split"].isin(["val", "validation"])].copy()
    df_test_internal = train_dataset[train_dataset["split"] == "test"].copy()

    predictor = MultiModalPredictor(label=target_column, path=None)
    column_types = {c: "image_path" for c in (image_columns or [])}

    fit_kwargs = {
        "train_data": df_train,
        "hyperparameters": hyperparameters,
    }
    fit_kwargs.update(fit_cfg)
    if not df_val.empty:
        fit_kwargs.setdefault("tuning_data", df_val)
    if column_types:
        fit_kwargs.setdefault("column_types", column_types)

    logger.info(
        "Fitting AutoGluon with %d train / %d val rows (internal test rows: %d, external test provided: %s)",
        len(df_train),
        len(df_val),
        len(df_test_internal),
        (test_dataset is not None and not test_dataset.empty),
    )
    predictor.fit(**fit_kwargs)

    return predictor, {
        "train": df_train,
        "val": df_val,
        "test_internal": df_test_internal,
        "test_external": test_dataset,
        "threshold": threshold,
    }