diff multimodal_learner.py @ 0:375c36923da1 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
author goeckslab
date Tue, 09 Dec 2025 23:49:47 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/multimodal_learner.py	Tue Dec 09 23:49:47 2025 +0000
@@ -0,0 +1,391 @@
+#!/usr/bin/env python
+"""
+Main entrypoint for AutoGluon multimodal training wrapper.
+"""
+
+import argparse
+import logging
+import os
+import sys
+from typing import List, Optional
+
+import pandas as pd
+from metrics_logic import aggregate_metrics
+from plot_logic import infer_problem_type
+from report_utils import write_outputs
+from sklearn.model_selection import KFold, StratifiedKFold
+from split_logic import split_dataset
+from test_pipeline import run_autogluon_test_experiment
+from training_pipeline import autogluon_hyperparameters, handle_missing_images, run_autogluon_experiment
+# ------------------------------------------------------------------
+# Local imports (your split utilities)
+# ------------------------------------------------------------------
+from utils import (
+    absolute_path_expander,
+    enable_deterministic_mode,
+    enable_tensor_cores_if_available,
+    ensure_local_tmp,
+    load_file,
+    prepare_image_search_dirs,
+    set_seeds,
+    str2bool,
+)
+
+# ------------------------------------------------------------------
+# Logger setup
+# ------------------------------------------------------------------
+logger = logging.getLogger(__name__)
+
+
+# ------------------------------------------------------------------
+# Argument parsing (unchanged from your original, only minor fixes)
+# ------------------------------------------------------------------
+def parse_args(argv=None):
+    parser = argparse.ArgumentParser(description="Train & report an AutoGluon model")
+
+    parser.add_argument("--input_csv_train", dest="train_dataset", required=True)
+    parser.add_argument("--input_csv_test", dest="test_dataset", default=None)
+    parser.add_argument("--target_column", required=True)
+    parser.add_argument("--output_json", default="results.json")
+    parser.add_argument("--output_html", default="report.html")
+    parser.add_argument("--output_config", default=None)
+    parser.add_argument("--images_zip", nargs="*", default=None,
+                        help="One or more ZIP files that contain image assets")
+    parser.add_argument("--missing_image_strategy", default="false",
+                        help="true/false: remove rows with missing images or use placeholder")
+    parser.add_argument("--threshold", type=float, default=None)
+    parser.add_argument("--time_limit", type=int, default=None)
+    parser.add_argument("--deterministic", action="store_true", default=False,
+                        help="Enable deterministic algorithms to reduce run-to-run variance")
+    parser.add_argument("--random_seed", type=int, default=42)
+    parser.add_argument("--cross_validation", type=str, default="false")
+    parser.add_argument("--num_folds", type=int, default=5)
+    parser.add_argument("--epochs", type=int, default=None)
+    parser.add_argument("--learning_rate", type=float, default=None)
+    parser.add_argument("--batch_size", type=int, default=None)
+    parser.add_argument("--backbone_image", type=str, default="swin_base_patch4_window7_224")
+    parser.add_argument("--backbone_text", type=str, default="microsoft/deberta-v3-base")
+    parser.add_argument("--validation_size", type=float, default=0.2)
+    parser.add_argument("--split_probabilities", type=float, nargs=3,
+                        default=[0.7, 0.1, 0.2], metavar=("train", "val", "test"))
+    parser.add_argument("--preset", choices=["medium_quality", "high_quality", "best_quality"],
+                        default="medium_quality")
+    parser.add_argument("--eval_metric", default="roc_auc")
+    parser.add_argument("--hyperparameters", default=None)
+
+    args, unknown = parser.parse_known_args(argv)
+    if unknown:
+        logger.warning("Ignoring unknown CLI tokens: %s", unknown)
+
+    # -------------------------- Validation --------------------------
+    if not (0.0 <= args.validation_size <= 1.0):
+        parser.error("--validation_size must be in [0, 1]")
+    if len(args.split_probabilities) != 3 or abs(sum(args.split_probabilities) - 1.0) > 1e-6:
+        parser.error("--split_probabilities must be three numbers summing to 1.0")
+    if args.cross_validation.lower() == "true" and (args.num_folds < 2):
+        parser.error("--num_folds must be >= 2 when --cross_validation is true")
+
+    return args
+
+
+def run_cross_validation(
+    args,
+    df_full: pd.DataFrame,
+    test_dataset: Optional[pd.DataFrame],
+    image_cols: List[str],
+    ag_config: dict,
+):
+    """Cross-validation loop returning aggregated metrics and last predictor."""
+    df_full = df_full.drop(columns=["split"], errors="ignore")
+    y = df_full[args.target_column]
+    try:
+        use_stratified = y.dtype == object or y.nunique() <= 20
+    except Exception:
+        use_stratified = False
+
+    kf = StratifiedKFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed)) if use_stratified else KFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed))
+
+    raw_folds = []
+    ag_folds = []
+    folds_info = []
+    last_predictor = None
+    last_data_ctx = None
+
+    for fold_idx, (train_idx, val_idx) in enumerate(kf.split(df_full, y if use_stratified else None), start=1):
+        logger.info(f"CV fold {fold_idx}/{args.num_folds}")
+        df_tr = df_full.iloc[train_idx].copy()
+        df_va = df_full.iloc[val_idx].copy()
+
+        df_tr["split"] = "train"
+        df_va["split"] = "val"
+        fold_dataset = pd.concat([df_tr, df_va], ignore_index=True)
+
+        predictor_fold, data_ctx = run_autogluon_experiment(
+            train_dataset=fold_dataset,
+            test_dataset=test_dataset,
+            target_column=args.target_column,
+            image_columns=image_cols,
+            ag_config=ag_config,
+        )
+        last_predictor = predictor_fold
+        last_data_ctx = data_ctx
+        problem_type = infer_problem_type(predictor_fold, df_tr, args.target_column)
+        eval_results = run_autogluon_test_experiment(
+            predictor=predictor_fold,
+            data_ctx=data_ctx,
+            target_column=args.target_column,
+            eval_metric=args.eval_metric,
+            ag_config=ag_config,
+            problem_type=problem_type,
+        )
+
+        raw_metrics_fold = eval_results.get("raw_metrics", {})
+        ag_by_split_fold = eval_results.get("ag_eval", {})
+        raw_folds.append(raw_metrics_fold)
+        ag_folds.append(ag_by_split_fold)
+        folds_info.append(
+            {
+                "fold": int(fold_idx),
+                "predictor_path": getattr(predictor_fold, "path", None),
+                "raw_metrics": raw_metrics_fold,
+                "ag_eval": ag_by_split_fold,
+            }
+        )
+
+    raw_metrics_mean, raw_metrics_std = aggregate_metrics(raw_folds)
+    ag_by_split_mean, ag_by_split_std = aggregate_metrics(ag_folds)
+    return (
+        last_predictor,
+        raw_metrics_mean,
+        ag_by_split_mean,
+        raw_folds,
+        ag_folds,
+        raw_metrics_std,
+        ag_by_split_std,
+        folds_info,
+        last_data_ctx,
+    )
+
+
+# ------------------------------------------------------------------
+# Main execution
+# ------------------------------------------------------------------
+def main():
+    args = parse_args()
+
+    # ------------------------------------------------------------------
+    # Debug output
+    # ------------------------------------------------------------------
+    logger.info("=== AutoGluon Training Wrapper Started ===")
+    logger.info(f"Working directory: {os.getcwd()}")
+    logger.info(f"Command line: {' '.join(sys.argv)}")
+    logger.info(f"Parsed args: {vars(args)}")
+
+    # ------------------------------------------------------------------
+    # Reproducibility & performance
+    # ------------------------------------------------------------------
+    set_seeds(args.random_seed)
+    if args.deterministic:
+        enable_deterministic_mode(args.random_seed)
+        logger.info("Deterministic mode enabled (seed=%s)", args.random_seed)
+    ensure_local_tmp()
+    enable_tensor_cores_if_available()
+
+    # ------------------------------------------------------------------
+    # Load datasets
+    # ------------------------------------------------------------------
+    train_dataset = load_file(args.train_dataset)
+    test_dataset = load_file(args.test_dataset) if args.test_dataset else None
+
+    logger.info(f"Train dataset loaded: {len(train_dataset)} rows")
+    if test_dataset is not None:
+        logger.info(f"Test dataset loaded: {len(test_dataset)} rows")
+
+    # ------------------------------------------------------------------
+    # Resolve target column by name; if Galaxy passed a numeric index,
+    # translate it to the corresponding header so downstream checks pass.
+    # Galaxy's data_column widget is 1-based.
+    # ------------------------------------------------------------------
+    if args.target_column not in train_dataset.columns and str(args.target_column).isdigit():
+        idx = int(args.target_column) - 1
+        if 0 <= idx < len(train_dataset.columns):
+            resolved = train_dataset.columns[idx]
+            logger.info(f"Target column '{args.target_column}' not found; using column #{idx + 1} header '{resolved}' instead.")
+            args.target_column = resolved
+        else:
+            logger.error(f"Numeric target index '{args.target_column}' is out of range for dataset with {len(train_dataset.columns)} columns.")
+            sys.exit(1)
+
+    # ------------------------------------------------------------------
+    # Image handling (ZIP extraction + absolute path expansion)
+    # ------------------------------------------------------------------
+    extracted_imgs_path = prepare_image_search_dirs(args)
+
+    image_cols = absolute_path_expander(train_dataset, extracted_imgs_path, None)
+    if test_dataset is not None:
+        absolute_path_expander(test_dataset, extracted_imgs_path, image_cols)
+
+    # ------------------------------------------------------------------
+    # Handle missing images
+    # ------------------------------------------------------------------
+    train_dataset = handle_missing_images(
+        train_dataset,
+        image_columns=image_cols,
+        strategy=args.missing_image_strategy,
+    )
+    if test_dataset is not None:
+        test_dataset = handle_missing_images(
+            test_dataset,
+            image_columns=image_cols,
+            strategy=args.missing_image_strategy,
+        )
+
+    logger.info(f"After cleanup → train: {len(train_dataset)}, test: {len(test_dataset) if test_dataset is not None else 0}")
+
+    # ------------------------------------------------------------------
+    # Dataset splitting logic (adds 'split' column to train_dataset)
+    # ------------------------------------------------------------------
+    split_dataset(
+        train_dataset=train_dataset,
+        test_dataset=test_dataset,
+        target_column=args.target_column,
+        split_probabilities=args.split_probabilities,
+        validation_size=args.validation_size,
+        random_seed=args.random_seed,
+    )
+
+    logger.info("Preprocessing complete — ready for AutoGluon training!")
+    logger.info(f"Final split counts:\n{train_dataset['split'].value_counts().sort_index()}")
+
+    # Verify target/image/text columns exist
+    if args.target_column not in train_dataset.columns:
+        logger.error(f"Target column '{args.target_column}' not found in training data.")
+        sys.exit(1)
+    if test_dataset is not None and args.target_column not in test_dataset.columns:
+        logger.error(f"Target column '{args.target_column}' not found in test data.")
+        sys.exit(1)
+
+    # Threshold is only meaningful for binary classification; ignore otherwise.
+    threshold_for_run = args.threshold
+    unique_labels = None
+    target_looks_binary = False
+    try:
+        unique_labels = train_dataset[args.target_column].nunique(dropna=True)
+        target_looks_binary = unique_labels == 2
+    except Exception:
+        logger.warning("Could not inspect target column '%s' for threshold validation; proceeding without binary check.", args.target_column)
+
+    if threshold_for_run is not None:
+        if target_looks_binary:
+            threshold_for_run = float(threshold_for_run)
+            logger.info("Applying custom decision threshold %.4f for binary evaluation.", threshold_for_run)
+        else:
+            logger.warning(
+                "Threshold %.3f provided but target '%s' does not appear binary (unique labels=%s); ignoring threshold.",
+                threshold_for_run,
+                args.target_column,
+                unique_labels if unique_labels is not None else "unknown",
+            )
+            threshold_for_run = None
+    args.threshold = threshold_for_run
+    # Image columns are auto-inferred; image_cols already resolved to absolute paths.
+    # ------------------------------------------------------------------
+    # Build AutoGluon configuration from CLI knobs
+    # ------------------------------------------------------------------
+    ag_config = autogluon_hyperparameters(
+        threshold=args.threshold,
+        time_limit=args.time_limit,
+        random_seed=args.random_seed,
+        epochs=args.epochs,
+        learning_rate=args.learning_rate,
+        batch_size=args.batch_size,
+        backbone_image=args.backbone_image,
+        backbone_text=args.backbone_text,
+        preset=args.preset,
+        eval_metric=args.eval_metric,
+        hyperparameters=args.hyperparameters,
+    )
+    logger.info(f"AutoGluon config prepared: fit={ag_config.get('fit')}, hyperparameters keys={list(ag_config.get('hyperparameters', {}).keys())}")
+
+    cv_enabled = str2bool(args.cross_validation)
+    if cv_enabled:
+        (
+            predictor,
+            raw_metrics,
+            ag_by_split,
+            raw_folds,
+            ag_folds,
+            raw_metrics_std,
+            ag_by_split_std,
+            folds_info,
+            data_ctx,
+        ) = run_cross_validation(
+            args=args,
+            df_full=train_dataset,
+            test_dataset=test_dataset,
+            image_cols=image_cols,
+            ag_config=ag_config,
+        )
+        if predictor is None:
+            logger.error("All CV folds failed. Exiting.")
+            sys.exit(1)
+        eval_results = {
+            "raw_metrics": raw_metrics,
+            "ag_eval": ag_by_split,
+            "fit_summary": None,
+        }
+    else:
+        predictor, data_ctx = run_autogluon_experiment(
+            train_dataset=train_dataset,
+            test_dataset=test_dataset,
+            target_column=args.target_column,
+            image_columns=image_cols,
+            ag_config=ag_config,
+        )
+        logger.info("AutoGluon training finished. Model path: %s", getattr(predictor, "path", None))
+
+        # Evaluate predictor on Train/Val/Test splits
+        problem_type = infer_problem_type(predictor, train_dataset, args.target_column)
+        eval_results = run_autogluon_test_experiment(
+            predictor=predictor,
+            data_ctx=data_ctx,
+            target_column=args.target_column,
+            eval_metric=args.eval_metric,
+            ag_config=ag_config,
+            problem_type=problem_type,
+        )
+        raw_metrics = eval_results.get("raw_metrics", {})
+        ag_by_split = eval_results.get("ag_eval", {})
+        raw_folds = ag_folds = raw_metrics_std = ag_by_split_std = None
+
+    logger.info("Transparent metrics by split: %s", eval_results["raw_metrics"])
+    logger.info("AutoGluon evaluate() by split: %s", eval_results["ag_eval"])
+
+    if "problem_type" in eval_results:
+        problem_type_final = eval_results["problem_type"]
+    else:
+        problem_type_final = infer_problem_type(predictor, train_dataset, args.target_column)
+
+    write_outputs(
+        args=args,
+        predictor=predictor,
+        problem_type=problem_type_final,
+        eval_results=eval_results,
+        data_ctx=data_ctx,
+        raw_folds=raw_folds,
+        ag_folds=ag_folds,
+        raw_metrics_std=raw_metrics_std,
+        ag_by_split_std=ag_by_split_std,
+    )
+
+
+if __name__ == "__main__":
+    logging.basicConfig(
+        level=logging.INFO,
+        format="%(asctime)s | %(levelname)s | %(message)s",
+        datefmt="%H:%M:%S"
+    )
+    # Quiet noisy image parsing logs (e.g., PIL.PngImagePlugin debug streams)
+    logging.getLogger("PIL").setLevel(logging.WARNING)
+    logging.getLogger("PIL.PngImagePlugin").setLevel(logging.WARNING)
+    main()