Mercurial > repos > goeckslab > image_learner
comparison ludwig_backend.py @ 15:d17e3a1b8659 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
| author | goeckslab |
|---|---|
| date | Fri, 28 Nov 2025 15:45:49 +0000 |
| parents | bcfa2e234a80 |
| children |
comparison
equal
deleted
inserted
replaced
| 14:94cd9ac4a9b1 | 15:d17e3a1b8659 |
|---|---|
| 29 TEST_STATISTICS_FILE_NAME, | 29 TEST_STATISTICS_FILE_NAME, |
| 30 TRAIN_SET_METADATA_FILE_NAME, | 30 TRAIN_SET_METADATA_FILE_NAME, |
| 31 ) | 31 ) |
| 32 from ludwig.utils.data_utils import get_split_path | 32 from ludwig.utils.data_utils import get_split_path |
| 33 from metaformer_setup import get_visualizations_registry, META_DEFAULT_CFGS | 33 from metaformer_setup import get_visualizations_registry, META_DEFAULT_CFGS |
| 34 from plotly_plots import build_classification_plots | 34 from plotly_plots import ( |
| 35 build_classification_plots, | |
| 36 build_prediction_diagnostics, | |
| 37 build_regression_test_plots, | |
| 38 build_regression_train_val_plots, | |
| 39 build_train_validation_plots, | |
| 40 ) | |
| 35 from utils import detect_output_type, extract_metrics_from_json | 41 from utils import detect_output_type, extract_metrics_from_json |
| 36 | 42 |
| 37 logger = logging.getLogger("ImageLearner") | 43 logger = logging.getLogger("ImageLearner") |
| 38 | 44 |
| 39 | 45 |
| 69 ... | 75 ... |
| 70 | 76 |
| 71 | 77 |
| 72 class LudwigDirectBackend: | 78 class LudwigDirectBackend: |
| 73 """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" | 79 """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" |
| 80 | |
| 81 _torchvision_patched = False | |
| 74 | 82 |
| 75 def _detect_image_dimensions(self, image_zip_path: str) -> Tuple[int, int]: | 83 def _detect_image_dimensions(self, image_zip_path: str) -> Tuple[int, int]: |
| 76 """Detect image dimensions from the first image in the dataset.""" | 84 """Detect image dimensions from the first image in the dataset.""" |
| 77 try: | 85 try: |
| 78 import zipfile | 86 import zipfile |
| 342 image_feat["preprocessing"]["infer_image_max_height"] = height | 350 image_feat["preprocessing"]["infer_image_max_height"] = height |
| 343 image_feat["preprocessing"]["infer_image_max_width"] = width | 351 image_feat["preprocessing"]["infer_image_max_width"] = width |
| 344 logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions") | 352 logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions") |
| 345 except (ValueError, IndexError): | 353 except (ValueError, IndexError): |
| 346 logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing") | 354 logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing") |
| 355 | |
| 356 def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]: | |
| 357 """Pick a validation metric that Ludwig will accept for the resolved task.""" | |
| 358 default_map = { | |
| 359 "regression": "pearson_r", | |
| 360 "binary": "roc_auc", | |
| 361 "category": "accuracy", | |
| 362 } | |
| 363 allowed_map = { | |
| 364 "regression": { | |
| 365 "pearson_r", | |
| 366 "mean_absolute_error", | |
| 367 "mean_squared_error", | |
| 368 "root_mean_squared_error", | |
| 369 "mean_absolute_percentage_error", | |
| 370 "r2", | |
| 371 "explained_variance", | |
| 372 "loss", | |
| 373 }, | |
| 374 # Ludwig rejects f1 and balanced_accuracy for binary outputs; keep to known-safe set. | |
| 375 "binary": { | |
| 376 "roc_auc", | |
| 377 "accuracy", | |
| 378 "precision", | |
| 379 "recall", | |
| 380 "specificity", | |
| 381 "log_loss", | |
| 382 "loss", | |
| 383 }, | |
| 384 "category": { | |
| 385 "accuracy", | |
| 386 "balanced_accuracy", | |
| 387 "precision", | |
| 388 "recall", | |
| 389 "f1", | |
| 390 "specificity", | |
| 391 "log_loss", | |
| 392 "loss", | |
| 393 }, | |
| 394 } | |
| 395 alias_map = { | |
| 396 "regression": { | |
| 397 "mae": "mean_absolute_error", | |
| 398 "mse": "mean_squared_error", | |
| 399 "rmse": "root_mean_squared_error", | |
| 400 "mape": "mean_absolute_percentage_error", | |
| 401 }, | |
| 402 } | |
| 403 | |
| 404 default_metric = default_map.get(task) | |
| 405 allowed = allowed_map.get(task, set()) | |
| 406 metric = requested or default_metric | |
| 407 | |
| 408 if metric is None: | |
| 409 return None | |
| 410 | |
| 411 metric = alias_map.get(task, {}).get(metric, metric) | |
| 412 | |
| 413 if metric not in allowed: | |
| 414 if requested: | |
| 415 logger.warning( | |
| 416 f"Validation metric '{requested}' is not supported for {task} outputs; using '{default_metric}' instead." | |
| 417 ) | |
| 418 metric = default_metric | |
| 419 return metric | |
| 420 | |
| 347 if task_type == "regression": | 421 if task_type == "regression": |
| 348 output_feat = { | 422 output_feat = { |
| 349 "name": LABEL_COLUMN_NAME, | 423 "name": LABEL_COLUMN_NAME, |
| 350 "type": "number", | 424 "type": "number", |
| 351 "decoder": {"type": "regressor"}, | 425 "decoder": {"type": "regressor"}, |
| 352 "loss": {"type": "mean_squared_error"}, | 426 "loss": {"type": "mean_squared_error"}, |
| 353 } | 427 } |
| 354 val_metric = config_params.get("validation_metric", "mean_squared_error") | 428 val_metric = _resolve_validation_metric("regression", config_params.get("validation_metric")) |
| 355 | 429 |
| 356 else: | 430 else: |
| 357 if num_unique_labels == 2: | 431 if num_unique_labels == 2: |
| 358 output_feat = { | 432 output_feat = { |
| 359 "name": LABEL_COLUMN_NAME, | 433 "name": LABEL_COLUMN_NAME, |
| 366 output_feat = { | 440 output_feat = { |
| 367 "name": LABEL_COLUMN_NAME, | 441 "name": LABEL_COLUMN_NAME, |
| 368 "type": "category", | 442 "type": "category", |
| 369 "loss": {"type": "softmax_cross_entropy"}, | 443 "loss": {"type": "softmax_cross_entropy"}, |
| 370 } | 444 } |
| 371 val_metric = None | 445 val_metric = _resolve_validation_metric( |
| 446 "binary" if num_unique_labels == 2 else "category", | |
| 447 config_params.get("validation_metric"), | |
| 448 ) | |
| 372 | 449 |
| 373 conf: Dict[str, Any] = { | 450 conf: Dict[str, Any] = { |
| 374 "model_type": "ecd", | 451 "model_type": "ecd", |
| 375 "input_features": [image_feat], | 452 "input_features": [image_feat], |
| 376 "output_features": [output_feat], | 453 "output_features": [output_feat], |
| 378 "trainer": { | 455 "trainer": { |
| 379 "epochs": epochs, | 456 "epochs": epochs, |
| 380 "early_stop": early_stop, | 457 "early_stop": early_stop, |
| 381 "batch_size": batch_size_cfg, | 458 "batch_size": batch_size_cfg, |
| 382 "learning_rate": learning_rate, | 459 "learning_rate": learning_rate, |
| 383 # only set validation_metric for regression | 460 # set validation_metric when provided |
| 384 **({"validation_metric": val_metric} if val_metric else {}), | 461 **({"validation_metric": val_metric} if val_metric else {}), |
| 385 }, | 462 }, |
| 386 "preprocessing": { | 463 "preprocessing": { |
| 387 "split": split_config, | 464 "split": split_config, |
| 388 "num_processes": num_processes, | 465 "num_processes": num_processes, |
| 400 "LudwigDirectBackend: Failed to serialize YAML.", | 477 "LudwigDirectBackend: Failed to serialize YAML.", |
| 401 exc_info=True, | 478 exc_info=True, |
| 402 ) | 479 ) |
| 403 raise | 480 raise |
| 404 | 481 |
| 482 def _patch_torchvision_download(self) -> None: | |
| 483 """ | |
| 484 Torchvision weight downloads sometimes fail checksum validation behind | |
| 485 corporate proxies that rewrite binaries. Skip hash checking to allow | |
| 486 pre-trained weights to load in those environments. | |
| 487 """ | |
| 488 if LudwigDirectBackend._torchvision_patched: | |
| 489 return | |
| 490 try: | |
| 491 import torch.hub as torch_hub | |
| 492 | |
| 493 original = torch_hub.load_state_dict_from_url | |
| 494 original_download = torch_hub.download_url_to_file | |
| 495 | |
| 496 def _no_hash(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None): | |
| 497 return original( | |
| 498 url, | |
| 499 model_dir=model_dir, | |
| 500 map_location=map_location, | |
| 501 progress=progress, | |
| 502 check_hash=False, | |
| 503 file_name=file_name, | |
| 504 ) | |
| 505 | |
| 506 def _download_no_hash(url, dst, hash_prefix=None, progress=True): | |
| 507 # Torchvision's download_url_to_file signature does not accept check_hash in older versions. | |
| 508 return original_download(url, dst, hash_prefix=None, progress=progress) | |
| 509 | |
| 510 torch_hub.load_state_dict_from_url = _no_hash # type: ignore[assignment] | |
| 511 torch_hub.download_url_to_file = _download_no_hash # type: ignore[assignment] | |
| 512 LudwigDirectBackend._torchvision_patched = True | |
| 513 logger.info("Disabled torchvision weight hash verification to avoid proxy-corrupted downloads.") | |
| 514 except Exception as exc: | |
| 515 logger.warning(f"Could not patch torchvision download hash check: {exc}") | |
| 516 | |
| 405 def run_experiment( | 517 def run_experiment( |
| 406 self, | 518 self, |
| 407 dataset_path: Path, | 519 dataset_path: Path, |
| 408 config_path: Path, | 520 config_path: Path, |
| 409 output_dir: Path, | 521 output_dir: Path, |
| 410 random_seed: int = 42, | 522 random_seed: int = 42, |
| 411 ) -> None: | 523 ) -> None: |
| 412 """Invoke Ludwig's internal experiment_cli function to run the experiment.""" | 524 """Invoke Ludwig's internal experiment_cli function to run the experiment.""" |
| 413 logger.info("LudwigDirectBackend: Starting experiment execution.") | 525 logger.info("LudwigDirectBackend: Starting experiment execution.") |
| 526 | |
| 527 # Avoid strict hash validation for torchvision weights (common in proxied environments) | |
| 528 self._patch_torchvision_download() | |
| 414 | 529 |
| 415 try: | 530 try: |
| 416 from ludwig.experiment import experiment_cli | 531 from ludwig.experiment import experiment_cli |
| 417 except ImportError as e: | 532 except ImportError as e: |
| 418 logger.error( | 533 logger.error( |
| 504 | 619 |
| 505 def generate_plots(self, output_dir: Path) -> None: | 620 def generate_plots(self, output_dir: Path) -> None: |
| 506 """Generate all registered Ludwig visualizations for the latest experiment run.""" | 621 """Generate all registered Ludwig visualizations for the latest experiment run.""" |
| 507 logger.info("Generating all Ludwig visualizations…") | 622 logger.info("Generating all Ludwig visualizations…") |
| 508 | 623 |
| 624 # Keep only lightweight plots (drop compare_performance/roc_curves) | |
| 509 test_plots = { | 625 test_plots = { |
| 510 "compare_performance", | |
| 511 "compare_classifiers_performance_from_prob", | |
| 512 "compare_classifiers_performance_from_pred", | |
| 513 "compare_classifiers_performance_changing_k", | |
| 514 "compare_classifiers_multiclass_multimetric", | |
| 515 "compare_classifiers_predictions", | |
| 516 "confidence_thresholding_2thresholds_2d", | |
| 517 "confidence_thresholding_2thresholds_3d", | |
| 518 "confidence_thresholding", | |
| 519 "confidence_thresholding_data_vs_acc", | |
| 520 "binary_threshold_vs_metric", | |
| 521 "roc_curves", | |
| 522 "roc_curves_from_test_statistics", | 626 "roc_curves_from_test_statistics", |
| 523 "calibration_1_vs_all", | |
| 524 "calibration_multiclass", | |
| 525 "confusion_matrix", | 627 "confusion_matrix", |
| 526 "frequency_vs_f1", | |
| 527 } | 628 } |
| 528 train_plots = { | 629 train_plots = { |
| 529 "learning_curves", | 630 "learning_curves", |
| 530 "compare_classifiers_performance_subset", | 631 "compare_classifiers_performance_subset", |
| 531 } | 632 } |
| 625 key=lambda p: p.stat().st_mtime, | 726 key=lambda p: p.stat().st_mtime, |
| 626 ) | 727 ) |
| 627 if not exp_dirs: | 728 if not exp_dirs: |
| 628 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") | 729 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") |
| 629 exp_dir = exp_dirs[-1] | 730 exp_dir = exp_dirs[-1] |
| 731 train_set_metadata_path = exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME | |
| 732 label_metadata_path = config.get("label_column_data_path") | |
| 733 if label_metadata_path: | |
| 734 label_metadata_path = Path(label_metadata_path) | |
| 735 | |
| 736 # Pull additional config details from description.json if available | |
| 737 config_for_summary = dict(config) | |
| 738 if "target_column" not in config_for_summary or not config_for_summary.get("target_column"): | |
| 739 config_for_summary["target_column"] = LABEL_COLUMN_NAME | |
| 740 desc_path = exp_dir / DESCRIPTION_FILE_NAME | |
| 741 if desc_path.exists(): | |
| 742 try: | |
| 743 with open(desc_path, "r") as f: | |
| 744 desc_cfg = json.load(f).get("config", {}) | |
| 745 encoder_cfg = ( | |
| 746 desc_cfg.get("input_features", [{}])[0].get("encoder", {}) | |
| 747 if isinstance(desc_cfg.get("input_features", [{}]), list) | |
| 748 else {} | |
| 749 ) | |
| 750 output_cfg = ( | |
| 751 desc_cfg.get("output_features", [{}])[0] | |
| 752 if isinstance(desc_cfg.get("output_features", [{}]), list) | |
| 753 else {} | |
| 754 ) | |
| 755 trainer_cfg = desc_cfg.get("trainer", {}) if isinstance(desc_cfg, dict) else {} | |
| 756 loss_cfg = output_cfg.get("loss", {}) if isinstance(output_cfg, dict) else {} | |
| 757 opt_cfg = trainer_cfg.get("optimizer", {}) if isinstance(trainer_cfg, dict) else {} | |
| 758 clip_cfg = trainer_cfg.get("gradient_clipping", {}) if isinstance(trainer_cfg, dict) else {} | |
| 759 | |
| 760 arch_type = encoder_cfg.get("type") | |
| 761 arch_variant = encoder_cfg.get("model_variant") | |
| 762 arch_name = None | |
| 763 if arch_type: | |
| 764 arch_base = str(arch_type).replace("_", " ").title() | |
| 765 arch_name = f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base | |
| 766 | |
| 767 summary_fields = { | |
| 768 "architecture": arch_name, | |
| 769 "model_variant": arch_variant, | |
| 770 "pretrained": encoder_cfg.get("use_pretrained"), | |
| 771 "trainable": encoder_cfg.get("trainable"), | |
| 772 "target_column": output_cfg.get("column"), | |
| 773 "task_type": output_cfg.get("type"), | |
| 774 "validation_metric": trainer_cfg.get("validation_metric"), | |
| 775 "loss_function": loss_cfg.get("type"), | |
| 776 "threshold": output_cfg.get("threshold"), | |
| 777 "total_epochs": trainer_cfg.get("epochs"), | |
| 778 "early_stop": trainer_cfg.get("early_stop"), | |
| 779 "batch_size": trainer_cfg.get("batch_size"), | |
| 780 "optimizer": opt_cfg.get("type"), | |
| 781 "learning_rate": trainer_cfg.get("learning_rate"), | |
| 782 "random_seed": desc_cfg.get("random_seed") or config.get("random_seed"), | |
| 783 "use_mixed_precision": trainer_cfg.get("use_mixed_precision"), | |
| 784 "gradient_clipping": clip_cfg.get("clipglobalnorm"), | |
| 785 } | |
| 786 for k, v in summary_fields.items(): | |
| 787 if v is None: | |
| 788 continue | |
| 789 # Do not override user-passed target/image column names in config | |
| 790 if k in {"target_column", "image_column"} and config_for_summary.get(k): | |
| 791 continue | |
| 792 config_for_summary.setdefault(k, v) | |
| 793 except Exception as e: # pragma: no cover - defensive | |
| 794 logger.warning(f"Could not merge description.json into config summary: {e}") | |
| 630 | 795 |
| 631 base_viz_dir = exp_dir / "visualizations" | 796 base_viz_dir = exp_dir / "visualizations" |
| 632 train_viz_dir = base_viz_dir / "train" | 797 train_viz_dir = base_viz_dir / "train" |
| 633 test_viz_dir = base_viz_dir / "test" | 798 test_viz_dir = base_viz_dir / "test" |
| 634 | 799 |
| 696 html += f"<h1>{title}</h1>" | 861 html += f"<h1>{title}</h1>" |
| 697 | 862 |
| 698 metrics_html = "" | 863 metrics_html = "" |
| 699 train_val_metrics_html = "" | 864 train_val_metrics_html = "" |
| 700 test_metrics_html = "" | 865 test_metrics_html = "" |
| 866 output_type = None | |
| 867 train_stats_path = exp_dir / "training_statistics.json" | |
| 868 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME | |
| 701 try: | 869 try: |
| 702 train_stats_path = exp_dir / "training_statistics.json" | |
| 703 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME | |
| 704 if train_stats_path.exists() and test_stats_path.exists(): | 870 if train_stats_path.exists() and test_stats_path.exists(): |
| 705 with open(train_stats_path) as f: | 871 with open(train_stats_path) as f: |
| 706 train_stats = json.load(f) | 872 train_stats = json.load(f) |
| 707 with open(test_stats_path) as f: | 873 with open(test_stats_path) as f: |
| 708 test_stats = json.load(f) | 874 test_stats = json.load(f) |
| 723 | 889 |
| 724 config_html = "" | 890 config_html = "" |
| 725 training_progress = self.get_training_process(output_dir) | 891 training_progress = self.get_training_process(output_dir) |
| 726 try: | 892 try: |
| 727 config_html = format_config_table_html( | 893 config_html = format_config_table_html( |
| 728 config, split_info, training_progress, output_type | 894 config_for_summary, split_info, training_progress, output_type |
| 729 ) | 895 ) |
| 730 except Exception as e: | 896 except Exception as e: |
| 731 logger.warning(f"Could not load config for HTML report: {e}") | 897 logger.warning(f"Could not load config for HTML report: {e}") |
| 898 config_html = ( | |
| 899 "<h2 style='text-align: center;'>Model and Training Summary</h2>" | |
| 900 "<p style='text-align:center; color:#666;'>Configuration details unavailable.</p>" | |
| 901 ) | |
| 902 if not config_html: | |
| 903 config_html = ( | |
| 904 "<h2 style='text-align: center;'>Model and Training Summary</h2>" | |
| 905 "<p style='text-align:center; color:#666;'>No configuration details found.</p>" | |
| 906 ) | |
| 732 | 907 |
| 733 # ---------- image rendering with exclusions ---------- | 908 # ---------- image rendering with exclusions ---------- |
| 734 def render_img_section( | 909 def render_img_section( |
| 735 title: str, | 910 title: str, |
| 736 dir_path: Path, | 911 dir_path: Path, |
| 774 imgs = [ | 949 imgs = [ |
| 775 img | 950 img |
| 776 for img in imgs | 951 for img in imgs |
| 777 if img.name not in default_exclude | 952 if img.name not in default_exclude |
| 778 and img.name not in exclude_names | 953 and img.name not in exclude_names |
| 954 and not ( | |
| 955 "learning_curves" in img.stem | |
| 956 and "loss" in img.stem | |
| 957 and "label" in img.stem | |
| 958 ) | |
| 779 ] | 959 ] |
| 780 | 960 |
| 781 if not imgs: | 961 if not imgs: |
| 782 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" | 962 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" |
| 783 | 963 |
| 800 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' | 980 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' |
| 801 f"</div>" | 981 f"</div>" |
| 802 ) | 982 ) |
| 803 return html_section | 983 return html_section |
| 804 | 984 |
| 805 tab1_content = config_html + metrics_html | 985 # Show performance first, then config |
| 986 tab1_content = metrics_html + config_html | |
| 806 | 987 |
| 807 tab2_content = train_val_metrics_html + render_img_section( | 988 tab2_content = train_val_metrics_html + render_img_section( |
| 808 "Training and Validation Visualizations", | 989 "Training and Validation Visualizations", |
| 809 train_viz_dir, | 990 train_viz_dir, |
| 810 output_type, | 991 output_type, |
| 813 "roc_curves_from_prediction_statistics.png", | 994 "roc_curves_from_prediction_statistics.png", |
| 814 "precision_recall_curves_from_prediction_statistics.png", | 995 "precision_recall_curves_from_prediction_statistics.png", |
| 815 "precision_recall_curve.png", | 996 "precision_recall_curve.png", |
| 816 }, | 997 }, |
| 817 ) | 998 ) |
| 999 if train_stats_path.exists(): | |
| 1000 try: | |
| 1001 if output_type == "regression": | |
| 1002 tv_plots = build_regression_train_val_plots(str(train_stats_path)) | |
| 1003 else: | |
| 1004 tv_plots = build_train_validation_plots(str(train_stats_path)) | |
| 1005 for plot in tv_plots: | |
| 1006 tab2_content += ( | |
| 1007 f"<h2 style='text-align: center;'>{plot['title']}</h2>" | |
| 1008 f"<div class='plotly-center'>{plot['html']}</div>" | |
| 1009 ) | |
| 1010 if tv_plots: | |
| 1011 logger.info(f"Generated {len(tv_plots)} train/val diagnostic plots") | |
| 1012 except Exception as e: | |
| 1013 logger.warning(f"Could not generate train/val plots: {e}") | |
| 818 | 1014 |
| 819 # --- Predictions vs Ground Truth table (REGRESSION ONLY) --- | 1015 # --- Predictions vs Ground Truth table (REGRESSION ONLY) --- |
| 820 preds_section = "" | 1016 preds_section = "" |
| 821 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME | 1017 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME |
| 822 if output_type == "regression" and parquet_path.exists(): | 1018 if output_type == "regression" and parquet_path.exists(): |
| 847 preds_section = ( | 1043 preds_section = ( |
| 848 "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>" | 1044 "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>" |
| 849 "<div class='preds-controls'>" | 1045 "<div class='preds-controls'>" |
| 850 "<button id='downloadPredsCsv' class='download-btn'>Download CSV</button>" | 1046 "<button id='downloadPredsCsv' class='download-btn'>Download CSV</button>" |
| 851 "</div>" | 1047 "</div>" |
| 852 "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:900px; margin-bottom:20px;'>" | 1048 "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:350px; margin-bottom:20px;'>" |
| 853 + preds_html | 1049 + preds_html |
| 854 + "</div>" | 1050 + "</div>" |
| 855 ) | 1051 ) |
| 856 except Exception as e: | 1052 except Exception as e: |
| 857 logger.warning(f"Could not build Predictions vs GT table: {e}") | 1053 logger.warning(f"Could not build Predictions vs GT table: {e}") |
| 858 | 1054 |
| 859 tab3_content = test_metrics_html + preds_section | 1055 tab3_content = test_metrics_html + preds_section |
| 1056 test_plotly_added = False | |
| 1057 | |
| 1058 if output_type == "regression" and train_stats_path.exists(): | |
| 1059 try: | |
| 1060 test_plots = build_regression_test_plots(str(train_stats_path)) | |
| 1061 for plot in test_plots: | |
| 1062 tab3_content += ( | |
| 1063 f"<h2 style='text-align: center;'>{plot['title']}</h2>" | |
| 1064 f"<div class='plotly-center'>{plot['html']}</div>" | |
| 1065 ) | |
| 1066 if test_plots: | |
| 1067 test_plotly_added = True | |
| 1068 logger.info(f"Generated {len(test_plots)} regression test plots") | |
| 1069 except Exception as e: | |
| 1070 logger.warning(f"Could not generate regression test plots: {e}") | |
| 860 | 1071 |
| 861 if output_type in ("binary", "category") and test_stats_path.exists(): | 1072 if output_type in ("binary", "category") and test_stats_path.exists(): |
| 862 try: | 1073 try: |
| 863 interactive_plots = build_classification_plots( | 1074 interactive_plots = build_classification_plots( |
| 864 str(test_stats_path), | 1075 str(test_stats_path), |
| 865 str(train_stats_path) if train_stats_path.exists() else None, | 1076 str(train_stats_path) if train_stats_path.exists() else None, |
| 1077 metadata_csv_path=str(label_metadata_path) | |
| 1078 if label_metadata_path and label_metadata_path.exists() | |
| 1079 else None, | |
| 1080 train_set_metadata_path=str(train_set_metadata_path) | |
| 1081 if train_set_metadata_path.exists() | |
| 1082 else None, | |
| 866 ) | 1083 ) |
| 867 for plot in interactive_plots: | 1084 for plot in interactive_plots: |
| 868 tab3_content += ( | 1085 tab3_content += ( |
| 869 f"<h2 style='text-align: center;'>{plot['title']}</h2>" | 1086 f"<h2 style='text-align: center;'>{plot['title']}</h2>" |
| 870 f"<div class='plotly-center'>{plot['html']}</div>" | 1087 f"<div class='plotly-center'>{plot['html']}</div>" |
| 871 ) | 1088 ) |
| 1089 if interactive_plots: | |
| 1090 test_plotly_added = True | |
| 872 logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots") | 1091 logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots") |
| 873 except Exception as e: | 1092 except Exception as e: |
| 874 logger.warning(f"Could not generate Plotly plots: {e}") | 1093 logger.warning(f"Could not generate Plotly plots: {e}") |
| 875 | 1094 |
| 1095 # Add prediction diagnostics from predictions.csv | |
| 1096 predictions_csv_path = exp_dir / "predictions.csv" | |
| 1097 try: | |
| 1098 diag_plots = build_prediction_diagnostics( | |
| 1099 str(predictions_csv_path), | |
| 1100 label_data_path=str(config.get("label_column_data_path")) | |
| 1101 if config.get("label_column_data_path") | |
| 1102 else None, | |
| 1103 threshold=config.get("threshold"), | |
| 1104 ) | |
| 1105 for plot in diag_plots: | |
| 1106 tab3_content += ( | |
| 1107 f"<h2 style='text-align: center;'>{plot['title']}</h2>" | |
| 1108 f"<div class='plotly-center'>{plot['html']}</div>" | |
| 1109 ) | |
| 1110 if diag_plots: | |
| 1111 test_plotly_added = True | |
| 1112 logger.info(f"Generated {len(diag_plots)} prediction diagnostic plots") | |
| 1113 except Exception as e: | |
| 1114 logger.warning(f"Could not generate prediction diagnostics: {e}") | |
| 1115 | |
| 1116 # Fallback: include static PNGs if no interactive plots were added | |
| 1117 if not test_plotly_added: | |
| 1118 tab3_content += render_img_section( | |
| 1119 "Test Visualizations (PNG fallback)", | |
| 1120 test_viz_dir, | |
| 1121 output_type, | |
| 1122 ) | |
| 1123 | |
| 876 # Add static TEST PNGs (with default dedupe/exclusions) | 1124 # Add static TEST PNGs (with default dedupe/exclusions) |
| 877 tab3_content += render_img_section( | |
| 878 "Test Visualizations", test_viz_dir, output_type | |
| 879 ) | |
| 880 | |
| 881 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) | 1125 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) |
| 882 modal_html = get_metrics_help_modal() | 1126 modal_html = get_metrics_help_modal() |
| 883 html += tabbed_html + modal_html + get_html_closing() | 1127 html += tabbed_html + modal_html + get_html_closing() |
| 884 | 1128 |
| 885 try: | 1129 try: |
