Mercurial > repos > goeckslab > multimodal_learner
comparison multimodal_learner.py @ 0:375c36923da1 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
| author | goeckslab |
|---|---|
| date | Tue, 09 Dec 2025 23:49:47 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:375c36923da1 |
|---|---|
| 1 #!/usr/bin/env python | |
| 2 """ | |
| 3 Main entrypoint for AutoGluon multimodal training wrapper. | |
| 4 """ | |
| 5 | |
| 6 import argparse | |
| 7 import logging | |
| 8 import os | |
| 9 import sys | |
| 10 from typing import List, Optional | |
| 11 | |
| 12 import pandas as pd | |
| 13 from metrics_logic import aggregate_metrics | |
| 14 from plot_logic import infer_problem_type | |
| 15 from report_utils import write_outputs | |
| 16 from sklearn.model_selection import KFold, StratifiedKFold | |
| 17 from split_logic import split_dataset | |
| 18 from test_pipeline import run_autogluon_test_experiment | |
| 19 from training_pipeline import autogluon_hyperparameters, handle_missing_images, run_autogluon_experiment | |
| 20 # ------------------------------------------------------------------ | |
| 21 # Local imports (your split utilities) | |
| 22 # ------------------------------------------------------------------ | |
| 23 from utils import ( | |
| 24 absolute_path_expander, | |
| 25 enable_deterministic_mode, | |
| 26 enable_tensor_cores_if_available, | |
| 27 ensure_local_tmp, | |
| 28 load_file, | |
| 29 prepare_image_search_dirs, | |
| 30 set_seeds, | |
| 31 str2bool, | |
| 32 ) | |
| 33 | |
| 34 # ------------------------------------------------------------------ | |
| 35 # Logger setup | |
| 36 # ------------------------------------------------------------------ | |
| 37 logger = logging.getLogger(__name__) | |
| 38 | |
| 39 | |
| 40 # ------------------------------------------------------------------ | |
| 41 # Argument parsing (unchanged from your original, only minor fixes) | |
| 42 # ------------------------------------------------------------------ | |
| 43 def parse_args(argv=None): | |
| 44 parser = argparse.ArgumentParser(description="Train & report an AutoGluon model") | |
| 45 | |
| 46 parser.add_argument("--input_csv_train", dest="train_dataset", required=True) | |
| 47 parser.add_argument("--input_csv_test", dest="test_dataset", default=None) | |
| 48 parser.add_argument("--target_column", required=True) | |
| 49 parser.add_argument("--output_json", default="results.json") | |
| 50 parser.add_argument("--output_html", default="report.html") | |
| 51 parser.add_argument("--output_config", default=None) | |
| 52 parser.add_argument("--images_zip", nargs="*", default=None, | |
| 53 help="One or more ZIP files that contain image assets") | |
| 54 parser.add_argument("--missing_image_strategy", default="false", | |
| 55 help="true/false: remove rows with missing images or use placeholder") | |
| 56 parser.add_argument("--threshold", type=float, default=None) | |
| 57 parser.add_argument("--time_limit", type=int, default=None) | |
| 58 parser.add_argument("--deterministic", action="store_true", default=False, | |
| 59 help="Enable deterministic algorithms to reduce run-to-run variance") | |
| 60 parser.add_argument("--random_seed", type=int, default=42) | |
| 61 parser.add_argument("--cross_validation", type=str, default="false") | |
| 62 parser.add_argument("--num_folds", type=int, default=5) | |
| 63 parser.add_argument("--epochs", type=int, default=None) | |
| 64 parser.add_argument("--learning_rate", type=float, default=None) | |
| 65 parser.add_argument("--batch_size", type=int, default=None) | |
| 66 parser.add_argument("--backbone_image", type=str, default="swin_base_patch4_window7_224") | |
| 67 parser.add_argument("--backbone_text", type=str, default="microsoft/deberta-v3-base") | |
| 68 parser.add_argument("--validation_size", type=float, default=0.2) | |
| 69 parser.add_argument("--split_probabilities", type=float, nargs=3, | |
| 70 default=[0.7, 0.1, 0.2], metavar=("train", "val", "test")) | |
| 71 parser.add_argument("--preset", choices=["medium_quality", "high_quality", "best_quality"], | |
| 72 default="medium_quality") | |
| 73 parser.add_argument("--eval_metric", default="roc_auc") | |
| 74 parser.add_argument("--hyperparameters", default=None) | |
| 75 | |
| 76 args, unknown = parser.parse_known_args(argv) | |
| 77 if unknown: | |
| 78 logger.warning("Ignoring unknown CLI tokens: %s", unknown) | |
| 79 | |
| 80 # -------------------------- Validation -------------------------- | |
| 81 if not (0.0 <= args.validation_size <= 1.0): | |
| 82 parser.error("--validation_size must be in [0, 1]") | |
| 83 if len(args.split_probabilities) != 3 or abs(sum(args.split_probabilities) - 1.0) > 1e-6: | |
| 84 parser.error("--split_probabilities must be three numbers summing to 1.0") | |
| 85 if args.cross_validation.lower() == "true" and (args.num_folds < 2): | |
| 86 parser.error("--num_folds must be >= 2 when --cross_validation is true") | |
| 87 | |
| 88 return args | |
| 89 | |
| 90 | |
| 91 def run_cross_validation( | |
| 92 args, | |
| 93 df_full: pd.DataFrame, | |
| 94 test_dataset: Optional[pd.DataFrame], | |
| 95 image_cols: List[str], | |
| 96 ag_config: dict, | |
| 97 ): | |
| 98 """Cross-validation loop returning aggregated metrics and last predictor.""" | |
| 99 df_full = df_full.drop(columns=["split"], errors="ignore") | |
| 100 y = df_full[args.target_column] | |
| 101 try: | |
| 102 use_stratified = y.dtype == object or y.nunique() <= 20 | |
| 103 except Exception: | |
| 104 use_stratified = False | |
| 105 | |
| 106 kf = StratifiedKFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed)) if use_stratified else KFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed)) | |
| 107 | |
| 108 raw_folds = [] | |
| 109 ag_folds = [] | |
| 110 folds_info = [] | |
| 111 last_predictor = None | |
| 112 last_data_ctx = None | |
| 113 | |
| 114 for fold_idx, (train_idx, val_idx) in enumerate(kf.split(df_full, y if use_stratified else None), start=1): | |
| 115 logger.info(f"CV fold {fold_idx}/{args.num_folds}") | |
| 116 df_tr = df_full.iloc[train_idx].copy() | |
| 117 df_va = df_full.iloc[val_idx].copy() | |
| 118 | |
| 119 df_tr["split"] = "train" | |
| 120 df_va["split"] = "val" | |
| 121 fold_dataset = pd.concat([df_tr, df_va], ignore_index=True) | |
| 122 | |
| 123 predictor_fold, data_ctx = run_autogluon_experiment( | |
| 124 train_dataset=fold_dataset, | |
| 125 test_dataset=test_dataset, | |
| 126 target_column=args.target_column, | |
| 127 image_columns=image_cols, | |
| 128 ag_config=ag_config, | |
| 129 ) | |
| 130 last_predictor = predictor_fold | |
| 131 last_data_ctx = data_ctx | |
| 132 problem_type = infer_problem_type(predictor_fold, df_tr, args.target_column) | |
| 133 eval_results = run_autogluon_test_experiment( | |
| 134 predictor=predictor_fold, | |
| 135 data_ctx=data_ctx, | |
| 136 target_column=args.target_column, | |
| 137 eval_metric=args.eval_metric, | |
| 138 ag_config=ag_config, | |
| 139 problem_type=problem_type, | |
| 140 ) | |
| 141 | |
| 142 raw_metrics_fold = eval_results.get("raw_metrics", {}) | |
| 143 ag_by_split_fold = eval_results.get("ag_eval", {}) | |
| 144 raw_folds.append(raw_metrics_fold) | |
| 145 ag_folds.append(ag_by_split_fold) | |
| 146 folds_info.append( | |
| 147 { | |
| 148 "fold": int(fold_idx), | |
| 149 "predictor_path": getattr(predictor_fold, "path", None), | |
| 150 "raw_metrics": raw_metrics_fold, | |
| 151 "ag_eval": ag_by_split_fold, | |
| 152 } | |
| 153 ) | |
| 154 | |
| 155 raw_metrics_mean, raw_metrics_std = aggregate_metrics(raw_folds) | |
| 156 ag_by_split_mean, ag_by_split_std = aggregate_metrics(ag_folds) | |
| 157 return ( | |
| 158 last_predictor, | |
| 159 raw_metrics_mean, | |
| 160 ag_by_split_mean, | |
| 161 raw_folds, | |
| 162 ag_folds, | |
| 163 raw_metrics_std, | |
| 164 ag_by_split_std, | |
| 165 folds_info, | |
| 166 last_data_ctx, | |
| 167 ) | |
| 168 | |
| 169 | |
| 170 # ------------------------------------------------------------------ | |
| 171 # Main execution | |
| 172 # ------------------------------------------------------------------ | |
| 173 def main(): | |
| 174 args = parse_args() | |
| 175 | |
| 176 # ------------------------------------------------------------------ | |
| 177 # Debug output | |
| 178 # ------------------------------------------------------------------ | |
| 179 logger.info("=== AutoGluon Training Wrapper Started ===") | |
| 180 logger.info(f"Working directory: {os.getcwd()}") | |
| 181 logger.info(f"Command line: {' '.join(sys.argv)}") | |
| 182 logger.info(f"Parsed args: {vars(args)}") | |
| 183 | |
| 184 # ------------------------------------------------------------------ | |
| 185 # Reproducibility & performance | |
| 186 # ------------------------------------------------------------------ | |
| 187 set_seeds(args.random_seed) | |
| 188 if args.deterministic: | |
| 189 enable_deterministic_mode(args.random_seed) | |
| 190 logger.info("Deterministic mode enabled (seed=%s)", args.random_seed) | |
| 191 ensure_local_tmp() | |
| 192 enable_tensor_cores_if_available() | |
| 193 | |
| 194 # ------------------------------------------------------------------ | |
| 195 # Load datasets | |
| 196 # ------------------------------------------------------------------ | |
| 197 train_dataset = load_file(args.train_dataset) | |
| 198 test_dataset = load_file(args.test_dataset) if args.test_dataset else None | |
| 199 | |
| 200 logger.info(f"Train dataset loaded: {len(train_dataset)} rows") | |
| 201 if test_dataset is not None: | |
| 202 logger.info(f"Test dataset loaded: {len(test_dataset)} rows") | |
| 203 | |
| 204 # ------------------------------------------------------------------ | |
| 205 # Resolve target column by name; if Galaxy passed a numeric index, | |
| 206 # translate it to the corresponding header so downstream checks pass. | |
| 207 # Galaxy's data_column widget is 1-based. | |
| 208 # ------------------------------------------------------------------ | |
| 209 if args.target_column not in train_dataset.columns and str(args.target_column).isdigit(): | |
| 210 idx = int(args.target_column) - 1 | |
| 211 if 0 <= idx < len(train_dataset.columns): | |
| 212 resolved = train_dataset.columns[idx] | |
| 213 logger.info(f"Target column '{args.target_column}' not found; using column #{idx + 1} header '{resolved}' instead.") | |
| 214 args.target_column = resolved | |
| 215 else: | |
| 216 logger.error(f"Numeric target index '{args.target_column}' is out of range for dataset with {len(train_dataset.columns)} columns.") | |
| 217 sys.exit(1) | |
| 218 | |
| 219 # ------------------------------------------------------------------ | |
| 220 # Image handling (ZIP extraction + absolute path expansion) | |
| 221 # ------------------------------------------------------------------ | |
| 222 extracted_imgs_path = prepare_image_search_dirs(args) | |
| 223 | |
| 224 image_cols = absolute_path_expander(train_dataset, extracted_imgs_path, None) | |
| 225 if test_dataset is not None: | |
| 226 absolute_path_expander(test_dataset, extracted_imgs_path, image_cols) | |
| 227 | |
| 228 # ------------------------------------------------------------------ | |
| 229 # Handle missing images | |
| 230 # ------------------------------------------------------------------ | |
| 231 train_dataset = handle_missing_images( | |
| 232 train_dataset, | |
| 233 image_columns=image_cols, | |
| 234 strategy=args.missing_image_strategy, | |
| 235 ) | |
| 236 if test_dataset is not None: | |
| 237 test_dataset = handle_missing_images( | |
| 238 test_dataset, | |
| 239 image_columns=image_cols, | |
| 240 strategy=args.missing_image_strategy, | |
| 241 ) | |
| 242 | |
| 243 logger.info(f"After cleanup → train: {len(train_dataset)}, test: {len(test_dataset) if test_dataset is not None else 0}") | |
| 244 | |
| 245 # ------------------------------------------------------------------ | |
| 246 # Dataset splitting logic (adds 'split' column to train_dataset) | |
| 247 # ------------------------------------------------------------------ | |
| 248 split_dataset( | |
| 249 train_dataset=train_dataset, | |
| 250 test_dataset=test_dataset, | |
| 251 target_column=args.target_column, | |
| 252 split_probabilities=args.split_probabilities, | |
| 253 validation_size=args.validation_size, | |
| 254 random_seed=args.random_seed, | |
| 255 ) | |
| 256 | |
| 257 logger.info("Preprocessing complete — ready for AutoGluon training!") | |
| 258 logger.info(f"Final split counts:\n{train_dataset['split'].value_counts().sort_index()}") | |
| 259 | |
| 260 # Verify target/image/text columns exist | |
| 261 if args.target_column not in train_dataset.columns: | |
| 262 logger.error(f"Target column '{args.target_column}' not found in training data.") | |
| 263 sys.exit(1) | |
| 264 if test_dataset is not None and args.target_column not in test_dataset.columns: | |
| 265 logger.error(f"Target column '{args.target_column}' not found in test data.") | |
| 266 sys.exit(1) | |
| 267 | |
| 268 # Threshold is only meaningful for binary classification; ignore otherwise. | |
| 269 threshold_for_run = args.threshold | |
| 270 unique_labels = None | |
| 271 target_looks_binary = False | |
| 272 try: | |
| 273 unique_labels = train_dataset[args.target_column].nunique(dropna=True) | |
| 274 target_looks_binary = unique_labels == 2 | |
| 275 except Exception: | |
| 276 logger.warning("Could not inspect target column '%s' for threshold validation; proceeding without binary check.", args.target_column) | |
| 277 | |
| 278 if threshold_for_run is not None: | |
| 279 if target_looks_binary: | |
| 280 threshold_for_run = float(threshold_for_run) | |
| 281 logger.info("Applying custom decision threshold %.4f for binary evaluation.", threshold_for_run) | |
| 282 else: | |
| 283 logger.warning( | |
| 284 "Threshold %.3f provided but target '%s' does not appear binary (unique labels=%s); ignoring threshold.", | |
| 285 threshold_for_run, | |
| 286 args.target_column, | |
| 287 unique_labels if unique_labels is not None else "unknown", | |
| 288 ) | |
| 289 threshold_for_run = None | |
| 290 args.threshold = threshold_for_run | |
| 291 # Image columns are auto-inferred; image_cols already resolved to absolute paths. | |
| 292 # ------------------------------------------------------------------ | |
| 293 # Build AutoGluon configuration from CLI knobs | |
| 294 # ------------------------------------------------------------------ | |
| 295 ag_config = autogluon_hyperparameters( | |
| 296 threshold=args.threshold, | |
| 297 time_limit=args.time_limit, | |
| 298 random_seed=args.random_seed, | |
| 299 epochs=args.epochs, | |
| 300 learning_rate=args.learning_rate, | |
| 301 batch_size=args.batch_size, | |
| 302 backbone_image=args.backbone_image, | |
| 303 backbone_text=args.backbone_text, | |
| 304 preset=args.preset, | |
| 305 eval_metric=args.eval_metric, | |
| 306 hyperparameters=args.hyperparameters, | |
| 307 ) | |
| 308 logger.info(f"AutoGluon config prepared: fit={ag_config.get('fit')}, hyperparameters keys={list(ag_config.get('hyperparameters', {}).keys())}") | |
| 309 | |
| 310 cv_enabled = str2bool(args.cross_validation) | |
| 311 if cv_enabled: | |
| 312 ( | |
| 313 predictor, | |
| 314 raw_metrics, | |
| 315 ag_by_split, | |
| 316 raw_folds, | |
| 317 ag_folds, | |
| 318 raw_metrics_std, | |
| 319 ag_by_split_std, | |
| 320 folds_info, | |
| 321 data_ctx, | |
| 322 ) = run_cross_validation( | |
| 323 args=args, | |
| 324 df_full=train_dataset, | |
| 325 test_dataset=test_dataset, | |
| 326 image_cols=image_cols, | |
| 327 ag_config=ag_config, | |
| 328 ) | |
| 329 if predictor is None: | |
| 330 logger.error("All CV folds failed. Exiting.") | |
| 331 sys.exit(1) | |
| 332 eval_results = { | |
| 333 "raw_metrics": raw_metrics, | |
| 334 "ag_eval": ag_by_split, | |
| 335 "fit_summary": None, | |
| 336 } | |
| 337 else: | |
| 338 predictor, data_ctx = run_autogluon_experiment( | |
| 339 train_dataset=train_dataset, | |
| 340 test_dataset=test_dataset, | |
| 341 target_column=args.target_column, | |
| 342 image_columns=image_cols, | |
| 343 ag_config=ag_config, | |
| 344 ) | |
| 345 logger.info("AutoGluon training finished. Model path: %s", getattr(predictor, "path", None)) | |
| 346 | |
| 347 # Evaluate predictor on Train/Val/Test splits | |
| 348 problem_type = infer_problem_type(predictor, train_dataset, args.target_column) | |
| 349 eval_results = run_autogluon_test_experiment( | |
| 350 predictor=predictor, | |
| 351 data_ctx=data_ctx, | |
| 352 target_column=args.target_column, | |
| 353 eval_metric=args.eval_metric, | |
| 354 ag_config=ag_config, | |
| 355 problem_type=problem_type, | |
| 356 ) | |
| 357 raw_metrics = eval_results.get("raw_metrics", {}) | |
| 358 ag_by_split = eval_results.get("ag_eval", {}) | |
| 359 raw_folds = ag_folds = raw_metrics_std = ag_by_split_std = None | |
| 360 | |
| 361 logger.info("Transparent metrics by split: %s", eval_results["raw_metrics"]) | |
| 362 logger.info("AutoGluon evaluate() by split: %s", eval_results["ag_eval"]) | |
| 363 | |
| 364 if "problem_type" in eval_results: | |
| 365 problem_type_final = eval_results["problem_type"] | |
| 366 else: | |
| 367 problem_type_final = infer_problem_type(predictor, train_dataset, args.target_column) | |
| 368 | |
| 369 write_outputs( | |
| 370 args=args, | |
| 371 predictor=predictor, | |
| 372 problem_type=problem_type_final, | |
| 373 eval_results=eval_results, | |
| 374 data_ctx=data_ctx, | |
| 375 raw_folds=raw_folds, | |
| 376 ag_folds=ag_folds, | |
| 377 raw_metrics_std=raw_metrics_std, | |
| 378 ag_by_split_std=ag_by_split_std, | |
| 379 ) | |
| 380 | |
| 381 | |
| 382 if __name__ == "__main__": | |
| 383 logging.basicConfig( | |
| 384 level=logging.INFO, | |
| 385 format="%(asctime)s | %(levelname)s | %(message)s", | |
| 386 datefmt="%H:%M:%S" | |
| 387 ) | |
| 388 # Quiet noisy image parsing logs (e.g., PIL.PngImagePlugin debug streams) | |
| 389 logging.getLogger("PIL").setLevel(logging.WARNING) | |
| 390 logging.getLogger("PIL.PngImagePlugin").setLevel(logging.WARNING) | |
| 391 main() |
