view multimodal_learner.py @ 0:375c36923da1 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
author goeckslab
date Tue, 09 Dec 2025 23:49:47 +0000
parents
children
line wrap: on
line source

#!/usr/bin/env python
"""
Main entrypoint for AutoGluon multimodal training wrapper.
"""

import argparse
import logging
import os
import sys
from typing import List, Optional

import pandas as pd
from metrics_logic import aggregate_metrics
from plot_logic import infer_problem_type
from report_utils import write_outputs
from sklearn.model_selection import KFold, StratifiedKFold
from split_logic import split_dataset
from test_pipeline import run_autogluon_test_experiment
from training_pipeline import autogluon_hyperparameters, handle_missing_images, run_autogluon_experiment
# ------------------------------------------------------------------
# Local imports (your split utilities)
# ------------------------------------------------------------------
from utils import (
    absolute_path_expander,
    enable_deterministic_mode,
    enable_tensor_cores_if_available,
    ensure_local_tmp,
    load_file,
    prepare_image_search_dirs,
    set_seeds,
    str2bool,
)

# ------------------------------------------------------------------
# Logger setup
# ------------------------------------------------------------------
logger = logging.getLogger(__name__)


# ------------------------------------------------------------------
# Argument parsing (unchanged from your original, only minor fixes)
# ------------------------------------------------------------------
def parse_args(argv=None):
    parser = argparse.ArgumentParser(description="Train & report an AutoGluon model")

    parser.add_argument("--input_csv_train", dest="train_dataset", required=True)
    parser.add_argument("--input_csv_test", dest="test_dataset", default=None)
    parser.add_argument("--target_column", required=True)
    parser.add_argument("--output_json", default="results.json")
    parser.add_argument("--output_html", default="report.html")
    parser.add_argument("--output_config", default=None)
    parser.add_argument("--images_zip", nargs="*", default=None,
                        help="One or more ZIP files that contain image assets")
    parser.add_argument("--missing_image_strategy", default="false",
                        help="true/false: remove rows with missing images or use placeholder")
    parser.add_argument("--threshold", type=float, default=None)
    parser.add_argument("--time_limit", type=int, default=None)
    parser.add_argument("--deterministic", action="store_true", default=False,
                        help="Enable deterministic algorithms to reduce run-to-run variance")
    parser.add_argument("--random_seed", type=int, default=42)
    parser.add_argument("--cross_validation", type=str, default="false")
    parser.add_argument("--num_folds", type=int, default=5)
    parser.add_argument("--epochs", type=int, default=None)
    parser.add_argument("--learning_rate", type=float, default=None)
    parser.add_argument("--batch_size", type=int, default=None)
    parser.add_argument("--backbone_image", type=str, default="swin_base_patch4_window7_224")
    parser.add_argument("--backbone_text", type=str, default="microsoft/deberta-v3-base")
    parser.add_argument("--validation_size", type=float, default=0.2)
    parser.add_argument("--split_probabilities", type=float, nargs=3,
                        default=[0.7, 0.1, 0.2], metavar=("train", "val", "test"))
    parser.add_argument("--preset", choices=["medium_quality", "high_quality", "best_quality"],
                        default="medium_quality")
    parser.add_argument("--eval_metric", default="roc_auc")
    parser.add_argument("--hyperparameters", default=None)

    args, unknown = parser.parse_known_args(argv)
    if unknown:
        logger.warning("Ignoring unknown CLI tokens: %s", unknown)

    # -------------------------- Validation --------------------------
    if not (0.0 <= args.validation_size <= 1.0):
        parser.error("--validation_size must be in [0, 1]")
    if len(args.split_probabilities) != 3 or abs(sum(args.split_probabilities) - 1.0) > 1e-6:
        parser.error("--split_probabilities must be three numbers summing to 1.0")
    if args.cross_validation.lower() == "true" and (args.num_folds < 2):
        parser.error("--num_folds must be >= 2 when --cross_validation is true")

    return args


def run_cross_validation(
    args,
    df_full: pd.DataFrame,
    test_dataset: Optional[pd.DataFrame],
    image_cols: List[str],
    ag_config: dict,
):
    """Cross-validation loop returning aggregated metrics and last predictor."""
    df_full = df_full.drop(columns=["split"], errors="ignore")
    y = df_full[args.target_column]
    try:
        use_stratified = y.dtype == object or y.nunique() <= 20
    except Exception:
        use_stratified = False

    kf = StratifiedKFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed)) if use_stratified else KFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed))

    raw_folds = []
    ag_folds = []
    folds_info = []
    last_predictor = None
    last_data_ctx = None

    for fold_idx, (train_idx, val_idx) in enumerate(kf.split(df_full, y if use_stratified else None), start=1):
        logger.info(f"CV fold {fold_idx}/{args.num_folds}")
        df_tr = df_full.iloc[train_idx].copy()
        df_va = df_full.iloc[val_idx].copy()

        df_tr["split"] = "train"
        df_va["split"] = "val"
        fold_dataset = pd.concat([df_tr, df_va], ignore_index=True)

        predictor_fold, data_ctx = run_autogluon_experiment(
            train_dataset=fold_dataset,
            test_dataset=test_dataset,
            target_column=args.target_column,
            image_columns=image_cols,
            ag_config=ag_config,
        )
        last_predictor = predictor_fold
        last_data_ctx = data_ctx
        problem_type = infer_problem_type(predictor_fold, df_tr, args.target_column)
        eval_results = run_autogluon_test_experiment(
            predictor=predictor_fold,
            data_ctx=data_ctx,
            target_column=args.target_column,
            eval_metric=args.eval_metric,
            ag_config=ag_config,
            problem_type=problem_type,
        )

        raw_metrics_fold = eval_results.get("raw_metrics", {})
        ag_by_split_fold = eval_results.get("ag_eval", {})
        raw_folds.append(raw_metrics_fold)
        ag_folds.append(ag_by_split_fold)
        folds_info.append(
            {
                "fold": int(fold_idx),
                "predictor_path": getattr(predictor_fold, "path", None),
                "raw_metrics": raw_metrics_fold,
                "ag_eval": ag_by_split_fold,
            }
        )

    raw_metrics_mean, raw_metrics_std = aggregate_metrics(raw_folds)
    ag_by_split_mean, ag_by_split_std = aggregate_metrics(ag_folds)
    return (
        last_predictor,
        raw_metrics_mean,
        ag_by_split_mean,
        raw_folds,
        ag_folds,
        raw_metrics_std,
        ag_by_split_std,
        folds_info,
        last_data_ctx,
    )


# ------------------------------------------------------------------
# Main execution
# ------------------------------------------------------------------
def main():
    args = parse_args()

    # ------------------------------------------------------------------
    # Debug output
    # ------------------------------------------------------------------
    logger.info("=== AutoGluon Training Wrapper Started ===")
    logger.info(f"Working directory: {os.getcwd()}")
    logger.info(f"Command line: {' '.join(sys.argv)}")
    logger.info(f"Parsed args: {vars(args)}")

    # ------------------------------------------------------------------
    # Reproducibility & performance
    # ------------------------------------------------------------------
    set_seeds(args.random_seed)
    if args.deterministic:
        enable_deterministic_mode(args.random_seed)
        logger.info("Deterministic mode enabled (seed=%s)", args.random_seed)
    ensure_local_tmp()
    enable_tensor_cores_if_available()

    # ------------------------------------------------------------------
    # Load datasets
    # ------------------------------------------------------------------
    train_dataset = load_file(args.train_dataset)
    test_dataset = load_file(args.test_dataset) if args.test_dataset else None

    logger.info(f"Train dataset loaded: {len(train_dataset)} rows")
    if test_dataset is not None:
        logger.info(f"Test dataset loaded: {len(test_dataset)} rows")

    # ------------------------------------------------------------------
    # Resolve target column by name; if Galaxy passed a numeric index,
    # translate it to the corresponding header so downstream checks pass.
    # Galaxy's data_column widget is 1-based.
    # ------------------------------------------------------------------
    if args.target_column not in train_dataset.columns and str(args.target_column).isdigit():
        idx = int(args.target_column) - 1
        if 0 <= idx < len(train_dataset.columns):
            resolved = train_dataset.columns[idx]
            logger.info(f"Target column '{args.target_column}' not found; using column #{idx + 1} header '{resolved}' instead.")
            args.target_column = resolved
        else:
            logger.error(f"Numeric target index '{args.target_column}' is out of range for dataset with {len(train_dataset.columns)} columns.")
            sys.exit(1)

    # ------------------------------------------------------------------
    # Image handling (ZIP extraction + absolute path expansion)
    # ------------------------------------------------------------------
    extracted_imgs_path = prepare_image_search_dirs(args)

    image_cols = absolute_path_expander(train_dataset, extracted_imgs_path, None)
    if test_dataset is not None:
        absolute_path_expander(test_dataset, extracted_imgs_path, image_cols)

    # ------------------------------------------------------------------
    # Handle missing images
    # ------------------------------------------------------------------
    train_dataset = handle_missing_images(
        train_dataset,
        image_columns=image_cols,
        strategy=args.missing_image_strategy,
    )
    if test_dataset is not None:
        test_dataset = handle_missing_images(
            test_dataset,
            image_columns=image_cols,
            strategy=args.missing_image_strategy,
        )

    logger.info(f"After cleanup → train: {len(train_dataset)}, test: {len(test_dataset) if test_dataset is not None else 0}")

    # ------------------------------------------------------------------
    # Dataset splitting logic (adds 'split' column to train_dataset)
    # ------------------------------------------------------------------
    split_dataset(
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        target_column=args.target_column,
        split_probabilities=args.split_probabilities,
        validation_size=args.validation_size,
        random_seed=args.random_seed,
    )

    logger.info("Preprocessing complete — ready for AutoGluon training!")
    logger.info(f"Final split counts:\n{train_dataset['split'].value_counts().sort_index()}")

    # Verify target/image/text columns exist
    if args.target_column not in train_dataset.columns:
        logger.error(f"Target column '{args.target_column}' not found in training data.")
        sys.exit(1)
    if test_dataset is not None and args.target_column not in test_dataset.columns:
        logger.error(f"Target column '{args.target_column}' not found in test data.")
        sys.exit(1)

    # Threshold is only meaningful for binary classification; ignore otherwise.
    threshold_for_run = args.threshold
    unique_labels = None
    target_looks_binary = False
    try:
        unique_labels = train_dataset[args.target_column].nunique(dropna=True)
        target_looks_binary = unique_labels == 2
    except Exception:
        logger.warning("Could not inspect target column '%s' for threshold validation; proceeding without binary check.", args.target_column)

    if threshold_for_run is not None:
        if target_looks_binary:
            threshold_for_run = float(threshold_for_run)
            logger.info("Applying custom decision threshold %.4f for binary evaluation.", threshold_for_run)
        else:
            logger.warning(
                "Threshold %.3f provided but target '%s' does not appear binary (unique labels=%s); ignoring threshold.",
                threshold_for_run,
                args.target_column,
                unique_labels if unique_labels is not None else "unknown",
            )
            threshold_for_run = None
    args.threshold = threshold_for_run
    # Image columns are auto-inferred; image_cols already resolved to absolute paths.
    # ------------------------------------------------------------------
    # Build AutoGluon configuration from CLI knobs
    # ------------------------------------------------------------------
    ag_config = autogluon_hyperparameters(
        threshold=args.threshold,
        time_limit=args.time_limit,
        random_seed=args.random_seed,
        epochs=args.epochs,
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        backbone_image=args.backbone_image,
        backbone_text=args.backbone_text,
        preset=args.preset,
        eval_metric=args.eval_metric,
        hyperparameters=args.hyperparameters,
    )
    logger.info(f"AutoGluon config prepared: fit={ag_config.get('fit')}, hyperparameters keys={list(ag_config.get('hyperparameters', {}).keys())}")

    cv_enabled = str2bool(args.cross_validation)
    if cv_enabled:
        (
            predictor,
            raw_metrics,
            ag_by_split,
            raw_folds,
            ag_folds,
            raw_metrics_std,
            ag_by_split_std,
            folds_info,
            data_ctx,
        ) = run_cross_validation(
            args=args,
            df_full=train_dataset,
            test_dataset=test_dataset,
            image_cols=image_cols,
            ag_config=ag_config,
        )
        if predictor is None:
            logger.error("All CV folds failed. Exiting.")
            sys.exit(1)
        eval_results = {
            "raw_metrics": raw_metrics,
            "ag_eval": ag_by_split,
            "fit_summary": None,
        }
    else:
        predictor, data_ctx = run_autogluon_experiment(
            train_dataset=train_dataset,
            test_dataset=test_dataset,
            target_column=args.target_column,
            image_columns=image_cols,
            ag_config=ag_config,
        )
        logger.info("AutoGluon training finished. Model path: %s", getattr(predictor, "path", None))

        # Evaluate predictor on Train/Val/Test splits
        problem_type = infer_problem_type(predictor, train_dataset, args.target_column)
        eval_results = run_autogluon_test_experiment(
            predictor=predictor,
            data_ctx=data_ctx,
            target_column=args.target_column,
            eval_metric=args.eval_metric,
            ag_config=ag_config,
            problem_type=problem_type,
        )
        raw_metrics = eval_results.get("raw_metrics", {})
        ag_by_split = eval_results.get("ag_eval", {})
        raw_folds = ag_folds = raw_metrics_std = ag_by_split_std = None

    logger.info("Transparent metrics by split: %s", eval_results["raw_metrics"])
    logger.info("AutoGluon evaluate() by split: %s", eval_results["ag_eval"])

    if "problem_type" in eval_results:
        problem_type_final = eval_results["problem_type"]
    else:
        problem_type_final = infer_problem_type(predictor, train_dataset, args.target_column)

    write_outputs(
        args=args,
        predictor=predictor,
        problem_type=problem_type_final,
        eval_results=eval_results,
        data_ctx=data_ctx,
        raw_folds=raw_folds,
        ag_folds=ag_folds,
        raw_metrics_std=raw_metrics_std,
        ag_by_split_std=ag_by_split_std,
    )


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s | %(levelname)s | %(message)s",
        datefmt="%H:%M:%S"
    )
    # Quiet noisy image parsing logs (e.g., PIL.PngImagePlugin debug streams)
    logging.getLogger("PIL").setLevel(logging.WARNING)
    logging.getLogger("PIL.PngImagePlugin").setLevel(logging.WARNING)
    main()