comparison multimodal_learner.py @ 0:375c36923da1 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
author goeckslab
date Tue, 09 Dec 2025 23:49:47 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:375c36923da1
1 #!/usr/bin/env python
2 """
3 Main entrypoint for AutoGluon multimodal training wrapper.
4 """
5
6 import argparse
7 import logging
8 import os
9 import sys
10 from typing import List, Optional
11
12 import pandas as pd
13 from metrics_logic import aggregate_metrics
14 from plot_logic import infer_problem_type
15 from report_utils import write_outputs
16 from sklearn.model_selection import KFold, StratifiedKFold
17 from split_logic import split_dataset
18 from test_pipeline import run_autogluon_test_experiment
19 from training_pipeline import autogluon_hyperparameters, handle_missing_images, run_autogluon_experiment
20 # ------------------------------------------------------------------
21 # Local imports (your split utilities)
22 # ------------------------------------------------------------------
23 from utils import (
24 absolute_path_expander,
25 enable_deterministic_mode,
26 enable_tensor_cores_if_available,
27 ensure_local_tmp,
28 load_file,
29 prepare_image_search_dirs,
30 set_seeds,
31 str2bool,
32 )
33
34 # ------------------------------------------------------------------
35 # Logger setup
36 # ------------------------------------------------------------------
37 logger = logging.getLogger(__name__)
38
39
40 # ------------------------------------------------------------------
41 # Argument parsing (unchanged from your original, only minor fixes)
42 # ------------------------------------------------------------------
43 def parse_args(argv=None):
44 parser = argparse.ArgumentParser(description="Train & report an AutoGluon model")
45
46 parser.add_argument("--input_csv_train", dest="train_dataset", required=True)
47 parser.add_argument("--input_csv_test", dest="test_dataset", default=None)
48 parser.add_argument("--target_column", required=True)
49 parser.add_argument("--output_json", default="results.json")
50 parser.add_argument("--output_html", default="report.html")
51 parser.add_argument("--output_config", default=None)
52 parser.add_argument("--images_zip", nargs="*", default=None,
53 help="One or more ZIP files that contain image assets")
54 parser.add_argument("--missing_image_strategy", default="false",
55 help="true/false: remove rows with missing images or use placeholder")
56 parser.add_argument("--threshold", type=float, default=None)
57 parser.add_argument("--time_limit", type=int, default=None)
58 parser.add_argument("--deterministic", action="store_true", default=False,
59 help="Enable deterministic algorithms to reduce run-to-run variance")
60 parser.add_argument("--random_seed", type=int, default=42)
61 parser.add_argument("--cross_validation", type=str, default="false")
62 parser.add_argument("--num_folds", type=int, default=5)
63 parser.add_argument("--epochs", type=int, default=None)
64 parser.add_argument("--learning_rate", type=float, default=None)
65 parser.add_argument("--batch_size", type=int, default=None)
66 parser.add_argument("--backbone_image", type=str, default="swin_base_patch4_window7_224")
67 parser.add_argument("--backbone_text", type=str, default="microsoft/deberta-v3-base")
68 parser.add_argument("--validation_size", type=float, default=0.2)
69 parser.add_argument("--split_probabilities", type=float, nargs=3,
70 default=[0.7, 0.1, 0.2], metavar=("train", "val", "test"))
71 parser.add_argument("--preset", choices=["medium_quality", "high_quality", "best_quality"],
72 default="medium_quality")
73 parser.add_argument("--eval_metric", default="roc_auc")
74 parser.add_argument("--hyperparameters", default=None)
75
76 args, unknown = parser.parse_known_args(argv)
77 if unknown:
78 logger.warning("Ignoring unknown CLI tokens: %s", unknown)
79
80 # -------------------------- Validation --------------------------
81 if not (0.0 <= args.validation_size <= 1.0):
82 parser.error("--validation_size must be in [0, 1]")
83 if len(args.split_probabilities) != 3 or abs(sum(args.split_probabilities) - 1.0) > 1e-6:
84 parser.error("--split_probabilities must be three numbers summing to 1.0")
85 if args.cross_validation.lower() == "true" and (args.num_folds < 2):
86 parser.error("--num_folds must be >= 2 when --cross_validation is true")
87
88 return args
89
90
91 def run_cross_validation(
92 args,
93 df_full: pd.DataFrame,
94 test_dataset: Optional[pd.DataFrame],
95 image_cols: List[str],
96 ag_config: dict,
97 ):
98 """Cross-validation loop returning aggregated metrics and last predictor."""
99 df_full = df_full.drop(columns=["split"], errors="ignore")
100 y = df_full[args.target_column]
101 try:
102 use_stratified = y.dtype == object or y.nunique() <= 20
103 except Exception:
104 use_stratified = False
105
106 kf = StratifiedKFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed)) if use_stratified else KFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed))
107
108 raw_folds = []
109 ag_folds = []
110 folds_info = []
111 last_predictor = None
112 last_data_ctx = None
113
114 for fold_idx, (train_idx, val_idx) in enumerate(kf.split(df_full, y if use_stratified else None), start=1):
115 logger.info(f"CV fold {fold_idx}/{args.num_folds}")
116 df_tr = df_full.iloc[train_idx].copy()
117 df_va = df_full.iloc[val_idx].copy()
118
119 df_tr["split"] = "train"
120 df_va["split"] = "val"
121 fold_dataset = pd.concat([df_tr, df_va], ignore_index=True)
122
123 predictor_fold, data_ctx = run_autogluon_experiment(
124 train_dataset=fold_dataset,
125 test_dataset=test_dataset,
126 target_column=args.target_column,
127 image_columns=image_cols,
128 ag_config=ag_config,
129 )
130 last_predictor = predictor_fold
131 last_data_ctx = data_ctx
132 problem_type = infer_problem_type(predictor_fold, df_tr, args.target_column)
133 eval_results = run_autogluon_test_experiment(
134 predictor=predictor_fold,
135 data_ctx=data_ctx,
136 target_column=args.target_column,
137 eval_metric=args.eval_metric,
138 ag_config=ag_config,
139 problem_type=problem_type,
140 )
141
142 raw_metrics_fold = eval_results.get("raw_metrics", {})
143 ag_by_split_fold = eval_results.get("ag_eval", {})
144 raw_folds.append(raw_metrics_fold)
145 ag_folds.append(ag_by_split_fold)
146 folds_info.append(
147 {
148 "fold": int(fold_idx),
149 "predictor_path": getattr(predictor_fold, "path", None),
150 "raw_metrics": raw_metrics_fold,
151 "ag_eval": ag_by_split_fold,
152 }
153 )
154
155 raw_metrics_mean, raw_metrics_std = aggregate_metrics(raw_folds)
156 ag_by_split_mean, ag_by_split_std = aggregate_metrics(ag_folds)
157 return (
158 last_predictor,
159 raw_metrics_mean,
160 ag_by_split_mean,
161 raw_folds,
162 ag_folds,
163 raw_metrics_std,
164 ag_by_split_std,
165 folds_info,
166 last_data_ctx,
167 )
168
169
170 # ------------------------------------------------------------------
171 # Main execution
172 # ------------------------------------------------------------------
173 def main():
174 args = parse_args()
175
176 # ------------------------------------------------------------------
177 # Debug output
178 # ------------------------------------------------------------------
179 logger.info("=== AutoGluon Training Wrapper Started ===")
180 logger.info(f"Working directory: {os.getcwd()}")
181 logger.info(f"Command line: {' '.join(sys.argv)}")
182 logger.info(f"Parsed args: {vars(args)}")
183
184 # ------------------------------------------------------------------
185 # Reproducibility & performance
186 # ------------------------------------------------------------------
187 set_seeds(args.random_seed)
188 if args.deterministic:
189 enable_deterministic_mode(args.random_seed)
190 logger.info("Deterministic mode enabled (seed=%s)", args.random_seed)
191 ensure_local_tmp()
192 enable_tensor_cores_if_available()
193
194 # ------------------------------------------------------------------
195 # Load datasets
196 # ------------------------------------------------------------------
197 train_dataset = load_file(args.train_dataset)
198 test_dataset = load_file(args.test_dataset) if args.test_dataset else None
199
200 logger.info(f"Train dataset loaded: {len(train_dataset)} rows")
201 if test_dataset is not None:
202 logger.info(f"Test dataset loaded: {len(test_dataset)} rows")
203
204 # ------------------------------------------------------------------
205 # Resolve target column by name; if Galaxy passed a numeric index,
206 # translate it to the corresponding header so downstream checks pass.
207 # Galaxy's data_column widget is 1-based.
208 # ------------------------------------------------------------------
209 if args.target_column not in train_dataset.columns and str(args.target_column).isdigit():
210 idx = int(args.target_column) - 1
211 if 0 <= idx < len(train_dataset.columns):
212 resolved = train_dataset.columns[idx]
213 logger.info(f"Target column '{args.target_column}' not found; using column #{idx + 1} header '{resolved}' instead.")
214 args.target_column = resolved
215 else:
216 logger.error(f"Numeric target index '{args.target_column}' is out of range for dataset with {len(train_dataset.columns)} columns.")
217 sys.exit(1)
218
219 # ------------------------------------------------------------------
220 # Image handling (ZIP extraction + absolute path expansion)
221 # ------------------------------------------------------------------
222 extracted_imgs_path = prepare_image_search_dirs(args)
223
224 image_cols = absolute_path_expander(train_dataset, extracted_imgs_path, None)
225 if test_dataset is not None:
226 absolute_path_expander(test_dataset, extracted_imgs_path, image_cols)
227
228 # ------------------------------------------------------------------
229 # Handle missing images
230 # ------------------------------------------------------------------
231 train_dataset = handle_missing_images(
232 train_dataset,
233 image_columns=image_cols,
234 strategy=args.missing_image_strategy,
235 )
236 if test_dataset is not None:
237 test_dataset = handle_missing_images(
238 test_dataset,
239 image_columns=image_cols,
240 strategy=args.missing_image_strategy,
241 )
242
243 logger.info(f"After cleanup → train: {len(train_dataset)}, test: {len(test_dataset) if test_dataset is not None else 0}")
244
245 # ------------------------------------------------------------------
246 # Dataset splitting logic (adds 'split' column to train_dataset)
247 # ------------------------------------------------------------------
248 split_dataset(
249 train_dataset=train_dataset,
250 test_dataset=test_dataset,
251 target_column=args.target_column,
252 split_probabilities=args.split_probabilities,
253 validation_size=args.validation_size,
254 random_seed=args.random_seed,
255 )
256
257 logger.info("Preprocessing complete — ready for AutoGluon training!")
258 logger.info(f"Final split counts:\n{train_dataset['split'].value_counts().sort_index()}")
259
260 # Verify target/image/text columns exist
261 if args.target_column not in train_dataset.columns:
262 logger.error(f"Target column '{args.target_column}' not found in training data.")
263 sys.exit(1)
264 if test_dataset is not None and args.target_column not in test_dataset.columns:
265 logger.error(f"Target column '{args.target_column}' not found in test data.")
266 sys.exit(1)
267
268 # Threshold is only meaningful for binary classification; ignore otherwise.
269 threshold_for_run = args.threshold
270 unique_labels = None
271 target_looks_binary = False
272 try:
273 unique_labels = train_dataset[args.target_column].nunique(dropna=True)
274 target_looks_binary = unique_labels == 2
275 except Exception:
276 logger.warning("Could not inspect target column '%s' for threshold validation; proceeding without binary check.", args.target_column)
277
278 if threshold_for_run is not None:
279 if target_looks_binary:
280 threshold_for_run = float(threshold_for_run)
281 logger.info("Applying custom decision threshold %.4f for binary evaluation.", threshold_for_run)
282 else:
283 logger.warning(
284 "Threshold %.3f provided but target '%s' does not appear binary (unique labels=%s); ignoring threshold.",
285 threshold_for_run,
286 args.target_column,
287 unique_labels if unique_labels is not None else "unknown",
288 )
289 threshold_for_run = None
290 args.threshold = threshold_for_run
291 # Image columns are auto-inferred; image_cols already resolved to absolute paths.
292 # ------------------------------------------------------------------
293 # Build AutoGluon configuration from CLI knobs
294 # ------------------------------------------------------------------
295 ag_config = autogluon_hyperparameters(
296 threshold=args.threshold,
297 time_limit=args.time_limit,
298 random_seed=args.random_seed,
299 epochs=args.epochs,
300 learning_rate=args.learning_rate,
301 batch_size=args.batch_size,
302 backbone_image=args.backbone_image,
303 backbone_text=args.backbone_text,
304 preset=args.preset,
305 eval_metric=args.eval_metric,
306 hyperparameters=args.hyperparameters,
307 )
308 logger.info(f"AutoGluon config prepared: fit={ag_config.get('fit')}, hyperparameters keys={list(ag_config.get('hyperparameters', {}).keys())}")
309
310 cv_enabled = str2bool(args.cross_validation)
311 if cv_enabled:
312 (
313 predictor,
314 raw_metrics,
315 ag_by_split,
316 raw_folds,
317 ag_folds,
318 raw_metrics_std,
319 ag_by_split_std,
320 folds_info,
321 data_ctx,
322 ) = run_cross_validation(
323 args=args,
324 df_full=train_dataset,
325 test_dataset=test_dataset,
326 image_cols=image_cols,
327 ag_config=ag_config,
328 )
329 if predictor is None:
330 logger.error("All CV folds failed. Exiting.")
331 sys.exit(1)
332 eval_results = {
333 "raw_metrics": raw_metrics,
334 "ag_eval": ag_by_split,
335 "fit_summary": None,
336 }
337 else:
338 predictor, data_ctx = run_autogluon_experiment(
339 train_dataset=train_dataset,
340 test_dataset=test_dataset,
341 target_column=args.target_column,
342 image_columns=image_cols,
343 ag_config=ag_config,
344 )
345 logger.info("AutoGluon training finished. Model path: %s", getattr(predictor, "path", None))
346
347 # Evaluate predictor on Train/Val/Test splits
348 problem_type = infer_problem_type(predictor, train_dataset, args.target_column)
349 eval_results = run_autogluon_test_experiment(
350 predictor=predictor,
351 data_ctx=data_ctx,
352 target_column=args.target_column,
353 eval_metric=args.eval_metric,
354 ag_config=ag_config,
355 problem_type=problem_type,
356 )
357 raw_metrics = eval_results.get("raw_metrics", {})
358 ag_by_split = eval_results.get("ag_eval", {})
359 raw_folds = ag_folds = raw_metrics_std = ag_by_split_std = None
360
361 logger.info("Transparent metrics by split: %s", eval_results["raw_metrics"])
362 logger.info("AutoGluon evaluate() by split: %s", eval_results["ag_eval"])
363
364 if "problem_type" in eval_results:
365 problem_type_final = eval_results["problem_type"]
366 else:
367 problem_type_final = infer_problem_type(predictor, train_dataset, args.target_column)
368
369 write_outputs(
370 args=args,
371 predictor=predictor,
372 problem_type=problem_type_final,
373 eval_results=eval_results,
374 data_ctx=data_ctx,
375 raw_folds=raw_folds,
376 ag_folds=ag_folds,
377 raw_metrics_std=raw_metrics_std,
378 ag_by_split_std=ag_by_split_std,
379 )
380
381
382 if __name__ == "__main__":
383 logging.basicConfig(
384 level=logging.INFO,
385 format="%(asctime)s | %(levelname)s | %(message)s",
386 datefmt="%H:%M:%S"
387 )
388 # Quiet noisy image parsing logs (e.g., PIL.PngImagePlugin debug streams)
389 logging.getLogger("PIL").setLevel(logging.WARNING)
390 logging.getLogger("PIL.PngImagePlugin").setLevel(logging.WARNING)
391 main()