comparison report_utils.py @ 0:375c36923da1 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
author goeckslab
date Tue, 09 Dec 2025 23:49:47 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:375c36923da1
1 import base64
2 import html
3 import json
4 import logging
5 import os
6 import platform
7 import shutil
8 import sys
9 import tempfile
10 from datetime import datetime
11 from typing import Any, Dict, List, Optional
12
13 import numpy as np
14 import pandas as pd
15 import yaml
16 from utils import verify_outputs
17
18 logger = logging.getLogger(__name__)
19
20
21 def _escape(s: Any) -> str:
22 return html.escape(str(s))
23
24
25 def _write_predictor_path(predictor):
26 try:
27 pred_path = getattr(predictor, "path", None)
28 if pred_path:
29 with open("predictor_path.txt", "w") as pf:
30 pf.write(str(pred_path))
31 logger.info("Wrote predictor path → predictor_path.txt")
32 return pred_path
33 except Exception:
34 logger.warning("Could not write predictor_path.txt")
35 return None
36
37
38 def _copy_config_if_available(pred_path: Optional[str], output_config: Optional[str]):
39 if not output_config:
40 return
41 try:
42 config_yaml_path = os.path.join(pred_path, "config.yaml") if pred_path else None
43 if config_yaml_path and os.path.isfile(config_yaml_path):
44 shutil.copy2(config_yaml_path, output_config)
45 logger.info(f"Wrote AutoGluon config → {output_config}")
46 else:
47 with open(output_config, "w") as cfg_out:
48 cfg_out.write("# config.yaml not found for this run\n")
49 logger.warning(f"AutoGluon config.yaml not found; created placeholder at {output_config}")
50 except Exception as e:
51 logger.error(f"Failed to write config output '{output_config}': {e}")
52 try:
53 with open(output_config, "w") as cfg_out:
54 cfg_out.write(f"# Failed to copy config.yaml: {e}\n")
55 except Exception:
56 pass
57
58
59 def _load_config_yaml(args, predictor) -> dict:
60 """
61 Load config.yaml either from the predictor path or the exported output_config.
62 """
63 candidates = []
64 pred_path = getattr(predictor, "path", None)
65 if pred_path:
66 cfg_path = os.path.join(pred_path, "config.yaml")
67 if os.path.isfile(cfg_path):
68 candidates.append(cfg_path)
69 if args.output_config and os.path.isfile(args.output_config):
70 candidates.append(args.output_config)
71
72 for p in candidates:
73 try:
74 with open(p, "r") as f:
75 return yaml.safe_load(f) or {}
76 except Exception:
77 continue
78 return {}
79
80
81 def _summarize_config(cfg: dict, args) -> List[tuple[str, str]]:
82 """
83 Build rows describing model components and key hyperparameters from a loaded config.yaml.
84 Falls back to CLI args when config values are missing.
85 """
86 rows: List[tuple[str, str]] = []
87 model_cfg = cfg.get("model", {}) if isinstance(cfg, dict) else {}
88 names = model_cfg.get("names") or []
89 if names:
90 rows.append(("Model components", ", ".join(names)))
91
92 # Tabular backbone with data types
93 tabular_val = "—"
94 for k, v in model_cfg.items():
95 if k in ("names", "hf_text", "timm_image"):
96 continue
97 if isinstance(v, dict) and "data_types" in v:
98 dtypes = v.get("data_types") or []
99 if any(t in ("categorical", "numerical") for t in dtypes):
100 dt_str = ", ".join(dtypes) if dtypes else ""
101 tabular_val = f"{k} ({dt_str})" if dt_str else k
102 break
103 rows.append(("Tabular backbone", tabular_val))
104
105 image_val = model_cfg.get("timm_image", {}).get("checkpoint_name") or "—"
106 rows.append(("Image backbone", image_val))
107
108 text_val = model_cfg.get("hf_text", {}).get("checkpoint_name") or "—"
109 rows.append(("Text backbone", text_val))
110
111 fusion_val = "—"
112 for k in model_cfg.keys():
113 if str(k).startswith("fusion"):
114 fusion_val = k
115 break
116 rows.append(("Fusion backbone", fusion_val))
117
118 # Optimizer block
119 optim_cfg = cfg.get("optim", {}) if isinstance(cfg, dict) else {}
120 optim_map = [
121 ("optim_type", "Optimizer"),
122 ("lr", "Learning rate"),
123 ("weight_decay", "Weight decay"),
124 ("lr_decay", "LR decay"),
125 ("max_epochs", "Max epochs"),
126 ("max_steps", "Max steps"),
127 ("patience", "Early-stop patience"),
128 ("check_val_every_n_epoch", "Val check every N epochs"),
129 ("top_k", "Top K checkpoints"),
130 ("top_k_average_method", "Top K averaging"),
131 ]
132 for key, label in optim_map:
133 if key in optim_cfg:
134 rows.append((label, optim_cfg[key]))
135
136 env_cfg = cfg.get("env", {}) if isinstance(cfg, dict) else {}
137 if "batch_size" in env_cfg:
138 rows.append(("Global batch size", env_cfg["batch_size"]))
139
140 return rows
141
142
143 def write_outputs(
144 args,
145 predictor,
146 problem_type: str,
147 eval_results: dict,
148 data_ctx: dict,
149 raw_folds=None,
150 ag_folds=None,
151 raw_metrics_std=None,
152 ag_by_split_std=None,
153 ):
154 from plot_logic import (
155 build_summary_html,
156 build_test_html_and_plots,
157 build_feature_html,
158 assemble_full_html_report,
159 build_train_html_and_plots,
160 )
161 from autogluon.multimodal import MultiModalPredictor
162 from metrics_logic import aggregate_metrics
163
164 raw_metrics = eval_results.get("raw_metrics", {})
165 ag_by_split = eval_results.get("ag_eval", {})
166 fit_summary_obj = eval_results.get("fit_summary")
167
168 df_train = data_ctx.get("train")
169 df_val = data_ctx.get("val")
170 df_test_internal = data_ctx.get("test_internal")
171 df_test_external = data_ctx.get("test_external")
172 df_test = df_test_external if df_test_external is not None else df_test_internal
173 df_train_full = df_train if df_val is None else pd.concat([df_train, df_val], ignore_index=True)
174
175 # Aggregate folds if provided without stds
176 if raw_folds and raw_metrics_std is None:
177 raw_metrics, raw_metrics_std = aggregate_metrics(raw_folds)
178 if ag_folds and ag_by_split_std is None:
179 ag_by_split, ag_by_split_std = aggregate_metrics(ag_folds)
180
181 # Inject AG eval into raw metrics for visibility
182 def _inject_ag(src: dict, dst: dict):
183 for k, v in (src or {}).items():
184 try:
185 dst[f"AG_{k}"] = float(v)
186 except Exception:
187 dst[f"AG_{k}"] = v
188 if "Train" in raw_metrics and "Train" in ag_by_split:
189 _inject_ag(ag_by_split["Train"], raw_metrics["Train"])
190 if "Validation" in raw_metrics and "Validation" in ag_by_split:
191 _inject_ag(ag_by_split["Validation"], raw_metrics["Validation"])
192 if "Test" in raw_metrics and "Test" in ag_by_split:
193 _inject_ag(ag_by_split["Test"], raw_metrics["Test"])
194
195 # JSON
196 with open(args.output_json, "w") as f:
197 json.dump(
198 {
199 "train": raw_metrics.get("Train", {}),
200 "val": raw_metrics.get("Validation", {}),
201 "test": raw_metrics.get("Test", {}),
202 "test_external": raw_metrics.get("Test (external)", {}),
203 "ag_eval": ag_by_split,
204 "ag_eval_std": ag_by_split_std,
205 "fit_summary": fit_summary_obj,
206 "problem_type": problem_type,
207 "predictor_path": getattr(predictor, "path", None),
208 "threshold": args.threshold,
209 "threshold_test": args.threshold,
210 "preset": args.preset,
211 "eval_metric": args.eval_metric,
212 "folds": {
213 "raw_folds": raw_folds,
214 "ag_folds": ag_folds,
215 "summary_mean": raw_metrics if raw_folds else None,
216 "summary_std": raw_metrics_std,
217 "ag_summary_mean": ag_by_split,
218 "ag_summary_std": ag_by_split_std,
219 },
220 },
221 f,
222 indent=2,
223 default=str,
224 )
225 logger.info(f"Wrote full JSON → {args.output_json}")
226
227 # HTML report assembly
228 label_col = args.target_column
229
230 class_balance_block_html = build_class_balance_html(
231 df_train=df_train,
232 label_col=label_col,
233 df_val=df_val,
234 df_test=df_test,
235 )
236 summary_perf_table_html = build_model_performance_summary_table(
237 train_scores=raw_metrics.get("Train", {}),
238 val_scores=raw_metrics.get("Validation", {}),
239 test_scores=raw_metrics.get("Test", {}),
240 include_test=True,
241 title=None,
242 show_title=False,
243 )
244
245 cfg_yaml = _load_config_yaml(args, predictor)
246 config_rows = _summarize_config(cfg_yaml, args)
247 threshold_rows = []
248 if problem_type == "binary" and args.threshold is not None:
249 threshold_rows.append(("Decision threshold (Test)", f"{float(args.threshold):.3f}"))
250 extra_run_rows = [
251 ("Target column", label_col),
252 ("Model evaluation metric", args.eval_metric or "AutoGluon default"),
253 ("Experiment quality", args.preset or "AutoGluon default"),
254 ] + threshold_rows + config_rows
255
256 summary_html = build_summary_html(
257 predictor=predictor,
258 df_train=df_train_full,
259 df_val=df_val,
260 df_test=df_test,
261 label_column=label_col,
262 extra_run_rows=extra_run_rows,
263 class_balance_html=class_balance_block_html,
264 perf_table_html=summary_perf_table_html,
265 )
266
267 train_tab_perf_html = build_model_performance_summary_table(
268 train_scores=raw_metrics.get("Train", {}),
269 val_scores=raw_metrics.get("Validation", {}),
270 test_scores=raw_metrics.get("Test", {}),
271 include_test=False,
272 title=None,
273 show_title=False,
274 )
275
276 train_html = build_train_html_and_plots(
277 predictor=predictor,
278 problem_type=problem_type,
279 df_train=df_train,
280 df_val=df_val,
281 label_column=label_col,
282 tmpdir=tempfile.mkdtemp(),
283 seed=int(args.random_seed),
284 perf_table_html=train_tab_perf_html,
285 threshold=args.threshold,
286 )
287
288 test_html_template, plots = build_test_html_and_plots(
289 predictor,
290 problem_type,
291 df_test,
292 label_col,
293 tempfile.mkdtemp(),
294 threshold=args.threshold,
295 )
296
297 def _fmt_val(v):
298 if isinstance(v, (int, np.integer)):
299 return f"{int(v)}"
300 if isinstance(v, (float, np.floating)):
301 return f"{v:.6f}"
302 return str(v)
303
304 test_scores = raw_metrics.get("Test", {})
305 # Drop AutoGluon-injected ROC AUC line from the Test Performance Summary
306 filtered_test_scores = {k: v for k, v in test_scores.items() if k != "AG_roc_auc"}
307 metric_rows = "".join(
308 f"<tr><td>{k.replace('_',' ').replace('(TNR)','(TNR)').replace('(Sensitivity/TPR)', '(Sensitivity/TPR)')}</td>"
309 f"<td>{_fmt_val(v)}</td></tr>"
310 for k, v in filtered_test_scores.items()
311 )
312 test_html_filled = test_html_template.format(metric_rows)
313
314 is_multimodal = isinstance(predictor, MultiModalPredictor)
315 leaderboard_html = "" if is_multimodal else build_leaderboard_html(predictor)
316 inputs_html = ""
317 ignored_features_html = "" if is_multimodal else build_ignored_features_html(predictor, df_train_full)
318 presets_hparams_html = build_presets_hparams_html(predictor)
319 notices: List[str] = []
320 if args.threshold is not None and problem_type == "binary":
321 notices.append(f"Using decision threshold = {float(args.threshold):.3f} on Test.")
322 warnings_html = build_warnings_html([], notices)
323 repro_html = build_reproducibility_html(args, {}, getattr(predictor, "path", None))
324
325 transparency_blocks = "\n".join(
326 [
327 leaderboard_html,
328 inputs_html,
329 ignored_features_html,
330 presets_hparams_html,
331 warnings_html,
332 repro_html,
333 ]
334 )
335
336 try:
337 feature_text = build_feature_html(predictor, df_test, label_col, tempfile.mkdtemp(), args.random_seed) if df_test is not None else ""
338 except Exception:
339 feature_text = "<p>Feature analysis unavailable for this model.</p>"
340
341 full_html = assemble_full_html_report(
342 summary_html,
343 train_html,
344 test_html_filled,
345 plots,
346 feature_text + transparency_blocks,
347 )
348 with open(args.output_html, "w") as f:
349 f.write(full_html)
350 logger.info(f"Wrote HTML report → {args.output_html}")
351
352 pred_path = _write_predictor_path(predictor)
353 _copy_config_if_available(pred_path, args.output_config)
354
355 outputs_to_check = [
356 (args.output_json, "JSON results"),
357 (args.output_html, "HTML report"),
358 ]
359 if args.output_config:
360 outputs_to_check.append((args.output_config, "AutoGluon config"))
361 verify_outputs(outputs_to_check)
362
363
364 def get_html_template() -> str:
365 """
366 Returns the opening HTML, <head> (with CSS/JS), and opens <body> + .container.
367 Includes:
368 - Base styling for layout and tables
369 - Sortable table headers with 3-state arrows (none ⇅, asc ↑, desc ↓)
370 - A scroll helper class (.scroll-rows-30) that approximates ~30 visible rows
371 - A guarded script so initializing runs only once even if injected twice
372 """
373 return """
374 <!DOCTYPE html>
375 <html>
376 <head>
377 <meta charset="UTF-8">
378 <title>Galaxy-Ludwig Report</title>
379 <style>
380 body {
381 font-family: Arial, sans-serif;
382 margin: 0;
383 padding: 20px;
384 background-color: #f4f4f4;
385 }
386 .container {
387 max-width: 1200px;
388 margin: auto;
389 background: white;
390 padding: 20px;
391 box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
392 overflow-x: auto;
393 }
394 h1 {
395 text-align: center;
396 color: #333;
397 }
398 h2 {
399 border-bottom: 2px solid #4CAF50;
400 color: #4CAF50;
401 padding-bottom: 5px;
402 margin-top: 28px;
403 }
404
405 /* baseline table setup */
406 table {
407 border-collapse: collapse;
408 margin: 20px 0;
409 width: 100%;
410 table-layout: fixed;
411 background: #fff;
412 }
413 table, th, td {
414 border: 1px solid #ddd;
415 }
416 th, td {
417 padding: 10px;
418 text-align: center;
419 vertical-align: middle;
420 word-break: break-word;
421 white-space: normal;
422 overflow-wrap: anywhere;
423 }
424 th {
425 background-color: #4CAF50;
426 color: white;
427 }
428
429 .plot {
430 text-align: center;
431 margin: 20px 0;
432 }
433 .plot img {
434 max-width: 100%;
435 height: auto;
436 border: 1px solid #ddd;
437 }
438
439 /* -------------------
440 sortable columns (3-state: none ⇅, asc ↑, desc ↓)
441 ------------------- */
442 table.performance-summary th.sortable {
443 cursor: pointer;
444 position: relative;
445 user-select: none;
446 }
447 /* default icon space */
448 table.performance-summary th.sortable::after {
449 content: '⇅';
450 position: absolute;
451 right: 12px;
452 top: 50%;
453 transform: translateY(-50%);
454 font-size: 0.8em;
455 color: #eaf5ea; /* light on green */
456 text-shadow: 0 0 1px rgba(0,0,0,0.15);
457 }
458 /* three states override the default */
459 table.performance-summary th.sortable.sorted-none::after { content: '⇅'; color: #eaf5ea; }
460 table.performance-summary th.sortable.sorted-asc::after { content: '↑'; color: #ffffff; }
461 table.performance-summary th.sortable.sorted-desc::after { content: '↓'; color: #ffffff; }
462
463 /* show ~30 rows with a scrollbar (tweak if you want) */
464 .scroll-rows-30 {
465 max-height: 900px; /* ~30 rows depending on row height */
466 overflow-y: auto; /* vertical scrollbar (“sidebar”) */
467 overflow-x: auto;
468 }
469
470 /* Tabs + Help button (used by build_tabbed_html) */
471 .tabs {
472 display: flex;
473 align-items: center;
474 border-bottom: 2px solid #ccc;
475 margin-bottom: 1rem;
476 gap: 6px;
477 flex-wrap: wrap;
478 }
479 .tab {
480 padding: 10px 20px;
481 cursor: pointer;
482 border: 1px solid #ccc;
483 border-bottom: none;
484 background: #f9f9f9;
485 margin-right: 5px;
486 border-top-left-radius: 8px;
487 border-top-right-radius: 8px;
488 }
489 .tab.active {
490 background: white;
491 font-weight: bold;
492 }
493 .help-btn {
494 margin-left: auto;
495 padding: 6px 12px;
496 font-size: 0.9rem;
497 border: 1px solid #4CAF50;
498 border-radius: 4px;
499 background: #4CAF50;
500 color: white;
501 cursor: pointer;
502 }
503 .tab-content {
504 display: none;
505 padding: 20px;
506 border: 1px solid #ccc;
507 border-top: none;
508 background: #fff;
509 }
510 .tab-content.active {
511 display: block;
512 }
513
514 /* Modal (used by get_metrics_help_modal) */
515 .modal {
516 display: none;
517 position: fixed;
518 z-index: 9999;
519 left: 0; top: 0;
520 width: 100%; height: 100%;
521 overflow: auto;
522 background-color: rgba(0,0,0,0.4);
523 }
524 .modal-content {
525 background-color: #fefefe;
526 margin: 8% auto;
527 padding: 20px;
528 border: 1px solid #888;
529 width: 90%;
530 max-width: 900px;
531 border-radius: 8px;
532 }
533 .modal .close {
534 color: #777;
535 float: right;
536 font-size: 28px;
537 font-weight: bold;
538 line-height: 1;
539 margin-left: 8px;
540 }
541 .modal .close:hover,
542 .modal .close:focus {
543 color: black;
544 text-decoration: none;
545 cursor: pointer;
546 }
547 .metrics-guide h3 { margin-top: 20px; }
548 .metrics-guide p { margin: 6px 0; }
549 .metrics-guide ul { margin: 10px 0; padding-left: 20px; }
550 </style>
551
552 <script>
553 // Guard to avoid double-initialization if this block is included twice
554 (function(){
555 if (window.__perfSummarySortInit) return;
556 window.__perfSummarySortInit = true;
557
558 function initPerfSummarySorting() {
559 // Record original order for "back to original"
560 document.querySelectorAll('table.performance-summary tbody').forEach(tbody => {
561 Array.from(tbody.rows).forEach((row, i) => { row.dataset.originalOrder = i; });
562 });
563
564 const getText = td => (td?.innerText || '').trim();
565 const cmp = (idx, asc) => (a, b) => {
566 const v1 = getText(a.children[idx]);
567 const v2 = getText(b.children[idx]);
568 const n1 = parseFloat(v1), n2 = parseFloat(v2);
569 if (!isNaN(n1) && !isNaN(n2)) return asc ? n1 - n2 : n2 - n1; // numeric
570 return asc ? v1.localeCompare(v2) : v2.localeCompare(v1); // lexical
571 };
572
573 document.querySelectorAll('table.performance-summary th.sortable').forEach(th => {
574 // initialize to “none”
575 th.classList.remove('sorted-asc','sorted-desc');
576 th.classList.add('sorted-none');
577
578 th.addEventListener('click', () => {
579 const table = th.closest('table');
580 const headerRow = th.parentNode;
581 const allTh = headerRow.querySelectorAll('th.sortable');
582 const tbody = table.querySelector('tbody');
583
584 // Determine current state BEFORE clearing
585 const isAsc = th.classList.contains('sorted-asc');
586 const isDesc = th.classList.contains('sorted-desc');
587
588 // Reset all headers in this row
589 allTh.forEach(x => x.classList.remove('sorted-asc','sorted-desc','sorted-none'));
590
591 // Compute next state
592 let next;
593 if (!isAsc && !isDesc) {
594 next = 'asc';
595 } else if (isAsc) {
596 next = 'desc';
597 } else {
598 next = 'none';
599 }
600 th.classList.add('sorted-' + next);
601
602 // Sort rows according to the chosen state
603 const rows = Array.from(tbody.rows);
604 if (next === 'none') {
605 rows.sort((a, b) => (a.dataset.originalOrder - b.dataset.originalOrder));
606 } else {
607 const idx = Array.from(headerRow.children).indexOf(th);
608 rows.sort(cmp(idx, next === 'asc'));
609 }
610 rows.forEach(r => tbody.appendChild(r));
611 });
612 });
613 }
614
615 // Run after DOM is ready
616 if (document.readyState === 'loading') {
617 document.addEventListener('DOMContentLoaded', initPerfSummarySorting);
618 } else {
619 initPerfSummarySorting();
620 }
621 })();
622 </script>
623 </head>
624 <body>
625 <div class="container">
626 """
627
628
629 def get_html_closing():
630 """Closes .container, body, and html."""
631 return """
632 </div>
633 </body>
634 </html>
635 """
636
637
638 def build_tabbed_html(
639 summary_html: str,
640 train_html: str,
641 test_html: str,
642 feature_html: str,
643 explainer_html: Optional[str] = None,
644 ) -> str:
645 """
646 Renders the tab headers, contents, and JS to switch tabs.
647 """
648 tabs = [
649 '<div class="tabs">',
650 '<div class="tab active" onclick="showTab(\'summary\')">Model Metric Summary and Config</div>',
651 '<div class="tab" onclick="showTab(\'train\')">Train and Validation Summary</div>',
652 '<div class="tab" onclick="showTab(\'test\')">Test Summary</div>',
653 ]
654 if explainer_html:
655 tabs.append('<div class="tab" onclick="showTab(\'explainer\')">Explainer Plots</div>')
656 tabs.append('<button id="openMetricsHelp" class="help-btn">Help</button>')
657 tabs.append('</div>')
658 tabs_section = "\n".join(tabs)
659
660 contents = [
661 f'<div id="summary" class="tab-content active">{summary_html}</div>',
662 f'<div id="train" class="tab-content">{train_html}</div>',
663 f'<div id="test" class="tab-content">{test_html}</div>',
664 ]
665 if explainer_html:
666 contents.append(f'<div id="explainer" class="tab-content">{explainer_html}</div>')
667 content_section = "\n".join(contents)
668
669 js = """
670 <script>
671 function showTab(id) {
672 document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active'));
673 document.querySelectorAll('.tab').forEach(el => el.classList.remove('active'));
674 document.getElementById(id).classList.add('active');
675 document.querySelector(`.tab[onclick*="${id}"]`).classList.add('active');
676 }
677 </script>
678 """
679 return tabs_section + "\n" + content_section + "\n" + js
680
681
682 def encode_image_to_base64(image_path: str) -> str:
683 """
684 Reads an image file from disk and returns a base64-encoded string
685 for embedding directly in HTML <img> tags.
686 """
687 try:
688 with open(image_path, "rb") as img_f:
689 return base64.b64encode(img_f.read()).decode("utf-8")
690 except Exception as e:
691 logger.error(f"Failed to encode image '{image_path}': {e}")
692 return ""
693
694
695 def get_model_architecture(predictor: Any) -> str:
696 """
697 Returns a human-friendly description of the final model architecture based on the
698 MultiModalPredictor configuration (e.g., timm_image=resnet50, hf_text=bert-base-uncased).
699 """
700 # MultiModalPredictor path: read backbones from config if available
701 archs = []
702 for attr in ("_config", "config"):
703 cfg = getattr(predictor, attr, None)
704 try:
705 model_cfg = getattr(cfg, "model", None)
706 if model_cfg:
707 # OmegaConf-like mapping
708 for name, sub in dict(model_cfg).items():
709 ck = None
710 # sub may be an object or a dict-like node
711 for k in ("checkpoint_name", "name", "model_name"):
712 try:
713 ck = getattr(sub, k)
714 except Exception:
715 ck = sub.get(k) if isinstance(sub, dict) else ck
716 if ck:
717 break
718 if ck:
719 archs.append(f"{name}={ck}")
720 except Exception:
721 continue
722
723 if archs:
724 return ", ".join(archs)
725
726 # Fallback
727 return type(predictor).__name__
728
729
730 def collect_run_context(args, predictor, problem_type: str,
731 df_train: pd.DataFrame, df_val: pd.DataFrame, df_test: pd.DataFrame,
732 warnings_list: List[str],
733 notes_list: List[str]) -> Dict[str, Any]:
734 """Build a dictionary with run/system context for transparency."""
735 # System info (best-effort; not depending on AutoGluon stdout)
736 try:
737 import psutil # optional
738 mem = psutil.virtual_memory()
739 mem_total_gb = mem.total / (1024 ** 3)
740 mem_avail_gb = mem.available / (1024 ** 3)
741 except Exception:
742 mem_total_gb = mem_avail_gb = None
743
744 ctx = {
745 "timestamp": datetime.now().isoformat(timespec="seconds"),
746 "python_version": platform.python_version(),
747 "platform": {
748 "system": platform.system(),
749 "release": platform.release(),
750 "version": platform.version(),
751 "machine": platform.machine(),
752 },
753 "cpu_count": os.cpu_count(),
754 "memory_total_gb": mem_total_gb,
755 "memory_available_gb": mem_avail_gb,
756 "packages": {},
757 "problem_type": problem_type,
758 "label_column": args.label_column,
759 "time_limit_sec": args.time_limit,
760 "random_seed": args.random_seed,
761 "splits": {
762 "train_rows": int(len(df_train)),
763 "val_rows": int(len(df_val)),
764 "test_rows": int(len(df_test)),
765 "n_features_raw": int(len(df_train.columns) - 1), # minus label
766 },
767 "warnings": warnings_list,
768 "notes": notes_list,
769 }
770 # Package versions (safe best-effort)
771 try:
772 import autogluon
773 ctx["packages"]["autogluon"] = getattr(autogluon, "__version__", "unknown")
774 except Exception:
775 pass
776 try:
777 import torch as _torch
778 ctx["packages"]["torch"] = getattr(_torch, "__version__", "unknown")
779 except Exception:
780 pass
781 try:
782 import sklearn
783 ctx["packages"]["scikit_learn"] = getattr(sklearn, "__version__", "unknown")
784 except Exception:
785 pass
786 try:
787 import numpy as _np
788 ctx["packages"]["numpy"] = getattr(_np, "__version__", "unknown")
789 except Exception:
790 pass
791 try:
792 import pandas as _pd
793 ctx["packages"]["pandas"] = getattr(_pd, "__version__", "unknown")
794 except Exception:
795 pass
796 return ctx
797
798
799 def build_class_balance_html(
800 df_train: Optional[pd.DataFrame],
801 label_col: str,
802 df_val: Optional[pd.DataFrame] = None,
803 df_test: Optional[pd.DataFrame] = None,
804 ) -> str:
805 """
806 Render label counts for each available split (Train/Validation/Test).
807 """
808 def _count_labels(frame: Optional[pd.DataFrame]) -> pd.Series:
809 if frame is None or label_col not in frame:
810 return pd.Series(dtype=int)
811 series = frame[label_col]
812 if series.dtype.kind in "ifu":
813 return pd.Series(series).value_counts(dropna=False).sort_index()
814 return pd.Series(series.astype(str)).value_counts(dropna=False)
815
816 counts_train = _count_labels(df_train)
817 counts_val = _count_labels(df_val)
818 counts_test = _count_labels(df_test)
819
820 labels: list[Any] = []
821 for idx in (counts_train.index, counts_val.index, counts_test.index):
822 for label in idx:
823 if label not in labels:
824 labels.append(label)
825
826 has_train = df_train is not None
827 has_val = df_val is not None
828 has_test = df_test is not None
829
830 def _fmt_count(counts: pd.Series, label: Any, enabled: bool) -> str:
831 if not enabled:
832 return "—"
833 return str(int(counts.get(label, 0)))
834
835 rows = [
836 f"<tr><td>{_escape(label)}</td>"
837 f"<td>{_fmt_count(counts_train, label, has_train)}</td>"
838 f"<td>{_fmt_count(counts_val, label, has_val)}</td>"
839 f"<td>{_fmt_count(counts_test, label, has_test)}</td></tr>"
840 for label in labels
841 ]
842
843 if not rows:
844 return "<p>No label distribution available.</p>"
845
846 return f"""
847 <h3>Label Counts by Split</h3>
848 <table class="table">
849 <thead><tr><th>Label</th><th>Train</th><th>Validation</th><th>Test</th></tr></thead>
850 <tbody>
851 {''.join(rows)}
852 </tbody>
853 </table>
854 """
855
856
857 def build_leaderboard_html(predictor) -> str:
858 try:
859 lb = predictor.leaderboard(silent=True)
860 # keep common helpful columns if present
861 cols_pref = ["model", "score_val", "eval_metric", "pred_time_val", "fit_time",
862 "pred_time_val_marginal", "fit_time_marginal", "stack_level", "can_infer", "fit_order"]
863 cols = [c for c in cols_pref if c in lb.columns] or list(lb.columns)
864 return "<h3>Model Leaderboard (Validation)</h3>" + lb[cols].to_html(index=False)
865 except Exception as e:
866 return f"<h3>Model Leaderboard</h3><p>Unavailable: {_escape(e)}</p>"
867
868
869 def build_ignored_features_html(predictor, df_any: pd.DataFrame) -> str:
870 # MultiModalPredictor does not always expose .features(); guard accordingly.
871 used = set()
872 try:
873 used = set(predictor.features())
874 except Exception:
875 # If we can't determine, don't emit a misleading section
876 return ""
877 raw_cols = [c for c in df_any.columns if c != getattr(predictor, "label", None)]
878 ignored = [c for c in raw_cols if c not in used]
879 if not ignored:
880 return ""
881 items = "".join(f"<li>{html.escape(c)}</li>" for c in ignored)
882 return f"""
883 <h3>Ignored / Unused Features</h3>
884 <p>The following columns were not used by the trained predictor at inference time:</p>
885 <ul>{items}</ul>
886 """
887
888
889 def build_presets_hparams_html(predictor) -> str:
890 # MultiModalPredictor path
891 mm_hp = {}
892 for attr in ("_config", "config", "_fit_args"):
893 if hasattr(predictor, attr):
894 try:
895 val = getattr(predictor, attr)
896 # make it JSON-ish
897 mm_hp[attr] = str(val)
898 except Exception:
899 continue
900 hp_html = f"<pre>{html.escape(json.dumps(mm_hp, indent=2))}</pre>" if mm_hp else "<i>Unavailable</i>"
901 return f"<h3>Training Presets & Hyperparameters</h3><details open><summary>Show hyperparameters</summary>{hp_html}</details>"
902
903
904 def build_warnings_html(warnings_list: List[str], notes_list: List[str]) -> str:
905 if not warnings_list and not notes_list:
906 return ""
907 w_html = "".join(f"<li>{_escape(w)}</li>" for w in warnings_list)
908 n_html = "".join(f"<li>{_escape(n)}</li>" for n in notes_list)
909 return f"""
910 <h3>Warnings & Notices</h3>
911 {'<h4>Warnings</h4><ul>'+w_html+'</ul>' if warnings_list else ''}
912 {'<h4>Notices</h4><ul>'+n_html+'</ul>' if notes_list else ''}
913 """
914
915
916 def build_reproducibility_html(args, ctx: Dict[str, Any], model_path: Optional[str]) -> str:
917 cmd = " ".join(_escape(x) for x in sys.argv)
918 load_snippet = ""
919 if model_path:
920 load_snippet = f"""<pre>
921 from autogluon.multimodal import MultiModalPredictor
922 predictor = MultiModalPredictor.load("{_escape(model_path)}")
923 </pre>"""
924 pkg_rows = "".join(f"<tr><td>{_escape(k)}</td><td>{_escape(v)}</td></tr>" for k, v in (ctx.get("packages") or {}).items())
925 sys_table = f"""
926 <table class="table">
927 <tbody>
928 <tr><th>Timestamp</th><td>{_escape(ctx.get('timestamp'))}</td></tr>
929 <tr><th>Python</th><td>{_escape(ctx.get('python_version'))}</td></tr>
930 <tr><th>Platform</th><td>{_escape(ctx.get('platform'))}</td></tr>
931 <tr><th>CPU Count</th><td>{_escape(ctx.get('cpu_count'))}</td></tr>
932 <tr><th>Memory (GB)</th><td>Total: {_escape(ctx.get('memory_total_gb'))} | Avail: {_escape(ctx.get('memory_available_gb'))}</td></tr>
933 <tr><th>Seed</th><td>{_escape(ctx.get('random_seed'))}</td></tr>
934 <tr><th>Time Limit (s)</th><td>{_escape(ctx.get('time_limit_sec'))}</td></tr>
935 </tbody>
936 </table>
937 """
938 pkgs_table = f"""
939 <h4>Package Versions</h4>
940 <table class="table">
941 <thead><tr><th>Package</th><th>Version</th></tr></thead>
942 <tbody>{pkg_rows}</tbody>
943 </table>
944 """
945 return f"""
946 <h3>Reproducibility</h3>
947 <h4>Command</h4>
948 <pre>{cmd}</pre>
949 {sys_table}
950 {pkgs_table}
951 <h4>Load Trained Model</h4>
952 {load_snippet or '<i>Model path not available</i>'}
953 """
954
955
956 def build_modalities_html(predictor, df_any: pd.DataFrame, label_col: str, image_col: Optional[str]) -> str:
957 """Summarize which inputs/modalities are used for MultiModalPredictor."""
958 cols = [c for c in df_any.columns]
959 # exclude label from feature list
960 feat_cols = [c for c in cols if c != label_col]
961 # identify image vs tabular columns from args / presence
962 img_present = (image_col in df_any.columns) if image_col else False
963 tab_cols = [c for c in feat_cols if c != image_col]
964
965 # brief lists (avoid dumping all, unless small)
966 def list_or_count(arr, max_show=20):
967 if len(arr) <= max_show:
968 items = "".join(f"<li>{html.escape(str(x))}</li>" for x in arr)
969 return f"<ul>{items}</ul>"
970 return f"<p>{len(arr)} columns</p>"
971
972 img_block = f"<p><b>Image column:</b> {html.escape(image_col)}</p>" if img_present else "<p><b>Image column:</b> None</p>"
973 tab_block = f"<div><b>Structured columns:</b> {len(tab_cols)}{list_or_count(tab_cols, max_show=15)}</div>"
974
975 return f"""
976 <h3>Modalities & Inputs</h3>
977 <p>This run used <b>MultiModalPredictor</b> (images + structured features).</p>
978 <p><b>Label column:</b> {html.escape(label_col)}</p>
979 {img_block}
980 {tab_block}
981 """
982
983
984 def build_model_performance_summary_table(
985 train_scores: dict,
986 val_scores: dict,
987 test_scores: dict | None = None,
988 include_test: bool = True,
989 title: str | None = 'Model Performance Summary',
990 show_title: bool = True,
991 ) -> str:
992 """
993 Returns an HTML table for metrics, optionally hiding the Test column.
994 Keys across score dicts are unioned; missing values render as '—'.
995 """
996 def fmt(v):
997 if v is None:
998 return '—'
999 if isinstance(v, (int, float)):
1000 return f'{v:.4f}'
1001 return str(v)
1002
1003 # Collect union of metric keys across splits
1004 metrics = set(train_scores.keys()) | set(val_scores.keys()) | (set(test_scores.keys()) if (include_test and test_scores) else set())
1005
1006 # Remove AG_roc_auc entirely as requested
1007 metrics.discard('AG_roc_auc')
1008
1009 # Helper: normalize metric keys for matching preferred names
1010 def _norm(k: str) -> str:
1011 return ''.join(ch for ch in str(k).lower() if ch.isalnum())
1012
1013 # Preferred metrics to appear at the end in this specific order (display names):
1014 preferred_display = ['Accuracy', 'ROC-AUC', 'Precision', 'Recall', 'F1-Score', 'PR-AUC', 'Specificity', 'MCC', 'LogLoss']
1015 # Mapping of normalized key -> display label
1016 norm_to_display = {
1017 'accuracy': 'Accuracy',
1018 'acc': 'Accuracy',
1019 'rocauc': 'ROC-AUC',
1020 'roc_auc': 'ROC-AUC',
1021 'rocaucscore': 'ROC-AUC',
1022 'precision': 'Precision',
1023 'prec': 'Precision',
1024 'recall': 'Recall',
1025 'recallsensitivitytpr': 'Recall',
1026 'f1': 'F1-Score',
1027 'f1score': 'F1-Score',
1028 'pr_auc': 'PR-AUC',
1029 'prauc': 'PR-AUC',
1030 'averageprecision': 'PR-AUC',
1031 'specificity': 'Specificity',
1032 'tnr': 'Specificity',
1033 'mcc': 'MCC',
1034 'logloss': 'LogLoss',
1035 'crossentropy': 'LogLoss',
1036 }
1037
1038 # Build ordered list: all non-preferred metrics sorted alphabetically, then preferred metrics in the requested order if present
1039 preferred_norms = [_norm(x) for x in preferred_display]
1040 all_metrics = list(metrics)
1041 # Partition
1042 preferred_present = []
1043 others = []
1044 for m in sorted(all_metrics):
1045 nm = _norm(m)
1046 if nm in preferred_norms or any(
1047 p in nm for p in ["rocauc", "prauc", "f1", "mcc", "logloss", "accuracy", "precision", "recall", "specificity"]
1048 ):
1049 # Defer preferred-like metrics to the end (we will place them in canonical order)
1050 preferred_present.append(m)
1051 else:
1052 others.append(m)
1053
1054 # Now assemble final metric order: others (alpha), then preferred in exact requested order if they exist in metrics
1055 final_metrics = []
1056 final_metrics.extend(others)
1057 for disp in preferred_display:
1058 # find any original key matching this display (by normalized mapping)
1059 target_norm = _norm(disp)
1060 found = None
1061 for m in preferred_present:
1062 if _norm(m) == target_norm or norm_to_display.get(_norm(m)) == disp or _norm(m).replace(' ', '') == target_norm:
1063 found = m
1064 break
1065 # also allow substring matches (e.g., 'roc_auc' vs 'rocauc')
1066 if target_norm in _norm(m):
1067 found = m
1068 break
1069 if found:
1070 final_metrics.append(found)
1071
1072 metrics = final_metrics
1073
1074 # Make all headers sortable by adding the 'sortable' class; the JS in utils.py hooks table.performance-summary
1075 header_cells = [
1076 '<th class="sortable">Metric</th>',
1077 '<th class="sortable">Train</th>',
1078 '<th class="sortable">Validation</th>'
1079 ]
1080 if include_test and test_scores:
1081 header_cells.append('<th class="sortable">Test</th>')
1082
1083 rows_html = []
1084 for m in metrics:
1085 # Display label mapping: clean up common verbose names
1086 disp = m
1087 nm = _norm(m)
1088 if nm in norm_to_display:
1089 disp = norm_to_display[nm]
1090 else:
1091 # generic cleanup: replace underscores with space and remove parenthetical qualifiers
1092 disp = str(m).replace('_', ' ')
1093 disp = disp.replace('(Sensitivity/TPR)', '')
1094 disp = disp.replace('(TNR)', '')
1095 disp = disp.strip()
1096
1097 cells = [
1098 f'<td>{_escape(disp)}</td>',
1099 f'<td>{fmt(train_scores.get(m))}</td>',
1100 f'<td>{fmt(val_scores.get(m))}</td>',
1101 ]
1102 if include_test and test_scores:
1103 cells.append(f'<td>{fmt(test_scores.get(m))}</td>')
1104
1105 rows_html.append('<tr>' + ''.join(cells) + '</tr>')
1106
1107 title_html = f'<h3 style="margin-top:0">{title}</h3>' if (show_title and title) else ''
1108
1109 table_html = f"""
1110 {title_html}
1111 <table class="performance-summary">
1112 <thead><tr>{''.join(header_cells)}</tr></thead>
1113 <tbody>{''.join(rows_html)}</tbody>
1114 </table>
1115 """
1116 return table_html