Mercurial > repos > goeckslab > multimodal_learner
comparison report_utils.py @ 0:375c36923da1 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
| author | goeckslab |
|---|---|
| date | Tue, 09 Dec 2025 23:49:47 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:375c36923da1 |
|---|---|
| 1 import base64 | |
| 2 import html | |
| 3 import json | |
| 4 import logging | |
| 5 import os | |
| 6 import platform | |
| 7 import shutil | |
| 8 import sys | |
| 9 import tempfile | |
| 10 from datetime import datetime | |
| 11 from typing import Any, Dict, List, Optional | |
| 12 | |
| 13 import numpy as np | |
| 14 import pandas as pd | |
| 15 import yaml | |
| 16 from utils import verify_outputs | |
| 17 | |
| 18 logger = logging.getLogger(__name__) | |
| 19 | |
| 20 | |
| 21 def _escape(s: Any) -> str: | |
| 22 return html.escape(str(s)) | |
| 23 | |
| 24 | |
| 25 def _write_predictor_path(predictor): | |
| 26 try: | |
| 27 pred_path = getattr(predictor, "path", None) | |
| 28 if pred_path: | |
| 29 with open("predictor_path.txt", "w") as pf: | |
| 30 pf.write(str(pred_path)) | |
| 31 logger.info("Wrote predictor path → predictor_path.txt") | |
| 32 return pred_path | |
| 33 except Exception: | |
| 34 logger.warning("Could not write predictor_path.txt") | |
| 35 return None | |
| 36 | |
| 37 | |
| 38 def _copy_config_if_available(pred_path: Optional[str], output_config: Optional[str]): | |
| 39 if not output_config: | |
| 40 return | |
| 41 try: | |
| 42 config_yaml_path = os.path.join(pred_path, "config.yaml") if pred_path else None | |
| 43 if config_yaml_path and os.path.isfile(config_yaml_path): | |
| 44 shutil.copy2(config_yaml_path, output_config) | |
| 45 logger.info(f"Wrote AutoGluon config → {output_config}") | |
| 46 else: | |
| 47 with open(output_config, "w") as cfg_out: | |
| 48 cfg_out.write("# config.yaml not found for this run\n") | |
| 49 logger.warning(f"AutoGluon config.yaml not found; created placeholder at {output_config}") | |
| 50 except Exception as e: | |
| 51 logger.error(f"Failed to write config output '{output_config}': {e}") | |
| 52 try: | |
| 53 with open(output_config, "w") as cfg_out: | |
| 54 cfg_out.write(f"# Failed to copy config.yaml: {e}\n") | |
| 55 except Exception: | |
| 56 pass | |
| 57 | |
| 58 | |
| 59 def _load_config_yaml(args, predictor) -> dict: | |
| 60 """ | |
| 61 Load config.yaml either from the predictor path or the exported output_config. | |
| 62 """ | |
| 63 candidates = [] | |
| 64 pred_path = getattr(predictor, "path", None) | |
| 65 if pred_path: | |
| 66 cfg_path = os.path.join(pred_path, "config.yaml") | |
| 67 if os.path.isfile(cfg_path): | |
| 68 candidates.append(cfg_path) | |
| 69 if args.output_config and os.path.isfile(args.output_config): | |
| 70 candidates.append(args.output_config) | |
| 71 | |
| 72 for p in candidates: | |
| 73 try: | |
| 74 with open(p, "r") as f: | |
| 75 return yaml.safe_load(f) or {} | |
| 76 except Exception: | |
| 77 continue | |
| 78 return {} | |
| 79 | |
| 80 | |
| 81 def _summarize_config(cfg: dict, args) -> List[tuple[str, str]]: | |
| 82 """ | |
| 83 Build rows describing model components and key hyperparameters from a loaded config.yaml. | |
| 84 Falls back to CLI args when config values are missing. | |
| 85 """ | |
| 86 rows: List[tuple[str, str]] = [] | |
| 87 model_cfg = cfg.get("model", {}) if isinstance(cfg, dict) else {} | |
| 88 names = model_cfg.get("names") or [] | |
| 89 if names: | |
| 90 rows.append(("Model components", ", ".join(names))) | |
| 91 | |
| 92 # Tabular backbone with data types | |
| 93 tabular_val = "—" | |
| 94 for k, v in model_cfg.items(): | |
| 95 if k in ("names", "hf_text", "timm_image"): | |
| 96 continue | |
| 97 if isinstance(v, dict) and "data_types" in v: | |
| 98 dtypes = v.get("data_types") or [] | |
| 99 if any(t in ("categorical", "numerical") for t in dtypes): | |
| 100 dt_str = ", ".join(dtypes) if dtypes else "" | |
| 101 tabular_val = f"{k} ({dt_str})" if dt_str else k | |
| 102 break | |
| 103 rows.append(("Tabular backbone", tabular_val)) | |
| 104 | |
| 105 image_val = model_cfg.get("timm_image", {}).get("checkpoint_name") or "—" | |
| 106 rows.append(("Image backbone", image_val)) | |
| 107 | |
| 108 text_val = model_cfg.get("hf_text", {}).get("checkpoint_name") or "—" | |
| 109 rows.append(("Text backbone", text_val)) | |
| 110 | |
| 111 fusion_val = "—" | |
| 112 for k in model_cfg.keys(): | |
| 113 if str(k).startswith("fusion"): | |
| 114 fusion_val = k | |
| 115 break | |
| 116 rows.append(("Fusion backbone", fusion_val)) | |
| 117 | |
| 118 # Optimizer block | |
| 119 optim_cfg = cfg.get("optim", {}) if isinstance(cfg, dict) else {} | |
| 120 optim_map = [ | |
| 121 ("optim_type", "Optimizer"), | |
| 122 ("lr", "Learning rate"), | |
| 123 ("weight_decay", "Weight decay"), | |
| 124 ("lr_decay", "LR decay"), | |
| 125 ("max_epochs", "Max epochs"), | |
| 126 ("max_steps", "Max steps"), | |
| 127 ("patience", "Early-stop patience"), | |
| 128 ("check_val_every_n_epoch", "Val check every N epochs"), | |
| 129 ("top_k", "Top K checkpoints"), | |
| 130 ("top_k_average_method", "Top K averaging"), | |
| 131 ] | |
| 132 for key, label in optim_map: | |
| 133 if key in optim_cfg: | |
| 134 rows.append((label, optim_cfg[key])) | |
| 135 | |
| 136 env_cfg = cfg.get("env", {}) if isinstance(cfg, dict) else {} | |
| 137 if "batch_size" in env_cfg: | |
| 138 rows.append(("Global batch size", env_cfg["batch_size"])) | |
| 139 | |
| 140 return rows | |
| 141 | |
| 142 | |
| 143 def write_outputs( | |
| 144 args, | |
| 145 predictor, | |
| 146 problem_type: str, | |
| 147 eval_results: dict, | |
| 148 data_ctx: dict, | |
| 149 raw_folds=None, | |
| 150 ag_folds=None, | |
| 151 raw_metrics_std=None, | |
| 152 ag_by_split_std=None, | |
| 153 ): | |
| 154 from plot_logic import ( | |
| 155 build_summary_html, | |
| 156 build_test_html_and_plots, | |
| 157 build_feature_html, | |
| 158 assemble_full_html_report, | |
| 159 build_train_html_and_plots, | |
| 160 ) | |
| 161 from autogluon.multimodal import MultiModalPredictor | |
| 162 from metrics_logic import aggregate_metrics | |
| 163 | |
| 164 raw_metrics = eval_results.get("raw_metrics", {}) | |
| 165 ag_by_split = eval_results.get("ag_eval", {}) | |
| 166 fit_summary_obj = eval_results.get("fit_summary") | |
| 167 | |
| 168 df_train = data_ctx.get("train") | |
| 169 df_val = data_ctx.get("val") | |
| 170 df_test_internal = data_ctx.get("test_internal") | |
| 171 df_test_external = data_ctx.get("test_external") | |
| 172 df_test = df_test_external if df_test_external is not None else df_test_internal | |
| 173 df_train_full = df_train if df_val is None else pd.concat([df_train, df_val], ignore_index=True) | |
| 174 | |
| 175 # Aggregate folds if provided without stds | |
| 176 if raw_folds and raw_metrics_std is None: | |
| 177 raw_metrics, raw_metrics_std = aggregate_metrics(raw_folds) | |
| 178 if ag_folds and ag_by_split_std is None: | |
| 179 ag_by_split, ag_by_split_std = aggregate_metrics(ag_folds) | |
| 180 | |
| 181 # Inject AG eval into raw metrics for visibility | |
| 182 def _inject_ag(src: dict, dst: dict): | |
| 183 for k, v in (src or {}).items(): | |
| 184 try: | |
| 185 dst[f"AG_{k}"] = float(v) | |
| 186 except Exception: | |
| 187 dst[f"AG_{k}"] = v | |
| 188 if "Train" in raw_metrics and "Train" in ag_by_split: | |
| 189 _inject_ag(ag_by_split["Train"], raw_metrics["Train"]) | |
| 190 if "Validation" in raw_metrics and "Validation" in ag_by_split: | |
| 191 _inject_ag(ag_by_split["Validation"], raw_metrics["Validation"]) | |
| 192 if "Test" in raw_metrics and "Test" in ag_by_split: | |
| 193 _inject_ag(ag_by_split["Test"], raw_metrics["Test"]) | |
| 194 | |
| 195 # JSON | |
| 196 with open(args.output_json, "w") as f: | |
| 197 json.dump( | |
| 198 { | |
| 199 "train": raw_metrics.get("Train", {}), | |
| 200 "val": raw_metrics.get("Validation", {}), | |
| 201 "test": raw_metrics.get("Test", {}), | |
| 202 "test_external": raw_metrics.get("Test (external)", {}), | |
| 203 "ag_eval": ag_by_split, | |
| 204 "ag_eval_std": ag_by_split_std, | |
| 205 "fit_summary": fit_summary_obj, | |
| 206 "problem_type": problem_type, | |
| 207 "predictor_path": getattr(predictor, "path", None), | |
| 208 "threshold": args.threshold, | |
| 209 "threshold_test": args.threshold, | |
| 210 "preset": args.preset, | |
| 211 "eval_metric": args.eval_metric, | |
| 212 "folds": { | |
| 213 "raw_folds": raw_folds, | |
| 214 "ag_folds": ag_folds, | |
| 215 "summary_mean": raw_metrics if raw_folds else None, | |
| 216 "summary_std": raw_metrics_std, | |
| 217 "ag_summary_mean": ag_by_split, | |
| 218 "ag_summary_std": ag_by_split_std, | |
| 219 }, | |
| 220 }, | |
| 221 f, | |
| 222 indent=2, | |
| 223 default=str, | |
| 224 ) | |
| 225 logger.info(f"Wrote full JSON → {args.output_json}") | |
| 226 | |
| 227 # HTML report assembly | |
| 228 label_col = args.target_column | |
| 229 | |
| 230 class_balance_block_html = build_class_balance_html( | |
| 231 df_train=df_train, | |
| 232 label_col=label_col, | |
| 233 df_val=df_val, | |
| 234 df_test=df_test, | |
| 235 ) | |
| 236 summary_perf_table_html = build_model_performance_summary_table( | |
| 237 train_scores=raw_metrics.get("Train", {}), | |
| 238 val_scores=raw_metrics.get("Validation", {}), | |
| 239 test_scores=raw_metrics.get("Test", {}), | |
| 240 include_test=True, | |
| 241 title=None, | |
| 242 show_title=False, | |
| 243 ) | |
| 244 | |
| 245 cfg_yaml = _load_config_yaml(args, predictor) | |
| 246 config_rows = _summarize_config(cfg_yaml, args) | |
| 247 threshold_rows = [] | |
| 248 if problem_type == "binary" and args.threshold is not None: | |
| 249 threshold_rows.append(("Decision threshold (Test)", f"{float(args.threshold):.3f}")) | |
| 250 extra_run_rows = [ | |
| 251 ("Target column", label_col), | |
| 252 ("Model evaluation metric", args.eval_metric or "AutoGluon default"), | |
| 253 ("Experiment quality", args.preset or "AutoGluon default"), | |
| 254 ] + threshold_rows + config_rows | |
| 255 | |
| 256 summary_html = build_summary_html( | |
| 257 predictor=predictor, | |
| 258 df_train=df_train_full, | |
| 259 df_val=df_val, | |
| 260 df_test=df_test, | |
| 261 label_column=label_col, | |
| 262 extra_run_rows=extra_run_rows, | |
| 263 class_balance_html=class_balance_block_html, | |
| 264 perf_table_html=summary_perf_table_html, | |
| 265 ) | |
| 266 | |
| 267 train_tab_perf_html = build_model_performance_summary_table( | |
| 268 train_scores=raw_metrics.get("Train", {}), | |
| 269 val_scores=raw_metrics.get("Validation", {}), | |
| 270 test_scores=raw_metrics.get("Test", {}), | |
| 271 include_test=False, | |
| 272 title=None, | |
| 273 show_title=False, | |
| 274 ) | |
| 275 | |
| 276 train_html = build_train_html_and_plots( | |
| 277 predictor=predictor, | |
| 278 problem_type=problem_type, | |
| 279 df_train=df_train, | |
| 280 df_val=df_val, | |
| 281 label_column=label_col, | |
| 282 tmpdir=tempfile.mkdtemp(), | |
| 283 seed=int(args.random_seed), | |
| 284 perf_table_html=train_tab_perf_html, | |
| 285 threshold=args.threshold, | |
| 286 ) | |
| 287 | |
| 288 test_html_template, plots = build_test_html_and_plots( | |
| 289 predictor, | |
| 290 problem_type, | |
| 291 df_test, | |
| 292 label_col, | |
| 293 tempfile.mkdtemp(), | |
| 294 threshold=args.threshold, | |
| 295 ) | |
| 296 | |
| 297 def _fmt_val(v): | |
| 298 if isinstance(v, (int, np.integer)): | |
| 299 return f"{int(v)}" | |
| 300 if isinstance(v, (float, np.floating)): | |
| 301 return f"{v:.6f}" | |
| 302 return str(v) | |
| 303 | |
| 304 test_scores = raw_metrics.get("Test", {}) | |
| 305 # Drop AutoGluon-injected ROC AUC line from the Test Performance Summary | |
| 306 filtered_test_scores = {k: v for k, v in test_scores.items() if k != "AG_roc_auc"} | |
| 307 metric_rows = "".join( | |
| 308 f"<tr><td>{k.replace('_',' ').replace('(TNR)','(TNR)').replace('(Sensitivity/TPR)', '(Sensitivity/TPR)')}</td>" | |
| 309 f"<td>{_fmt_val(v)}</td></tr>" | |
| 310 for k, v in filtered_test_scores.items() | |
| 311 ) | |
| 312 test_html_filled = test_html_template.format(metric_rows) | |
| 313 | |
| 314 is_multimodal = isinstance(predictor, MultiModalPredictor) | |
| 315 leaderboard_html = "" if is_multimodal else build_leaderboard_html(predictor) | |
| 316 inputs_html = "" | |
| 317 ignored_features_html = "" if is_multimodal else build_ignored_features_html(predictor, df_train_full) | |
| 318 presets_hparams_html = build_presets_hparams_html(predictor) | |
| 319 notices: List[str] = [] | |
| 320 if args.threshold is not None and problem_type == "binary": | |
| 321 notices.append(f"Using decision threshold = {float(args.threshold):.3f} on Test.") | |
| 322 warnings_html = build_warnings_html([], notices) | |
| 323 repro_html = build_reproducibility_html(args, {}, getattr(predictor, "path", None)) | |
| 324 | |
| 325 transparency_blocks = "\n".join( | |
| 326 [ | |
| 327 leaderboard_html, | |
| 328 inputs_html, | |
| 329 ignored_features_html, | |
| 330 presets_hparams_html, | |
| 331 warnings_html, | |
| 332 repro_html, | |
| 333 ] | |
| 334 ) | |
| 335 | |
| 336 try: | |
| 337 feature_text = build_feature_html(predictor, df_test, label_col, tempfile.mkdtemp(), args.random_seed) if df_test is not None else "" | |
| 338 except Exception: | |
| 339 feature_text = "<p>Feature analysis unavailable for this model.</p>" | |
| 340 | |
| 341 full_html = assemble_full_html_report( | |
| 342 summary_html, | |
| 343 train_html, | |
| 344 test_html_filled, | |
| 345 plots, | |
| 346 feature_text + transparency_blocks, | |
| 347 ) | |
| 348 with open(args.output_html, "w") as f: | |
| 349 f.write(full_html) | |
| 350 logger.info(f"Wrote HTML report → {args.output_html}") | |
| 351 | |
| 352 pred_path = _write_predictor_path(predictor) | |
| 353 _copy_config_if_available(pred_path, args.output_config) | |
| 354 | |
| 355 outputs_to_check = [ | |
| 356 (args.output_json, "JSON results"), | |
| 357 (args.output_html, "HTML report"), | |
| 358 ] | |
| 359 if args.output_config: | |
| 360 outputs_to_check.append((args.output_config, "AutoGluon config")) | |
| 361 verify_outputs(outputs_to_check) | |
| 362 | |
| 363 | |
| 364 def get_html_template() -> str: | |
| 365 """ | |
| 366 Returns the opening HTML, <head> (with CSS/JS), and opens <body> + .container. | |
| 367 Includes: | |
| 368 - Base styling for layout and tables | |
| 369 - Sortable table headers with 3-state arrows (none ⇅, asc ↑, desc ↓) | |
| 370 - A scroll helper class (.scroll-rows-30) that approximates ~30 visible rows | |
| 371 - A guarded script so initializing runs only once even if injected twice | |
| 372 """ | |
| 373 return """ | |
| 374 <!DOCTYPE html> | |
| 375 <html> | |
| 376 <head> | |
| 377 <meta charset="UTF-8"> | |
| 378 <title>Galaxy-Ludwig Report</title> | |
| 379 <style> | |
| 380 body { | |
| 381 font-family: Arial, sans-serif; | |
| 382 margin: 0; | |
| 383 padding: 20px; | |
| 384 background-color: #f4f4f4; | |
| 385 } | |
| 386 .container { | |
| 387 max-width: 1200px; | |
| 388 margin: auto; | |
| 389 background: white; | |
| 390 padding: 20px; | |
| 391 box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); | |
| 392 overflow-x: auto; | |
| 393 } | |
| 394 h1 { | |
| 395 text-align: center; | |
| 396 color: #333; | |
| 397 } | |
| 398 h2 { | |
| 399 border-bottom: 2px solid #4CAF50; | |
| 400 color: #4CAF50; | |
| 401 padding-bottom: 5px; | |
| 402 margin-top: 28px; | |
| 403 } | |
| 404 | |
| 405 /* baseline table setup */ | |
| 406 table { | |
| 407 border-collapse: collapse; | |
| 408 margin: 20px 0; | |
| 409 width: 100%; | |
| 410 table-layout: fixed; | |
| 411 background: #fff; | |
| 412 } | |
| 413 table, th, td { | |
| 414 border: 1px solid #ddd; | |
| 415 } | |
| 416 th, td { | |
| 417 padding: 10px; | |
| 418 text-align: center; | |
| 419 vertical-align: middle; | |
| 420 word-break: break-word; | |
| 421 white-space: normal; | |
| 422 overflow-wrap: anywhere; | |
| 423 } | |
| 424 th { | |
| 425 background-color: #4CAF50; | |
| 426 color: white; | |
| 427 } | |
| 428 | |
| 429 .plot { | |
| 430 text-align: center; | |
| 431 margin: 20px 0; | |
| 432 } | |
| 433 .plot img { | |
| 434 max-width: 100%; | |
| 435 height: auto; | |
| 436 border: 1px solid #ddd; | |
| 437 } | |
| 438 | |
| 439 /* ------------------- | |
| 440 sortable columns (3-state: none ⇅, asc ↑, desc ↓) | |
| 441 ------------------- */ | |
| 442 table.performance-summary th.sortable { | |
| 443 cursor: pointer; | |
| 444 position: relative; | |
| 445 user-select: none; | |
| 446 } | |
| 447 /* default icon space */ | |
| 448 table.performance-summary th.sortable::after { | |
| 449 content: '⇅'; | |
| 450 position: absolute; | |
| 451 right: 12px; | |
| 452 top: 50%; | |
| 453 transform: translateY(-50%); | |
| 454 font-size: 0.8em; | |
| 455 color: #eaf5ea; /* light on green */ | |
| 456 text-shadow: 0 0 1px rgba(0,0,0,0.15); | |
| 457 } | |
| 458 /* three states override the default */ | |
| 459 table.performance-summary th.sortable.sorted-none::after { content: '⇅'; color: #eaf5ea; } | |
| 460 table.performance-summary th.sortable.sorted-asc::after { content: '↑'; color: #ffffff; } | |
| 461 table.performance-summary th.sortable.sorted-desc::after { content: '↓'; color: #ffffff; } | |
| 462 | |
| 463 /* show ~30 rows with a scrollbar (tweak if you want) */ | |
| 464 .scroll-rows-30 { | |
| 465 max-height: 900px; /* ~30 rows depending on row height */ | |
| 466 overflow-y: auto; /* vertical scrollbar (“sidebar”) */ | |
| 467 overflow-x: auto; | |
| 468 } | |
| 469 | |
| 470 /* Tabs + Help button (used by build_tabbed_html) */ | |
| 471 .tabs { | |
| 472 display: flex; | |
| 473 align-items: center; | |
| 474 border-bottom: 2px solid #ccc; | |
| 475 margin-bottom: 1rem; | |
| 476 gap: 6px; | |
| 477 flex-wrap: wrap; | |
| 478 } | |
| 479 .tab { | |
| 480 padding: 10px 20px; | |
| 481 cursor: pointer; | |
| 482 border: 1px solid #ccc; | |
| 483 border-bottom: none; | |
| 484 background: #f9f9f9; | |
| 485 margin-right: 5px; | |
| 486 border-top-left-radius: 8px; | |
| 487 border-top-right-radius: 8px; | |
| 488 } | |
| 489 .tab.active { | |
| 490 background: white; | |
| 491 font-weight: bold; | |
| 492 } | |
| 493 .help-btn { | |
| 494 margin-left: auto; | |
| 495 padding: 6px 12px; | |
| 496 font-size: 0.9rem; | |
| 497 border: 1px solid #4CAF50; | |
| 498 border-radius: 4px; | |
| 499 background: #4CAF50; | |
| 500 color: white; | |
| 501 cursor: pointer; | |
| 502 } | |
| 503 .tab-content { | |
| 504 display: none; | |
| 505 padding: 20px; | |
| 506 border: 1px solid #ccc; | |
| 507 border-top: none; | |
| 508 background: #fff; | |
| 509 } | |
| 510 .tab-content.active { | |
| 511 display: block; | |
| 512 } | |
| 513 | |
| 514 /* Modal (used by get_metrics_help_modal) */ | |
| 515 .modal { | |
| 516 display: none; | |
| 517 position: fixed; | |
| 518 z-index: 9999; | |
| 519 left: 0; top: 0; | |
| 520 width: 100%; height: 100%; | |
| 521 overflow: auto; | |
| 522 background-color: rgba(0,0,0,0.4); | |
| 523 } | |
| 524 .modal-content { | |
| 525 background-color: #fefefe; | |
| 526 margin: 8% auto; | |
| 527 padding: 20px; | |
| 528 border: 1px solid #888; | |
| 529 width: 90%; | |
| 530 max-width: 900px; | |
| 531 border-radius: 8px; | |
| 532 } | |
| 533 .modal .close { | |
| 534 color: #777; | |
| 535 float: right; | |
| 536 font-size: 28px; | |
| 537 font-weight: bold; | |
| 538 line-height: 1; | |
| 539 margin-left: 8px; | |
| 540 } | |
| 541 .modal .close:hover, | |
| 542 .modal .close:focus { | |
| 543 color: black; | |
| 544 text-decoration: none; | |
| 545 cursor: pointer; | |
| 546 } | |
| 547 .metrics-guide h3 { margin-top: 20px; } | |
| 548 .metrics-guide p { margin: 6px 0; } | |
| 549 .metrics-guide ul { margin: 10px 0; padding-left: 20px; } | |
| 550 </style> | |
| 551 | |
| 552 <script> | |
| 553 // Guard to avoid double-initialization if this block is included twice | |
| 554 (function(){ | |
| 555 if (window.__perfSummarySortInit) return; | |
| 556 window.__perfSummarySortInit = true; | |
| 557 | |
| 558 function initPerfSummarySorting() { | |
| 559 // Record original order for "back to original" | |
| 560 document.querySelectorAll('table.performance-summary tbody').forEach(tbody => { | |
| 561 Array.from(tbody.rows).forEach((row, i) => { row.dataset.originalOrder = i; }); | |
| 562 }); | |
| 563 | |
| 564 const getText = td => (td?.innerText || '').trim(); | |
| 565 const cmp = (idx, asc) => (a, b) => { | |
| 566 const v1 = getText(a.children[idx]); | |
| 567 const v2 = getText(b.children[idx]); | |
| 568 const n1 = parseFloat(v1), n2 = parseFloat(v2); | |
| 569 if (!isNaN(n1) && !isNaN(n2)) return asc ? n1 - n2 : n2 - n1; // numeric | |
| 570 return asc ? v1.localeCompare(v2) : v2.localeCompare(v1); // lexical | |
| 571 }; | |
| 572 | |
| 573 document.querySelectorAll('table.performance-summary th.sortable').forEach(th => { | |
| 574 // initialize to “none” | |
| 575 th.classList.remove('sorted-asc','sorted-desc'); | |
| 576 th.classList.add('sorted-none'); | |
| 577 | |
| 578 th.addEventListener('click', () => { | |
| 579 const table = th.closest('table'); | |
| 580 const headerRow = th.parentNode; | |
| 581 const allTh = headerRow.querySelectorAll('th.sortable'); | |
| 582 const tbody = table.querySelector('tbody'); | |
| 583 | |
| 584 // Determine current state BEFORE clearing | |
| 585 const isAsc = th.classList.contains('sorted-asc'); | |
| 586 const isDesc = th.classList.contains('sorted-desc'); | |
| 587 | |
| 588 // Reset all headers in this row | |
| 589 allTh.forEach(x => x.classList.remove('sorted-asc','sorted-desc','sorted-none')); | |
| 590 | |
| 591 // Compute next state | |
| 592 let next; | |
| 593 if (!isAsc && !isDesc) { | |
| 594 next = 'asc'; | |
| 595 } else if (isAsc) { | |
| 596 next = 'desc'; | |
| 597 } else { | |
| 598 next = 'none'; | |
| 599 } | |
| 600 th.classList.add('sorted-' + next); | |
| 601 | |
| 602 // Sort rows according to the chosen state | |
| 603 const rows = Array.from(tbody.rows); | |
| 604 if (next === 'none') { | |
| 605 rows.sort((a, b) => (a.dataset.originalOrder - b.dataset.originalOrder)); | |
| 606 } else { | |
| 607 const idx = Array.from(headerRow.children).indexOf(th); | |
| 608 rows.sort(cmp(idx, next === 'asc')); | |
| 609 } | |
| 610 rows.forEach(r => tbody.appendChild(r)); | |
| 611 }); | |
| 612 }); | |
| 613 } | |
| 614 | |
| 615 // Run after DOM is ready | |
| 616 if (document.readyState === 'loading') { | |
| 617 document.addEventListener('DOMContentLoaded', initPerfSummarySorting); | |
| 618 } else { | |
| 619 initPerfSummarySorting(); | |
| 620 } | |
| 621 })(); | |
| 622 </script> | |
| 623 </head> | |
| 624 <body> | |
| 625 <div class="container"> | |
| 626 """ | |
| 627 | |
| 628 | |
| 629 def get_html_closing(): | |
| 630 """Closes .container, body, and html.""" | |
| 631 return """ | |
| 632 </div> | |
| 633 </body> | |
| 634 </html> | |
| 635 """ | |
| 636 | |
| 637 | |
| 638 def build_tabbed_html( | |
| 639 summary_html: str, | |
| 640 train_html: str, | |
| 641 test_html: str, | |
| 642 feature_html: str, | |
| 643 explainer_html: Optional[str] = None, | |
| 644 ) -> str: | |
| 645 """ | |
| 646 Renders the tab headers, contents, and JS to switch tabs. | |
| 647 """ | |
| 648 tabs = [ | |
| 649 '<div class="tabs">', | |
| 650 '<div class="tab active" onclick="showTab(\'summary\')">Model Metric Summary and Config</div>', | |
| 651 '<div class="tab" onclick="showTab(\'train\')">Train and Validation Summary</div>', | |
| 652 '<div class="tab" onclick="showTab(\'test\')">Test Summary</div>', | |
| 653 ] | |
| 654 if explainer_html: | |
| 655 tabs.append('<div class="tab" onclick="showTab(\'explainer\')">Explainer Plots</div>') | |
| 656 tabs.append('<button id="openMetricsHelp" class="help-btn">Help</button>') | |
| 657 tabs.append('</div>') | |
| 658 tabs_section = "\n".join(tabs) | |
| 659 | |
| 660 contents = [ | |
| 661 f'<div id="summary" class="tab-content active">{summary_html}</div>', | |
| 662 f'<div id="train" class="tab-content">{train_html}</div>', | |
| 663 f'<div id="test" class="tab-content">{test_html}</div>', | |
| 664 ] | |
| 665 if explainer_html: | |
| 666 contents.append(f'<div id="explainer" class="tab-content">{explainer_html}</div>') | |
| 667 content_section = "\n".join(contents) | |
| 668 | |
| 669 js = """ | |
| 670 <script> | |
| 671 function showTab(id) { | |
| 672 document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active')); | |
| 673 document.querySelectorAll('.tab').forEach(el => el.classList.remove('active')); | |
| 674 document.getElementById(id).classList.add('active'); | |
| 675 document.querySelector(`.tab[onclick*="${id}"]`).classList.add('active'); | |
| 676 } | |
| 677 </script> | |
| 678 """ | |
| 679 return tabs_section + "\n" + content_section + "\n" + js | |
| 680 | |
| 681 | |
| 682 def encode_image_to_base64(image_path: str) -> str: | |
| 683 """ | |
| 684 Reads an image file from disk and returns a base64-encoded string | |
| 685 for embedding directly in HTML <img> tags. | |
| 686 """ | |
| 687 try: | |
| 688 with open(image_path, "rb") as img_f: | |
| 689 return base64.b64encode(img_f.read()).decode("utf-8") | |
| 690 except Exception as e: | |
| 691 logger.error(f"Failed to encode image '{image_path}': {e}") | |
| 692 return "" | |
| 693 | |
| 694 | |
| 695 def get_model_architecture(predictor: Any) -> str: | |
| 696 """ | |
| 697 Returns a human-friendly description of the final model architecture based on the | |
| 698 MultiModalPredictor configuration (e.g., timm_image=resnet50, hf_text=bert-base-uncased). | |
| 699 """ | |
| 700 # MultiModalPredictor path: read backbones from config if available | |
| 701 archs = [] | |
| 702 for attr in ("_config", "config"): | |
| 703 cfg = getattr(predictor, attr, None) | |
| 704 try: | |
| 705 model_cfg = getattr(cfg, "model", None) | |
| 706 if model_cfg: | |
| 707 # OmegaConf-like mapping | |
| 708 for name, sub in dict(model_cfg).items(): | |
| 709 ck = None | |
| 710 # sub may be an object or a dict-like node | |
| 711 for k in ("checkpoint_name", "name", "model_name"): | |
| 712 try: | |
| 713 ck = getattr(sub, k) | |
| 714 except Exception: | |
| 715 ck = sub.get(k) if isinstance(sub, dict) else ck | |
| 716 if ck: | |
| 717 break | |
| 718 if ck: | |
| 719 archs.append(f"{name}={ck}") | |
| 720 except Exception: | |
| 721 continue | |
| 722 | |
| 723 if archs: | |
| 724 return ", ".join(archs) | |
| 725 | |
| 726 # Fallback | |
| 727 return type(predictor).__name__ | |
| 728 | |
| 729 | |
| 730 def collect_run_context(args, predictor, problem_type: str, | |
| 731 df_train: pd.DataFrame, df_val: pd.DataFrame, df_test: pd.DataFrame, | |
| 732 warnings_list: List[str], | |
| 733 notes_list: List[str]) -> Dict[str, Any]: | |
| 734 """Build a dictionary with run/system context for transparency.""" | |
| 735 # System info (best-effort; not depending on AutoGluon stdout) | |
| 736 try: | |
| 737 import psutil # optional | |
| 738 mem = psutil.virtual_memory() | |
| 739 mem_total_gb = mem.total / (1024 ** 3) | |
| 740 mem_avail_gb = mem.available / (1024 ** 3) | |
| 741 except Exception: | |
| 742 mem_total_gb = mem_avail_gb = None | |
| 743 | |
| 744 ctx = { | |
| 745 "timestamp": datetime.now().isoformat(timespec="seconds"), | |
| 746 "python_version": platform.python_version(), | |
| 747 "platform": { | |
| 748 "system": platform.system(), | |
| 749 "release": platform.release(), | |
| 750 "version": platform.version(), | |
| 751 "machine": platform.machine(), | |
| 752 }, | |
| 753 "cpu_count": os.cpu_count(), | |
| 754 "memory_total_gb": mem_total_gb, | |
| 755 "memory_available_gb": mem_avail_gb, | |
| 756 "packages": {}, | |
| 757 "problem_type": problem_type, | |
| 758 "label_column": args.label_column, | |
| 759 "time_limit_sec": args.time_limit, | |
| 760 "random_seed": args.random_seed, | |
| 761 "splits": { | |
| 762 "train_rows": int(len(df_train)), | |
| 763 "val_rows": int(len(df_val)), | |
| 764 "test_rows": int(len(df_test)), | |
| 765 "n_features_raw": int(len(df_train.columns) - 1), # minus label | |
| 766 }, | |
| 767 "warnings": warnings_list, | |
| 768 "notes": notes_list, | |
| 769 } | |
| 770 # Package versions (safe best-effort) | |
| 771 try: | |
| 772 import autogluon | |
| 773 ctx["packages"]["autogluon"] = getattr(autogluon, "__version__", "unknown") | |
| 774 except Exception: | |
| 775 pass | |
| 776 try: | |
| 777 import torch as _torch | |
| 778 ctx["packages"]["torch"] = getattr(_torch, "__version__", "unknown") | |
| 779 except Exception: | |
| 780 pass | |
| 781 try: | |
| 782 import sklearn | |
| 783 ctx["packages"]["scikit_learn"] = getattr(sklearn, "__version__", "unknown") | |
| 784 except Exception: | |
| 785 pass | |
| 786 try: | |
| 787 import numpy as _np | |
| 788 ctx["packages"]["numpy"] = getattr(_np, "__version__", "unknown") | |
| 789 except Exception: | |
| 790 pass | |
| 791 try: | |
| 792 import pandas as _pd | |
| 793 ctx["packages"]["pandas"] = getattr(_pd, "__version__", "unknown") | |
| 794 except Exception: | |
| 795 pass | |
| 796 return ctx | |
| 797 | |
| 798 | |
| 799 def build_class_balance_html( | |
| 800 df_train: Optional[pd.DataFrame], | |
| 801 label_col: str, | |
| 802 df_val: Optional[pd.DataFrame] = None, | |
| 803 df_test: Optional[pd.DataFrame] = None, | |
| 804 ) -> str: | |
| 805 """ | |
| 806 Render label counts for each available split (Train/Validation/Test). | |
| 807 """ | |
| 808 def _count_labels(frame: Optional[pd.DataFrame]) -> pd.Series: | |
| 809 if frame is None or label_col not in frame: | |
| 810 return pd.Series(dtype=int) | |
| 811 series = frame[label_col] | |
| 812 if series.dtype.kind in "ifu": | |
| 813 return pd.Series(series).value_counts(dropna=False).sort_index() | |
| 814 return pd.Series(series.astype(str)).value_counts(dropna=False) | |
| 815 | |
| 816 counts_train = _count_labels(df_train) | |
| 817 counts_val = _count_labels(df_val) | |
| 818 counts_test = _count_labels(df_test) | |
| 819 | |
| 820 labels: list[Any] = [] | |
| 821 for idx in (counts_train.index, counts_val.index, counts_test.index): | |
| 822 for label in idx: | |
| 823 if label not in labels: | |
| 824 labels.append(label) | |
| 825 | |
| 826 has_train = df_train is not None | |
| 827 has_val = df_val is not None | |
| 828 has_test = df_test is not None | |
| 829 | |
| 830 def _fmt_count(counts: pd.Series, label: Any, enabled: bool) -> str: | |
| 831 if not enabled: | |
| 832 return "—" | |
| 833 return str(int(counts.get(label, 0))) | |
| 834 | |
| 835 rows = [ | |
| 836 f"<tr><td>{_escape(label)}</td>" | |
| 837 f"<td>{_fmt_count(counts_train, label, has_train)}</td>" | |
| 838 f"<td>{_fmt_count(counts_val, label, has_val)}</td>" | |
| 839 f"<td>{_fmt_count(counts_test, label, has_test)}</td></tr>" | |
| 840 for label in labels | |
| 841 ] | |
| 842 | |
| 843 if not rows: | |
| 844 return "<p>No label distribution available.</p>" | |
| 845 | |
| 846 return f""" | |
| 847 <h3>Label Counts by Split</h3> | |
| 848 <table class="table"> | |
| 849 <thead><tr><th>Label</th><th>Train</th><th>Validation</th><th>Test</th></tr></thead> | |
| 850 <tbody> | |
| 851 {''.join(rows)} | |
| 852 </tbody> | |
| 853 </table> | |
| 854 """ | |
| 855 | |
| 856 | |
| 857 def build_leaderboard_html(predictor) -> str: | |
| 858 try: | |
| 859 lb = predictor.leaderboard(silent=True) | |
| 860 # keep common helpful columns if present | |
| 861 cols_pref = ["model", "score_val", "eval_metric", "pred_time_val", "fit_time", | |
| 862 "pred_time_val_marginal", "fit_time_marginal", "stack_level", "can_infer", "fit_order"] | |
| 863 cols = [c for c in cols_pref if c in lb.columns] or list(lb.columns) | |
| 864 return "<h3>Model Leaderboard (Validation)</h3>" + lb[cols].to_html(index=False) | |
| 865 except Exception as e: | |
| 866 return f"<h3>Model Leaderboard</h3><p>Unavailable: {_escape(e)}</p>" | |
| 867 | |
| 868 | |
| 869 def build_ignored_features_html(predictor, df_any: pd.DataFrame) -> str: | |
| 870 # MultiModalPredictor does not always expose .features(); guard accordingly. | |
| 871 used = set() | |
| 872 try: | |
| 873 used = set(predictor.features()) | |
| 874 except Exception: | |
| 875 # If we can't determine, don't emit a misleading section | |
| 876 return "" | |
| 877 raw_cols = [c for c in df_any.columns if c != getattr(predictor, "label", None)] | |
| 878 ignored = [c for c in raw_cols if c not in used] | |
| 879 if not ignored: | |
| 880 return "" | |
| 881 items = "".join(f"<li>{html.escape(c)}</li>" for c in ignored) | |
| 882 return f""" | |
| 883 <h3>Ignored / Unused Features</h3> | |
| 884 <p>The following columns were not used by the trained predictor at inference time:</p> | |
| 885 <ul>{items}</ul> | |
| 886 """ | |
| 887 | |
| 888 | |
| 889 def build_presets_hparams_html(predictor) -> str: | |
| 890 # MultiModalPredictor path | |
| 891 mm_hp = {} | |
| 892 for attr in ("_config", "config", "_fit_args"): | |
| 893 if hasattr(predictor, attr): | |
| 894 try: | |
| 895 val = getattr(predictor, attr) | |
| 896 # make it JSON-ish | |
| 897 mm_hp[attr] = str(val) | |
| 898 except Exception: | |
| 899 continue | |
| 900 hp_html = f"<pre>{html.escape(json.dumps(mm_hp, indent=2))}</pre>" if mm_hp else "<i>Unavailable</i>" | |
| 901 return f"<h3>Training Presets & Hyperparameters</h3><details open><summary>Show hyperparameters</summary>{hp_html}</details>" | |
| 902 | |
| 903 | |
| 904 def build_warnings_html(warnings_list: List[str], notes_list: List[str]) -> str: | |
| 905 if not warnings_list and not notes_list: | |
| 906 return "" | |
| 907 w_html = "".join(f"<li>{_escape(w)}</li>" for w in warnings_list) | |
| 908 n_html = "".join(f"<li>{_escape(n)}</li>" for n in notes_list) | |
| 909 return f""" | |
| 910 <h3>Warnings & Notices</h3> | |
| 911 {'<h4>Warnings</h4><ul>'+w_html+'</ul>' if warnings_list else ''} | |
| 912 {'<h4>Notices</h4><ul>'+n_html+'</ul>' if notes_list else ''} | |
| 913 """ | |
| 914 | |
| 915 | |
| 916 def build_reproducibility_html(args, ctx: Dict[str, Any], model_path: Optional[str]) -> str: | |
| 917 cmd = " ".join(_escape(x) for x in sys.argv) | |
| 918 load_snippet = "" | |
| 919 if model_path: | |
| 920 load_snippet = f"""<pre> | |
| 921 from autogluon.multimodal import MultiModalPredictor | |
| 922 predictor = MultiModalPredictor.load("{_escape(model_path)}") | |
| 923 </pre>""" | |
| 924 pkg_rows = "".join(f"<tr><td>{_escape(k)}</td><td>{_escape(v)}</td></tr>" for k, v in (ctx.get("packages") or {}).items()) | |
| 925 sys_table = f""" | |
| 926 <table class="table"> | |
| 927 <tbody> | |
| 928 <tr><th>Timestamp</th><td>{_escape(ctx.get('timestamp'))}</td></tr> | |
| 929 <tr><th>Python</th><td>{_escape(ctx.get('python_version'))}</td></tr> | |
| 930 <tr><th>Platform</th><td>{_escape(ctx.get('platform'))}</td></tr> | |
| 931 <tr><th>CPU Count</th><td>{_escape(ctx.get('cpu_count'))}</td></tr> | |
| 932 <tr><th>Memory (GB)</th><td>Total: {_escape(ctx.get('memory_total_gb'))} | Avail: {_escape(ctx.get('memory_available_gb'))}</td></tr> | |
| 933 <tr><th>Seed</th><td>{_escape(ctx.get('random_seed'))}</td></tr> | |
| 934 <tr><th>Time Limit (s)</th><td>{_escape(ctx.get('time_limit_sec'))}</td></tr> | |
| 935 </tbody> | |
| 936 </table> | |
| 937 """ | |
| 938 pkgs_table = f""" | |
| 939 <h4>Package Versions</h4> | |
| 940 <table class="table"> | |
| 941 <thead><tr><th>Package</th><th>Version</th></tr></thead> | |
| 942 <tbody>{pkg_rows}</tbody> | |
| 943 </table> | |
| 944 """ | |
| 945 return f""" | |
| 946 <h3>Reproducibility</h3> | |
| 947 <h4>Command</h4> | |
| 948 <pre>{cmd}</pre> | |
| 949 {sys_table} | |
| 950 {pkgs_table} | |
| 951 <h4>Load Trained Model</h4> | |
| 952 {load_snippet or '<i>Model path not available</i>'} | |
| 953 """ | |
| 954 | |
| 955 | |
| 956 def build_modalities_html(predictor, df_any: pd.DataFrame, label_col: str, image_col: Optional[str]) -> str: | |
| 957 """Summarize which inputs/modalities are used for MultiModalPredictor.""" | |
| 958 cols = [c for c in df_any.columns] | |
| 959 # exclude label from feature list | |
| 960 feat_cols = [c for c in cols if c != label_col] | |
| 961 # identify image vs tabular columns from args / presence | |
| 962 img_present = (image_col in df_any.columns) if image_col else False | |
| 963 tab_cols = [c for c in feat_cols if c != image_col] | |
| 964 | |
| 965 # brief lists (avoid dumping all, unless small) | |
| 966 def list_or_count(arr, max_show=20): | |
| 967 if len(arr) <= max_show: | |
| 968 items = "".join(f"<li>{html.escape(str(x))}</li>" for x in arr) | |
| 969 return f"<ul>{items}</ul>" | |
| 970 return f"<p>{len(arr)} columns</p>" | |
| 971 | |
| 972 img_block = f"<p><b>Image column:</b> {html.escape(image_col)}</p>" if img_present else "<p><b>Image column:</b> None</p>" | |
| 973 tab_block = f"<div><b>Structured columns:</b> {len(tab_cols)}{list_or_count(tab_cols, max_show=15)}</div>" | |
| 974 | |
| 975 return f""" | |
| 976 <h3>Modalities & Inputs</h3> | |
| 977 <p>This run used <b>MultiModalPredictor</b> (images + structured features).</p> | |
| 978 <p><b>Label column:</b> {html.escape(label_col)}</p> | |
| 979 {img_block} | |
| 980 {tab_block} | |
| 981 """ | |
| 982 | |
| 983 | |
| 984 def build_model_performance_summary_table( | |
| 985 train_scores: dict, | |
| 986 val_scores: dict, | |
| 987 test_scores: dict | None = None, | |
| 988 include_test: bool = True, | |
| 989 title: str | None = 'Model Performance Summary', | |
| 990 show_title: bool = True, | |
| 991 ) -> str: | |
| 992 """ | |
| 993 Returns an HTML table for metrics, optionally hiding the Test column. | |
| 994 Keys across score dicts are unioned; missing values render as '—'. | |
| 995 """ | |
| 996 def fmt(v): | |
| 997 if v is None: | |
| 998 return '—' | |
| 999 if isinstance(v, (int, float)): | |
| 1000 return f'{v:.4f}' | |
| 1001 return str(v) | |
| 1002 | |
| 1003 # Collect union of metric keys across splits | |
| 1004 metrics = set(train_scores.keys()) | set(val_scores.keys()) | (set(test_scores.keys()) if (include_test and test_scores) else set()) | |
| 1005 | |
| 1006 # Remove AG_roc_auc entirely as requested | |
| 1007 metrics.discard('AG_roc_auc') | |
| 1008 | |
| 1009 # Helper: normalize metric keys for matching preferred names | |
| 1010 def _norm(k: str) -> str: | |
| 1011 return ''.join(ch for ch in str(k).lower() if ch.isalnum()) | |
| 1012 | |
| 1013 # Preferred metrics to appear at the end in this specific order (display names): | |
| 1014 preferred_display = ['Accuracy', 'ROC-AUC', 'Precision', 'Recall', 'F1-Score', 'PR-AUC', 'Specificity', 'MCC', 'LogLoss'] | |
| 1015 # Mapping of normalized key -> display label | |
| 1016 norm_to_display = { | |
| 1017 'accuracy': 'Accuracy', | |
| 1018 'acc': 'Accuracy', | |
| 1019 'rocauc': 'ROC-AUC', | |
| 1020 'roc_auc': 'ROC-AUC', | |
| 1021 'rocaucscore': 'ROC-AUC', | |
| 1022 'precision': 'Precision', | |
| 1023 'prec': 'Precision', | |
| 1024 'recall': 'Recall', | |
| 1025 'recallsensitivitytpr': 'Recall', | |
| 1026 'f1': 'F1-Score', | |
| 1027 'f1score': 'F1-Score', | |
| 1028 'pr_auc': 'PR-AUC', | |
| 1029 'prauc': 'PR-AUC', | |
| 1030 'averageprecision': 'PR-AUC', | |
| 1031 'specificity': 'Specificity', | |
| 1032 'tnr': 'Specificity', | |
| 1033 'mcc': 'MCC', | |
| 1034 'logloss': 'LogLoss', | |
| 1035 'crossentropy': 'LogLoss', | |
| 1036 } | |
| 1037 | |
| 1038 # Build ordered list: all non-preferred metrics sorted alphabetically, then preferred metrics in the requested order if present | |
| 1039 preferred_norms = [_norm(x) for x in preferred_display] | |
| 1040 all_metrics = list(metrics) | |
| 1041 # Partition | |
| 1042 preferred_present = [] | |
| 1043 others = [] | |
| 1044 for m in sorted(all_metrics): | |
| 1045 nm = _norm(m) | |
| 1046 if nm in preferred_norms or any( | |
| 1047 p in nm for p in ["rocauc", "prauc", "f1", "mcc", "logloss", "accuracy", "precision", "recall", "specificity"] | |
| 1048 ): | |
| 1049 # Defer preferred-like metrics to the end (we will place them in canonical order) | |
| 1050 preferred_present.append(m) | |
| 1051 else: | |
| 1052 others.append(m) | |
| 1053 | |
| 1054 # Now assemble final metric order: others (alpha), then preferred in exact requested order if they exist in metrics | |
| 1055 final_metrics = [] | |
| 1056 final_metrics.extend(others) | |
| 1057 for disp in preferred_display: | |
| 1058 # find any original key matching this display (by normalized mapping) | |
| 1059 target_norm = _norm(disp) | |
| 1060 found = None | |
| 1061 for m in preferred_present: | |
| 1062 if _norm(m) == target_norm or norm_to_display.get(_norm(m)) == disp or _norm(m).replace(' ', '') == target_norm: | |
| 1063 found = m | |
| 1064 break | |
| 1065 # also allow substring matches (e.g., 'roc_auc' vs 'rocauc') | |
| 1066 if target_norm in _norm(m): | |
| 1067 found = m | |
| 1068 break | |
| 1069 if found: | |
| 1070 final_metrics.append(found) | |
| 1071 | |
| 1072 metrics = final_metrics | |
| 1073 | |
| 1074 # Make all headers sortable by adding the 'sortable' class; the JS in utils.py hooks table.performance-summary | |
| 1075 header_cells = [ | |
| 1076 '<th class="sortable">Metric</th>', | |
| 1077 '<th class="sortable">Train</th>', | |
| 1078 '<th class="sortable">Validation</th>' | |
| 1079 ] | |
| 1080 if include_test and test_scores: | |
| 1081 header_cells.append('<th class="sortable">Test</th>') | |
| 1082 | |
| 1083 rows_html = [] | |
| 1084 for m in metrics: | |
| 1085 # Display label mapping: clean up common verbose names | |
| 1086 disp = m | |
| 1087 nm = _norm(m) | |
| 1088 if nm in norm_to_display: | |
| 1089 disp = norm_to_display[nm] | |
| 1090 else: | |
| 1091 # generic cleanup: replace underscores with space and remove parenthetical qualifiers | |
| 1092 disp = str(m).replace('_', ' ') | |
| 1093 disp = disp.replace('(Sensitivity/TPR)', '') | |
| 1094 disp = disp.replace('(TNR)', '') | |
| 1095 disp = disp.strip() | |
| 1096 | |
| 1097 cells = [ | |
| 1098 f'<td>{_escape(disp)}</td>', | |
| 1099 f'<td>{fmt(train_scores.get(m))}</td>', | |
| 1100 f'<td>{fmt(val_scores.get(m))}</td>', | |
| 1101 ] | |
| 1102 if include_test and test_scores: | |
| 1103 cells.append(f'<td>{fmt(test_scores.get(m))}</td>') | |
| 1104 | |
| 1105 rows_html.append('<tr>' + ''.join(cells) + '</tr>') | |
| 1106 | |
| 1107 title_html = f'<h3 style="margin-top:0">{title}</h3>' if (show_title and title) else '' | |
| 1108 | |
| 1109 table_html = f""" | |
| 1110 {title_html} | |
| 1111 <table class="performance-summary"> | |
| 1112 <thead><tr>{''.join(header_cells)}</tr></thead> | |
| 1113 <tbody>{''.join(rows_html)}</tbody> | |
| 1114 </table> | |
| 1115 """ | |
| 1116 return table_html |
