Mercurial > repos > goeckslab > multimodal_learner
comparison training_pipeline.py @ 0:375c36923da1 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
| author | goeckslab |
|---|---|
| date | Tue, 09 Dec 2025 23:49:47 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:375c36923da1 |
|---|---|
| 1 from __future__ import annotations | |
| 2 | |
| 3 import contextlib | |
| 4 import importlib | |
| 5 import io | |
| 6 import json | |
| 7 import logging | |
| 8 import os | |
| 9 import tempfile | |
| 10 import uuid | |
| 11 from pathlib import Path | |
| 12 from typing import Dict, List, Optional, Tuple | |
| 13 | |
| 14 import numpy as np | |
| 15 import pandas as pd | |
| 16 import torch | |
| 17 from autogluon.multimodal import MultiModalPredictor | |
| 18 from metrics_logic import compute_metrics_for_split, evaluate_all_transparency | |
| 19 from packaging.version import Version | |
| 20 | |
| 21 logger = logging.getLogger(__name__) | |
| 22 | |
| 23 # ---------------------- small utilities ---------------------- | |
| 24 | |
| 25 | |
| 26 def load_user_hparams(hp_arg: Optional[str]) -> dict: | |
| 27 """Parse --hyperparameters (inline JSON or path to .json).""" | |
| 28 if not hp_arg: | |
| 29 return {} | |
| 30 try: | |
| 31 s = hp_arg.strip() | |
| 32 if s.startswith("{"): | |
| 33 return json.loads(s) | |
| 34 with open(s, "r") as f: | |
| 35 return json.load(f) | |
| 36 except Exception as e: | |
| 37 logger.warning(f"Could not parse --hyperparameters: {e}. Ignoring.") | |
| 38 return {} | |
| 39 | |
| 40 | |
| 41 def deep_update(dst: dict, src: dict) -> dict: | |
| 42 """Recursive dict update (src overrides dst).""" | |
| 43 for k, v in (src or {}).items(): | |
| 44 if isinstance(v, dict) and isinstance(dst.get(k), dict): | |
| 45 deep_update(dst[k], v) | |
| 46 else: | |
| 47 dst[k] = v | |
| 48 return dst | |
| 49 | |
| 50 | |
| 51 @contextlib.contextmanager | |
| 52 def suppress_stdout_stderr(): | |
| 53 """Silence noisy prints from AG internals (fit_summary).""" | |
| 54 with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()): | |
| 55 yield | |
| 56 | |
| 57 | |
| 58 def ag_evaluate_safely(predictor, df: Optional[pd.DataFrame], metrics: Optional[List[str]] = None) -> Dict[str, float]: | |
| 59 """ | |
| 60 Call predictor.evaluate and normalize the output to a dict. | |
| 61 """ | |
| 62 if df is None or len(df) == 0: | |
| 63 return {} | |
| 64 try: | |
| 65 res = predictor.evaluate(df, metrics=metrics) | |
| 66 except TypeError: | |
| 67 if metrics and len(metrics) == 1: | |
| 68 res = predictor.evaluate(df, metrics[0]) | |
| 69 else: | |
| 70 res = predictor.evaluate(df) | |
| 71 if isinstance(res, (int, float, np.floating)): | |
| 72 name = (metrics[0] if metrics else "metric") | |
| 73 return {name: float(res)} | |
| 74 if isinstance(res, dict): | |
| 75 return {k: float(v) for k, v in res.items()} | |
| 76 return {"metric": float(res)} | |
| 77 | |
| 78 | |
| 79 # ---------------------- hparams & training ---------------------- | |
| 80 def build_mm_hparams(args, df_train: pd.DataFrame, image_columns: Optional[List[str]]) -> dict: | |
| 81 """ | |
| 82 Build hyperparameters for MultiModalPredictor. | |
| 83 Handles text checkpoints for torch<2.6 and merges user overrides. | |
| 84 """ | |
| 85 inferred_text_cols = [ | |
| 86 c for c in df_train.columns | |
| 87 if c != args.label_column | |
| 88 and str(df_train[c].dtype) == "object" | |
| 89 and df_train[c].notna().any() | |
| 90 ] | |
| 91 text_cols = inferred_text_cols | |
| 92 | |
| 93 ag_version = None | |
| 94 try: | |
| 95 ag_mod = importlib.import_module("autogluon") | |
| 96 ag_ver = getattr(ag_mod, "__version__", None) | |
| 97 if ag_ver: | |
| 98 ag_version = Version(str(ag_ver)) | |
| 99 except Exception: | |
| 100 ag_mod = None | |
| 101 | |
| 102 def _log_missing_support(key: str) -> None: | |
| 103 logger.info( | |
| 104 "AutoGluon version %s does not expose '%s'; skipping override.", | |
| 105 ag_version or "unknown", | |
| 106 key, | |
| 107 ) | |
| 108 | |
| 109 hp = {} | |
| 110 | |
| 111 # Setup environment | |
| 112 hp["env"] = { | |
| 113 "seed": int(args.random_seed) | |
| 114 } | |
| 115 | |
| 116 # Set eval metric through model config | |
| 117 model_block = hp.setdefault("model", {}) | |
| 118 if args.eval_metric: | |
| 119 model_block.setdefault("metric_learning", {})["metric"] = str(args.eval_metric) | |
| 120 | |
| 121 if text_cols and Version(torch.__version__) < Version("2.6"): | |
| 122 safe_ckpt = "distilbert-base-uncased" | |
| 123 logger.warning(f"Forcing HF text checkpoint with safetensors: {safe_ckpt}") | |
| 124 hp["model.hf_text.checkpoint_name"] = safe_ckpt | |
| 125 hp.setdefault( | |
| 126 "model.names", | |
| 127 ["hf_text", "timm_image", "numerical_mlp", "categorical_mlp", "fusion_mlp"], | |
| 128 ) | |
| 129 | |
| 130 def _is_valid_hp_dict(d) -> bool: | |
| 131 if not isinstance(d, dict): | |
| 132 logger.warning("User-supplied hyperparameters must be a dict; received %s", type(d).__name__) | |
| 133 return False | |
| 134 return True | |
| 135 | |
| 136 user_hp = args.hyperparameters if isinstance(args.hyperparameters, dict) else load_user_hparams(args.hyperparameters) | |
| 137 if user_hp and _is_valid_hp_dict(user_hp): | |
| 138 hp = deep_update(hp, user_hp) | |
| 139 | |
| 140 # Map CLI knobs into AutoMM optimization hyperparameters when provided. | |
| 141 # We set multiple common key names (nested dicts and dotted flat keys) to | |
| 142 # maximize compatibility across AutoMM/AutoGluon versions. | |
| 143 try: | |
| 144 if any(getattr(args, param, None) is not None for param in ["epochs", "learning_rate", "batch_size"]): | |
| 145 if getattr(args, "epochs", None) is not None: | |
| 146 hp["optim.max_epochs"] = int(args.epochs) | |
| 147 hp["optim.epochs"] = int(args.epochs) | |
| 148 if getattr(args, "learning_rate", None) is not None: | |
| 149 hp["optim.learning_rate"] = float(args.learning_rate) | |
| 150 hp["optim.lr"] = float(args.learning_rate) | |
| 151 if getattr(args, "batch_size", None) is not None: | |
| 152 hp["optim.batch_size"] = int(args.batch_size) | |
| 153 hp["optim.per_device_train_batch_size"] = int(args.batch_size) | |
| 154 | |
| 155 # Also set dotted flat keys for max compatibility (e.g., 'optimization.max_epochs') | |
| 156 if getattr(args, "epochs", None) is not None: | |
| 157 hp["optimization.max_epochs"] = int(args.epochs) | |
| 158 hp["optimization.epochs"] = int(args.epochs) | |
| 159 if getattr(args, "learning_rate", None) is not None: | |
| 160 hp["optimization.learning_rate"] = float(args.learning_rate) | |
| 161 hp["optimization.lr"] = float(args.learning_rate) | |
| 162 if getattr(args, "batch_size", None) is not None: | |
| 163 hp["optimization.batch_size"] = int(args.batch_size) | |
| 164 hp["optimization.per_device_train_batch_size"] = int(args.batch_size) | |
| 165 except Exception: | |
| 166 logger.warning("Failed to attach epochs/learning_rate/batch_size to mm_hparams; continuing without them.") | |
| 167 | |
| 168 # Map backbone selections into mm_hparams if provided | |
| 169 try: | |
| 170 has_text_cols = bool(text_cols) | |
| 171 has_image_cols = False | |
| 172 model_names_cache: Optional[List[str]] = None | |
| 173 model_names_modified = False | |
| 174 | |
| 175 def _dedupe_preserve(seq: List[str]) -> List[str]: | |
| 176 seen = set() | |
| 177 ordered = [] | |
| 178 for item in seq: | |
| 179 if item in seen: | |
| 180 continue | |
| 181 seen.add(item) | |
| 182 ordered.append(item) | |
| 183 return ordered | |
| 184 | |
| 185 def _get_model_names() -> List[str]: | |
| 186 nonlocal model_names_cache | |
| 187 if model_names_cache is not None: | |
| 188 return model_names_cache | |
| 189 names = model_block.get("names") | |
| 190 if isinstance(names, list): | |
| 191 model_names_cache = list(names) | |
| 192 else: | |
| 193 model_names_cache = [] | |
| 194 if has_text_cols: | |
| 195 model_names_cache.append("hf_text") | |
| 196 if has_image_cols: | |
| 197 model_names_cache.append("timm_image") | |
| 198 model_names_cache.extend(["numerical_mlp", "categorical_mlp"]) | |
| 199 model_names_cache.append("fusion_mlp") | |
| 200 return model_names_cache | |
| 201 | |
| 202 def _set_model_names(new_names: List[str]) -> None: | |
| 203 nonlocal model_names_cache, model_names_modified | |
| 204 model_names_cache = new_names | |
| 205 model_names_modified = True | |
| 206 | |
| 207 if has_text_cols and getattr(args, "backbone_text", None): | |
| 208 text_choice = str(args.backbone_text) | |
| 209 model_block.setdefault("hf_text", {})["checkpoint_name"] = text_choice | |
| 210 hp["model.hf_text.checkpoint_name"] = text_choice | |
| 211 if has_image_cols and getattr(args, "backbone_image", None): | |
| 212 image_choice = str(args.backbone_image) | |
| 213 model_block.setdefault("timm_image", {})["checkpoint_name"] = image_choice | |
| 214 hp["model.timm_image.checkpoint_name"] = image_choice | |
| 215 if model_names_modified and model_names_cache is not None: | |
| 216 model_block["names"] = model_names_cache | |
| 217 except Exception: | |
| 218 logger.warning("Failed to attach backbone selections to mm_hparams; continuing without them.") | |
| 219 | |
| 220 if ag_version: | |
| 221 logger.info(f"Detected AutoGluon version: {ag_version}; applied robust hyperparameter mappings.") | |
| 222 | |
| 223 return hp | |
| 224 | |
| 225 | |
| 226 def train_predictor( | |
| 227 args, | |
| 228 df_train: pd.DataFrame, | |
| 229 df_val: pd.DataFrame, | |
| 230 image_columns: Optional[List[str]], | |
| 231 mm_hparams: dict, | |
| 232 ): | |
| 233 """ | |
| 234 Train a MultiModalPredictor, honoring common knobs (presets, eval_metric, etc.). | |
| 235 """ | |
| 236 logger.info("Starting AutoGluon MultiModal training...") | |
| 237 predictor = MultiModalPredictor(label=args.label_column, path=None) | |
| 238 column_types = {} | |
| 239 | |
| 240 mm_fit_kwargs = dict( | |
| 241 train_data=df_train, | |
| 242 time_limit=args.time_limit, | |
| 243 seed=int(args.random_seed), | |
| 244 hyperparameters=mm_hparams, | |
| 245 ) | |
| 246 if df_val is not None and not df_val.empty: | |
| 247 mm_fit_kwargs["tuning_data"] = df_val | |
| 248 if column_types: | |
| 249 mm_fit_kwargs["column_types"] = column_types | |
| 250 | |
| 251 preset_mm = getattr(args, "presets", None) | |
| 252 if preset_mm is None: | |
| 253 preset_mm = getattr(args, "preset", None) | |
| 254 if preset_mm is not None: | |
| 255 mm_fit_kwargs["presets"] = preset_mm | |
| 256 | |
| 257 predictor.fit(**mm_fit_kwargs) | |
| 258 return predictor | |
| 259 | |
| 260 | |
| 261 # ---------------------- evaluation ---------------------- | |
| 262 def evaluate_predictor_all_splits( | |
| 263 predictor, | |
| 264 df_train: Optional[pd.DataFrame], | |
| 265 df_val: Optional[pd.DataFrame], | |
| 266 df_test: Optional[pd.DataFrame], | |
| 267 label_col: str, | |
| 268 problem_type: str, | |
| 269 eval_metric: Optional[str], | |
| 270 threshold_test: Optional[float], | |
| 271 df_test_external: Optional[pd.DataFrame] = None, | |
| 272 ) -> Tuple[Dict[str, Dict[str, float]], Dict[str, Dict[str, float]]]: | |
| 273 """ | |
| 274 Returns (raw_metrics, ag_scores_by_split) | |
| 275 - raw_metrics: our transparent suite (threshold applied to Test/External Test only inside metrics_logic) | |
| 276 - ag_scores_by_split: AutoGluon's evaluate() per split for the chosen eval_metric (or default) | |
| 277 """ | |
| 278 metrics_req = None if (eval_metric is None or str(eval_metric).lower() == "auto") else [eval_metric] | |
| 279 ag_by_split: Dict[str, Dict[str, float]] = {} | |
| 280 | |
| 281 if df_train is not None and len(df_train): | |
| 282 ag_by_split["Train"] = ag_evaluate_safely(predictor, df_train, metrics=metrics_req) | |
| 283 if df_val is not None and len(df_val): | |
| 284 ag_by_split["Validation"] = ag_evaluate_safely(predictor, df_val, metrics=metrics_req) | |
| 285 | |
| 286 df_test_effective = df_test_external if df_test_external is not None else df_test | |
| 287 if df_test_effective is not None and len(df_test_effective): | |
| 288 ag_by_split["Test"] = ag_evaluate_safely(predictor, df_test_effective, metrics=metrics_req) | |
| 289 | |
| 290 # Transparent suite (threshold on Test handled inside metrics_logic) | |
| 291 _, raw_metrics = evaluate_all_transparency( | |
| 292 predictor=predictor, | |
| 293 train_df=df_train, | |
| 294 val_df=df_val, | |
| 295 test_df=df_test_effective, | |
| 296 target_col=label_col, | |
| 297 problem_type=problem_type, | |
| 298 threshold=threshold_test, | |
| 299 ) | |
| 300 | |
| 301 if df_test_external is not None and df_test_external is not df_test and len(df_test_external): | |
| 302 raw_metrics["Test (external)"] = compute_metrics_for_split( | |
| 303 predictor, df_test_external, label_col, problem_type, threshold=threshold_test | |
| 304 ) | |
| 305 ag_by_split["Test (external)"] = ag_evaluate_safely(predictor, df_test_external, metrics=metrics_req) | |
| 306 | |
| 307 return raw_metrics, ag_by_split | |
| 308 | |
| 309 | |
| 310 def fit_summary_safely(predictor) -> Optional[dict]: | |
| 311 """Get fit summary without printing misleading one-liners.""" | |
| 312 with suppress_stdout_stderr(): | |
| 313 try: | |
| 314 return predictor.fit_summary() | |
| 315 except Exception: | |
| 316 return None | |
| 317 | |
| 318 | |
| 319 # ---------------------- image helpers ---------------------- | |
| 320 _PLACEHOLDER_PATH = None | |
| 321 | |
| 322 | |
| 323 def _create_placeholder() -> str: | |
| 324 global _PLACEHOLDER_PATH | |
| 325 if _PLACEHOLDER_PATH and os.path.exists(_PLACEHOLDER_PATH): | |
| 326 return _PLACEHOLDER_PATH | |
| 327 | |
| 328 dir_ = Path(tempfile.mkdtemp(prefix="ag_placeholder_")) | |
| 329 file_ = dir_ / f"placeholder_{uuid.uuid4().hex}.png" | |
| 330 | |
| 331 try: | |
| 332 from PIL import Image | |
| 333 Image.new("RGB", (64, 64), (180, 180, 180)).save(file_) | |
| 334 except Exception: | |
| 335 import matplotlib.pyplot as plt | |
| 336 import numpy as np | |
| 337 plt.imsave(file_, np.full((64, 64, 3), 180, dtype=np.uint8)) | |
| 338 plt.close("all") | |
| 339 | |
| 340 _PLACEHOLDER_PATH = str(file_) | |
| 341 logger.info(f"Placeholder image created: {file_}") | |
| 342 return _PLACEHOLDER_PATH | |
| 343 | |
| 344 | |
| 345 def _is_valid_path(val) -> bool: | |
| 346 if pd.isna(val): | |
| 347 return False | |
| 348 s = str(val).strip() | |
| 349 return s and os.path.isfile(s) | |
| 350 | |
| 351 | |
| 352 def handle_missing_images( | |
| 353 df: pd.DataFrame, | |
| 354 image_columns: List[str], | |
| 355 strategy: str = "false", | |
| 356 ) -> pd.DataFrame: | |
| 357 if not image_columns or df.empty: | |
| 358 return df | |
| 359 | |
| 360 remove = str(strategy).lower() == "true" | |
| 361 masks = [~df[col].apply(_is_valid_path) for col in image_columns if col in df.columns] | |
| 362 if not masks: | |
| 363 return df | |
| 364 | |
| 365 any_missing = pd.concat(masks, axis=1).any(axis=1) | |
| 366 n_missing = int(any_missing.sum()) | |
| 367 | |
| 368 if n_missing == 0: | |
| 369 return df | |
| 370 | |
| 371 if remove: | |
| 372 result = df[~any_missing].reset_index(drop=True) | |
| 373 logger.info(f"Dropped {n_missing} rows with missing images → {len(result)} remain") | |
| 374 else: | |
| 375 placeholder = _create_placeholder() | |
| 376 result = df.copy() | |
| 377 for col in image_columns: | |
| 378 if col in result.columns: | |
| 379 result.loc[~result[col].apply(_is_valid_path), col] = placeholder | |
| 380 logger.info(f"Filled {n_missing} missing images with placeholder") | |
| 381 | |
| 382 return result | |
| 383 | |
| 384 | |
| 385 # ---------------------- AutoGluon config helpers ---------------------- | |
| 386 def autogluon_hyperparameters( | |
| 387 threshold, | |
| 388 time_limit, | |
| 389 random_seed, | |
| 390 epochs, | |
| 391 learning_rate, | |
| 392 batch_size, | |
| 393 backbone_image, | |
| 394 backbone_text, | |
| 395 preset, | |
| 396 eval_metric, | |
| 397 hyperparameters, | |
| 398 ): | |
| 399 """ | |
| 400 Build a MultiModalPredictor configuration (fit kwargs + hyperparameters) from CLI inputs. | |
| 401 The returned dict separates what should be passed to predictor.fit (under ``fit``) | |
| 402 from the model/optimization configuration (under ``hyperparameters``). Threshold is | |
| 403 preserved for downstream evaluation but not passed into AutoGluon directly. | |
| 404 """ | |
| 405 | |
| 406 def _prune_empty(d: dict) -> dict: | |
| 407 cleaned = {} | |
| 408 for k, v in (d or {}).items(): | |
| 409 if isinstance(v, dict): | |
| 410 nested = _prune_empty(v) | |
| 411 if nested: | |
| 412 cleaned[k] = nested | |
| 413 elif v is not None: | |
| 414 cleaned[k] = v | |
| 415 return cleaned | |
| 416 | |
| 417 # Base hyperparameters following the structure described in the AutoGluon | |
| 418 # customization guide (env / optimization / model). | |
| 419 env_cfg = {} | |
| 420 if random_seed is not None: | |
| 421 env_cfg["seed"] = int(random_seed) | |
| 422 if batch_size is not None: | |
| 423 env_cfg["per_gpu_batch_size"] = int(batch_size) | |
| 424 | |
| 425 optim_cfg = {} | |
| 426 if epochs is not None: | |
| 427 optim_cfg["max_epochs"] = int(epochs) | |
| 428 if learning_rate is not None: | |
| 429 optim_cfg["learning_rate"] = float(learning_rate) | |
| 430 if batch_size is not None: | |
| 431 bs = int(batch_size) | |
| 432 optim_cfg["per_device_train_batch_size"] = bs | |
| 433 optim_cfg["train_batch_size"] = bs | |
| 434 | |
| 435 model_cfg = {} | |
| 436 if eval_metric: | |
| 437 model_cfg.setdefault("metric_learning", {})["metric"] = str(eval_metric) | |
| 438 if backbone_image: | |
| 439 model_cfg.setdefault("timm_image", {})["checkpoint_name"] = str(backbone_image) | |
| 440 if backbone_text: | |
| 441 model_cfg.setdefault("hf_text", {})["checkpoint_name"] = str(backbone_text) | |
| 442 | |
| 443 hp = { | |
| 444 "env": env_cfg, | |
| 445 "optimization": optim_cfg, | |
| 446 "model": model_cfg, | |
| 447 } | |
| 448 | |
| 449 # Also expose the most common dotted aliases for robustness across AG versions. | |
| 450 if epochs is not None: | |
| 451 hp["optimization.max_epochs"] = int(epochs) | |
| 452 hp["optim.max_epochs"] = int(epochs) | |
| 453 if learning_rate is not None: | |
| 454 lr_val = float(learning_rate) | |
| 455 hp["optimization.learning_rate"] = lr_val | |
| 456 hp["optimization.lr"] = lr_val | |
| 457 hp["optim.learning_rate"] = lr_val | |
| 458 hp["optim.lr"] = lr_val | |
| 459 if batch_size is not None: | |
| 460 bs_val = int(batch_size) | |
| 461 hp["optimization.per_device_train_batch_size"] = bs_val | |
| 462 hp["optimization.batch_size"] = bs_val | |
| 463 hp["optim.per_device_train_batch_size"] = bs_val | |
| 464 hp["optim.batch_size"] = bs_val | |
| 465 hp["env.per_gpu_batch_size"] = bs_val | |
| 466 if backbone_image: | |
| 467 hp["model.timm_image.checkpoint_name"] = str(backbone_image) | |
| 468 if backbone_text: | |
| 469 hp["model.hf_text.checkpoint_name"] = str(backbone_text) | |
| 470 | |
| 471 # Merge user-provided hyperparameters (inline JSON or path) last so they win. | |
| 472 if isinstance(hyperparameters, dict): | |
| 473 user_hp = hyperparameters | |
| 474 else: | |
| 475 user_hp = load_user_hparams(hyperparameters) | |
| 476 hp = deep_update(hp, user_hp) | |
| 477 hp = _prune_empty(hp) | |
| 478 | |
| 479 fit_cfg = {} | |
| 480 if time_limit is not None: | |
| 481 fit_cfg["time_limit"] = time_limit | |
| 482 if random_seed is not None: | |
| 483 fit_cfg["seed"] = int(random_seed) | |
| 484 if preset: | |
| 485 fit_cfg["presets"] = preset | |
| 486 | |
| 487 config = { | |
| 488 "fit": fit_cfg, | |
| 489 "hyperparameters": hp, | |
| 490 } | |
| 491 if threshold is not None: | |
| 492 config["threshold"] = float(threshold) | |
| 493 | |
| 494 return config | |
| 495 | |
| 496 | |
| 497 def run_autogluon_experiment( | |
| 498 train_dataset: pd.DataFrame, | |
| 499 test_dataset: Optional[pd.DataFrame], | |
| 500 target_column: str, | |
| 501 image_columns: Optional[List[str]], | |
| 502 ag_config: dict, | |
| 503 ): | |
| 504 """ | |
| 505 Launch an AutoGluon MultiModal training run using the config from | |
| 506 autogluon_hyperparameters(). Returns (predictor, context dict) so callers | |
| 507 can evaluate downstream with the chosen threshold. | |
| 508 """ | |
| 509 if ag_config is None: | |
| 510 raise ValueError("ag_config is required to launch AutoGluon training.") | |
| 511 | |
| 512 hyperparameters = ag_config.get("hyperparameters") or {} | |
| 513 fit_cfg = dict(ag_config.get("fit") or {}) | |
| 514 threshold = ag_config.get("threshold") | |
| 515 | |
| 516 if "split" not in train_dataset.columns: | |
| 517 raise ValueError("train_dataset must contain a 'split' column. Did you call split_dataset?") | |
| 518 | |
| 519 df_train = train_dataset[train_dataset["split"] == "train"].copy() | |
| 520 df_val = train_dataset[train_dataset["split"].isin(["val", "validation"])].copy() | |
| 521 df_test_internal = train_dataset[train_dataset["split"] == "test"].copy() | |
| 522 | |
| 523 predictor = MultiModalPredictor(label=target_column, path=None) | |
| 524 column_types = {c: "image_path" for c in (image_columns or [])} | |
| 525 | |
| 526 fit_kwargs = { | |
| 527 "train_data": df_train, | |
| 528 "hyperparameters": hyperparameters, | |
| 529 } | |
| 530 fit_kwargs.update(fit_cfg) | |
| 531 if not df_val.empty: | |
| 532 fit_kwargs.setdefault("tuning_data", df_val) | |
| 533 if column_types: | |
| 534 fit_kwargs.setdefault("column_types", column_types) | |
| 535 | |
| 536 logger.info( | |
| 537 "Fitting AutoGluon with %d train / %d val rows (internal test rows: %d, external test provided: %s)", | |
| 538 len(df_train), | |
| 539 len(df_val), | |
| 540 len(df_test_internal), | |
| 541 (test_dataset is not None and not test_dataset.empty), | |
| 542 ) | |
| 543 predictor.fit(**fit_kwargs) | |
| 544 | |
| 545 return predictor, { | |
| 546 "train": df_train, | |
| 547 "val": df_val, | |
| 548 "test_internal": df_test_internal, | |
| 549 "test_external": test_dataset, | |
| 550 "threshold": threshold, | |
| 551 } |
