Mercurial > repos > goeckslab > image_learner
comparison ludwig_backend.py @ 17:db9be962dc13 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
| author | goeckslab |
|---|---|
| date | Wed, 10 Dec 2025 00:24:13 +0000 |
| parents | 8729f69e9207 |
| children |
comparison
equal
deleted
inserted
replaced
| 16:8729f69e9207 | 17:db9be962dc13 |
|---|---|
| 1 import inspect | |
| 1 import json | 2 import json |
| 2 import logging | 3 import logging |
| 3 import os | 4 import os |
| 4 from pathlib import Path | 5 from pathlib import Path |
| 5 from typing import Any, Dict, Optional, Protocol, Tuple | 6 from typing import Any, Dict, List, Optional, Protocol, Tuple |
| 6 | 7 |
| 7 import pandas as pd | 8 import pandas as pd |
| 8 import pandas.api.types as ptypes | 9 import pandas.api.types as ptypes |
| 9 import yaml | 10 import yaml |
| 10 from constants import ( | 11 from constants import ( |
| 15 ) | 16 ) |
| 16 from html_structure import ( | 17 from html_structure import ( |
| 17 build_tabbed_html, | 18 build_tabbed_html, |
| 18 encode_image_to_base64, | 19 encode_image_to_base64, |
| 19 format_config_table_html, | 20 format_config_table_html, |
| 21 format_dataset_overview_table, | |
| 20 format_stats_table_html, | 22 format_stats_table_html, |
| 21 format_test_merged_stats_table_html, | 23 format_test_merged_stats_table_html, |
| 22 format_train_val_stats_table_html, | 24 format_train_val_stats_table_html, |
| 23 get_html_closing, | 25 get_html_closing, |
| 24 get_html_template, | 26 get_html_template, |
| 31 TRAIN_SET_METADATA_FILE_NAME, | 33 TRAIN_SET_METADATA_FILE_NAME, |
| 32 ) | 34 ) |
| 33 from ludwig.utils.data_utils import get_split_path | 35 from ludwig.utils.data_utils import get_split_path |
| 34 from metaformer_setup import get_visualizations_registry, META_DEFAULT_CFGS | 36 from metaformer_setup import get_visualizations_registry, META_DEFAULT_CFGS |
| 35 from plotly_plots import ( | 37 from plotly_plots import ( |
| 38 build_binary_threshold_plot, | |
| 36 build_classification_plots, | 39 build_classification_plots, |
| 40 build_multiclass_metric_plots, | |
| 37 build_prediction_diagnostics, | 41 build_prediction_diagnostics, |
| 38 build_regression_test_plots, | 42 build_regression_test_plots, |
| 39 build_regression_train_val_plots, | 43 build_regression_train_val_plots, |
| 40 build_train_validation_plots, | 44 build_train_validation_plots, |
| 41 ) | 45 ) |
| 265 "trainable": trainable, | 269 "trainable": trainable, |
| 266 } | 270 } |
| 267 else: | 271 else: |
| 268 encoder_config = {"type": raw_encoder} | 272 encoder_config = {"type": raw_encoder} |
| 269 | 273 |
| 274 # Set a human-friendly architecture string for reporting | |
| 275 arch_display = None | |
| 276 if is_metaformer and custom_model: | |
| 277 arch_display = str(custom_model) | |
| 278 elif isinstance(raw_encoder, dict): | |
| 279 enc_type = raw_encoder.get("type") | |
| 280 enc_variant = raw_encoder.get("model_variant") | |
| 281 if enc_type: | |
| 282 base = str(enc_type).replace("_", " ").title() | |
| 283 arch_display = f"{base} {enc_variant}" if enc_variant is not None else base | |
| 284 else: | |
| 285 arch_display = str(raw_encoder).replace("_", " ").title() | |
| 286 | |
| 287 if not arch_display: | |
| 288 arch_display = str(model_name) | |
| 289 config_params["architecture"] = arch_display | |
| 290 | |
| 270 batch_size_cfg = batch_size or "auto" | 291 batch_size_cfg = batch_size or "auto" |
| 271 | 292 |
| 272 label_column_path = config_params.get("label_column_data_path") | 293 label_column_path = config_params.get("label_column_data_path") |
| 273 label_series = None | 294 label_series = None |
| 274 label_metadata_hint = config_params.get("label_metadata") or {} | 295 label_metadata_hint = config_params.get("label_metadata") or {} |
| 341 image_feat["preprocessing"]["resize_method"] = "interpolate" # Use interpolation for better quality | 362 image_feat["preprocessing"]["resize_method"] = "interpolate" # Use interpolation for better quality |
| 342 image_feat["preprocessing"]["standardize_image"] = "imagenet1k" # Use ImageNet standardization | 363 image_feat["preprocessing"]["standardize_image"] = "imagenet1k" # Use ImageNet standardization |
| 343 # Force Ludwig to respect our dimensions by setting additional parameters | 364 # Force Ludwig to respect our dimensions by setting additional parameters |
| 344 image_feat["preprocessing"]["requires_equal_dimensions"] = False | 365 image_feat["preprocessing"]["requires_equal_dimensions"] = False |
| 345 logger.info(f"Set preprocessing dimensions for MetaFormer: {height}x{width} (infer_dimensions=True with max dimensions to allow validation)") | 366 logger.info(f"Set preprocessing dimensions for MetaFormer: {height}x{width} (infer_dimensions=True with max dimensions to allow validation)") |
| 367 config_params["image_size"] = f"{height}x{width}" | |
| 346 # Now set the encoder configuration | 368 # Now set the encoder configuration |
| 347 image_feat["encoder"] = encoder_config | 369 image_feat["encoder"] = encoder_config |
| 348 | 370 |
| 349 if config_params.get("augmentation") is not None: | 371 if config_params.get("augmentation") is not None: |
| 350 image_feat["augmentation"] = config_params["augmentation"] | 372 image_feat["augmentation"] = config_params["augmentation"] |
| 372 # but set explicit max dimensions to control the output size | 394 # but set explicit max dimensions to control the output size |
| 373 image_feat["preprocessing"]["infer_image_dimensions"] = True | 395 image_feat["preprocessing"]["infer_image_dimensions"] = True |
| 374 image_feat["preprocessing"]["infer_image_max_height"] = height | 396 image_feat["preprocessing"]["infer_image_max_height"] = height |
| 375 image_feat["preprocessing"]["infer_image_max_width"] = width | 397 image_feat["preprocessing"]["infer_image_max_width"] = width |
| 376 logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions") | 398 logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions") |
| 399 config_params["image_size"] = f"{height}x{width}" | |
| 377 except (ValueError, IndexError): | 400 except (ValueError, IndexError): |
| 378 logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing") | 401 logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing") |
| 402 elif not is_metaformer: | |
| 403 # No explicit resize provided; keep for reporting purposes | |
| 404 config_params.setdefault("image_size", "original") | |
| 379 | 405 |
| 380 def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]: | 406 def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]: |
| 381 """Pick a validation metric that Ludwig will accept for the resolved task.""" | 407 """Pick a validation metric that Ludwig will accept for the resolved task.""" |
| 382 default_map = { | 408 default_map = { |
| 383 "regression": "pearson_r", | 409 "regression": "pearson_r", |
| 469 val_metric = _resolve_validation_metric( | 495 val_metric = _resolve_validation_metric( |
| 470 "binary" if num_unique_labels == 2 else "category", | 496 "binary" if num_unique_labels == 2 else "category", |
| 471 config_params.get("validation_metric"), | 497 config_params.get("validation_metric"), |
| 472 ) | 498 ) |
| 473 | 499 |
| 500 # Propagate the resolved validation metric (including any task-based fallback or alias normalization) | |
| 501 config_params["validation_metric"] = val_metric | |
| 502 | |
| 474 conf: Dict[str, Any] = { | 503 conf: Dict[str, Any] = { |
| 475 "model_type": "ecd", | 504 "model_type": "ecd", |
| 476 "input_features": [image_feat], | 505 "input_features": [image_feat], |
| 477 "output_features": [output_feat], | 506 "output_features": [output_feat], |
| 478 "combiner": {"type": "concat"}, | 507 "combiner": {"type": "concat"}, |
| 639 df.to_csv(csv_path, index=False) | 668 df.to_csv(csv_path, index=False) |
| 640 logger.info(f"Converted Parquet to CSV: {csv_path}") | 669 logger.info(f"Converted Parquet to CSV: {csv_path}") |
| 641 except Exception as e: | 670 except Exception as e: |
| 642 logger.error(f"Error converting Parquet to CSV: {e}") | 671 logger.error(f"Error converting Parquet to CSV: {e}") |
| 643 | 672 |
| 673 @staticmethod | |
| 674 def _extract_metric_series(stats: Dict[str, Any], split: str, prefer: Optional[str] = None) -> Tuple[Optional[str], Optional[List[float]]]: | |
| 675 """Pull the first numeric metric list we can find for the requested split.""" | |
| 676 if not isinstance(stats, dict): | |
| 677 return None, None | |
| 678 | |
| 679 split_stats = stats.get(split, {}) | |
| 680 ordered_metrics: List[Tuple[str, List[float]]] = [] | |
| 681 | |
| 682 def _append_metrics(metric_map: Dict[str, Any]) -> None: | |
| 683 for metric_name, values in metric_map.items(): | |
| 684 if isinstance(values, list) and any(isinstance(v, (int, float)) for v in values): | |
| 685 ordered_metrics.append((metric_name, values)) | |
| 686 | |
| 687 if isinstance(split_stats, dict): | |
| 688 combined = split_stats.get("combined") | |
| 689 if isinstance(combined, dict): | |
| 690 _append_metrics(combined) | |
| 691 | |
| 692 for feature_name, feature_metrics in split_stats.items(): | |
| 693 if feature_name == "combined" or not isinstance(feature_metrics, dict): | |
| 694 continue | |
| 695 _append_metrics(feature_metrics) | |
| 696 | |
| 697 if prefer: | |
| 698 for metric_name, values in ordered_metrics: | |
| 699 if metric_name == prefer: | |
| 700 return metric_name, values | |
| 701 | |
| 702 return ordered_metrics[0] if ordered_metrics else (None, None) | |
| 703 | |
| 644 def generate_plots(self, output_dir: Path) -> None: | 704 def generate_plots(self, output_dir: Path) -> None: |
| 645 """Generate all registered Ludwig visualizations for the latest experiment run.""" | 705 """Generate Ludwig visualizations (train/val + test) for the latest experiment run.""" |
| 646 logger.info("Generating all Ludwig visualizations…") | 706 logger.info("Generating Ludwig visualizations (train/val + test)…") |
| 647 | 707 |
| 648 # Keep only lightweight plots (drop compare_performance/roc_curves) | 708 # Train/validation visualizations |
| 649 test_plots = { | |
| 650 "roc_curves_from_test_statistics", | |
| 651 "confusion_matrix", | |
| 652 } | |
| 653 train_plots = { | 709 train_plots = { |
| 654 "learning_curves", | 710 "learning_curves", |
| 655 "compare_classifiers_performance_subset", | 711 } |
| 712 | |
| 713 # Test visualizations (multi-class transparency) | |
| 714 test_plots = { | |
| 715 "confusion_matrix", | |
| 716 "compare_performance", | |
| 717 "compare_classifiers_multiclass_multimetric", | |
| 718 "frequency_vs_f1", | |
| 719 "confidence_thresholding", | |
| 720 "confidence_thresholding_data_vs_acc", | |
| 721 "confidence_thresholding_data_vs_acc_subset", | |
| 722 "confidence_thresholding_data_vs_acc_subset_per_class", | |
| 723 # Binary-only visualizations will still be attempted; multi-class replacements handled elsewhere | |
| 724 "binary_threshold_vs_metric", | |
| 725 "roc_curves", | |
| 726 "precision_recall_curves", | |
| 727 "calibration_1_vs_all", | |
| 728 "calibration_multiclass", | |
| 656 } | 729 } |
| 657 | 730 |
| 658 output_dir = Path(output_dir) | 731 output_dir = Path(output_dir) |
| 659 exp_dirs = sorted( | 732 exp_dirs = sorted( |
| 660 output_dir.glob("experiment_run*"), | 733 output_dir.glob("experiment_run*"), |
| 675 def _check(p: Path) -> Optional[str]: | 748 def _check(p: Path) -> Optional[str]: |
| 676 return str(p) if p.exists() else None | 749 return str(p) if p.exists() else None |
| 677 | 750 |
| 678 training_stats = _check(exp_dir / "training_statistics.json") | 751 training_stats = _check(exp_dir / "training_statistics.json") |
| 679 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) | 752 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) |
| 680 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) | |
| 681 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) | 753 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) |
| 682 | 754 |
| 683 dataset_path = None | 755 dataset_path = None |
| 684 split_file = None | 756 split_file = None |
| 685 desc = exp_dir / DESCRIPTION_FILE_NAME | 757 desc = exp_dir / DESCRIPTION_FILE_NAME |
| 686 if desc.exists(): | 758 if desc.exists(): |
| 687 with open(desc, "r") as f: | 759 with open(desc, "r") as f: |
| 688 cfg = json.load(f) | 760 cfg = json.load(f) |
| 689 dataset_path = _check(Path(cfg.get("dataset", ""))) | 761 dataset_path = _check(Path(cfg.get("dataset", ""))) |
| 690 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) | 762 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) |
| 763 model_name = cfg.get("model_name", "model") | |
| 764 else: | |
| 765 model_name = "model" | |
| 691 | 766 |
| 692 output_feature = "" | 767 output_feature = "" |
| 693 if desc.exists(): | 768 if desc.exists(): |
| 694 try: | 769 try: |
| 695 output_feature = cfg["config"]["output_features"][0]["name"] | 770 output_feature = cfg["config"]["output_features"][0]["name"] |
| 698 if not output_feature and test_stats: | 773 if not output_feature and test_stats: |
| 699 with open(test_stats, "r") as f: | 774 with open(test_stats, "r") as f: |
| 700 stats = json.load(f) | 775 stats = json.load(f) |
| 701 output_feature = next(iter(stats.keys()), "") | 776 output_feature = next(iter(stats.keys()), "") |
| 702 | 777 |
| 778 probs_path = None | |
| 779 prob_candidates = [ | |
| 780 exp_dir / f"{LABEL_COLUMN_NAME}_probabilities.csv", | |
| 781 exp_dir / f"{output_feature}_probabilities.csv" if output_feature else None, | |
| 782 exp_dir / "probabilities.csv", | |
| 783 exp_dir / "predictions.csv", | |
| 784 exp_dir / PREDICTIONS_PARQUET_FILE_NAME, | |
| 785 ] | |
| 786 for cand in prob_candidates: | |
| 787 if cand and Path(cand).exists(): | |
| 788 probs_path = str(cand) | |
| 789 break | |
| 790 | |
| 703 viz_registry = get_visualizations_registry() | 791 viz_registry = get_visualizations_registry() |
| 792 if not viz_registry: | |
| 793 logger.warning( | |
| 794 "Ludwig visualizations registry not available; train/test PNGs will be skipped." | |
| 795 ) | |
| 796 return | |
| 797 | |
| 798 base_kwargs = { | |
| 799 "training_statistics": [training_stats] if training_stats else [], | |
| 800 "test_statistics": [test_stats] if test_stats else [], | |
| 801 "probabilities": [probs_path] if probs_path else [], | |
| 802 "output_feature_name": output_feature, | |
| 803 "ground_truth_split": 2, | |
| 804 "top_n_classes": [20], | |
| 805 "top_k": 3, | |
| 806 "metrics": ["f1", "precision", "recall", "accuracy"], | |
| 807 "positive_label": 0, | |
| 808 "ground_truth_metadata": gt_metadata, | |
| 809 "ground_truth": dataset_path, | |
| 810 "split_file": split_file, | |
| 811 "output_directory": None, # set per plot below | |
| 812 "normalize": False, | |
| 813 "file_format": "png", | |
| 814 "model_names": [model_name], | |
| 815 } | |
| 704 for viz_name, viz_func in viz_registry.items(): | 816 for viz_name, viz_func in viz_registry.items(): |
| 705 if viz_name in train_plots: | 817 if viz_name in train_plots: |
| 706 viz_dir_plot = train_viz | 818 viz_dir_plot = train_viz |
| 707 elif viz_name in test_plots: | 819 elif viz_name in test_plots: |
| 708 viz_dir_plot = test_viz | 820 viz_dir_plot = test_viz |
| 709 else: | 821 else: |
| 710 continue | 822 continue |
| 711 | 823 |
| 712 try: | 824 try: |
| 825 # Build per-viz kwargs based on the function signature to avoid unexpected args | |
| 826 sig_params = set(inspect.signature(viz_func).parameters.keys()) | |
| 827 call_kwargs = { | |
| 828 k: v | |
| 829 for k, v in base_kwargs.items() | |
| 830 if k in sig_params and v is not None | |
| 831 } | |
| 832 if "output_directory" in sig_params: | |
| 833 call_kwargs["output_directory"] = str(viz_dir_plot) | |
| 834 | |
| 713 viz_func( | 835 viz_func( |
| 714 training_statistics=[training_stats] if training_stats else [], | 836 **call_kwargs, |
| 715 test_statistics=[test_stats] if test_stats else [], | |
| 716 probabilities=[probs_path] if probs_path else [], | |
| 717 output_feature_name=output_feature, | |
| 718 ground_truth_split=2, | |
| 719 top_n_classes=[0], | |
| 720 top_k=3, | |
| 721 ground_truth_metadata=gt_metadata, | |
| 722 ground_truth=dataset_path, | |
| 723 split_file=split_file, | |
| 724 output_directory=str(viz_dir_plot), | |
| 725 normalize=False, | |
| 726 file_format="png", | |
| 727 ) | 837 ) |
| 728 logger.info(f"✔ Generated {viz_name}") | 838 logger.info(f"✔ Generated {viz_name}") |
| 729 except Exception as e: | 839 except Exception as e: |
| 730 logger.warning(f"✘ Skipped {viz_name}: {e}") | 840 logger.warning(f"✘ Skipped {viz_name}: {e}") |
| 731 | |
| 732 logger.info(f"All visualizations written to {viz_dir}") | 841 logger.info(f"All visualizations written to {viz_dir}") |
| 733 | 842 |
| 734 def generate_html_report( | 843 def generate_html_report( |
| 735 self, | 844 self, |
| 736 title: str, | 845 title: str, |
| 754 exp_dir = exp_dirs[-1] | 863 exp_dir = exp_dirs[-1] |
| 755 train_set_metadata_path = exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME | 864 train_set_metadata_path = exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME |
| 756 label_metadata_path = config.get("label_column_data_path") | 865 label_metadata_path = config.get("label_column_data_path") |
| 757 if label_metadata_path: | 866 if label_metadata_path: |
| 758 label_metadata_path = Path(label_metadata_path) | 867 label_metadata_path = Path(label_metadata_path) |
| 868 dataset_path_from_desc: Optional[Path] = None | |
| 759 | 869 |
| 760 # Pull additional config details from description.json if available | 870 # Pull additional config details from description.json if available |
| 761 config_for_summary = dict(config) | 871 config_for_summary = dict(config) |
| 762 if "target_column" not in config_for_summary or not config_for_summary.get("target_column"): | 872 if "target_column" not in config_for_summary or not config_for_summary.get("target_column"): |
| 763 config_for_summary["target_column"] = LABEL_COLUMN_NAME | 873 config_for_summary["target_column"] = LABEL_COLUMN_NAME |
| 764 desc_path = exp_dir / DESCRIPTION_FILE_NAME | 874 desc_path = exp_dir / DESCRIPTION_FILE_NAME |
| 765 if desc_path.exists(): | 875 if desc_path.exists(): |
| 766 try: | 876 try: |
| 767 with open(desc_path, "r") as f: | 877 with open(desc_path, "r") as f: |
| 768 desc_cfg = json.load(f).get("config", {}) | 878 desc_json = json.load(f) |
| 879 desc_cfg = desc_json.get("config", {}) if isinstance(desc_json, dict) else {} | |
| 769 encoder_cfg = ( | 880 encoder_cfg = ( |
| 770 desc_cfg.get("input_features", [{}])[0].get("encoder", {}) | 881 desc_cfg.get("input_features", [{}])[0].get("encoder", {}) |
| 771 if isinstance(desc_cfg.get("input_features", [{}]), list) | 882 if isinstance(desc_cfg.get("input_features", [{}]), list) |
| 772 else {} | 883 else {} |
| 773 ) | 884 ) |
| 781 opt_cfg = trainer_cfg.get("optimizer", {}) if isinstance(trainer_cfg, dict) else {} | 892 opt_cfg = trainer_cfg.get("optimizer", {}) if isinstance(trainer_cfg, dict) else {} |
| 782 clip_cfg = trainer_cfg.get("gradient_clipping", {}) if isinstance(trainer_cfg, dict) else {} | 893 clip_cfg = trainer_cfg.get("gradient_clipping", {}) if isinstance(trainer_cfg, dict) else {} |
| 783 | 894 |
| 784 arch_type = encoder_cfg.get("type") | 895 arch_type = encoder_cfg.get("type") |
| 785 arch_variant = encoder_cfg.get("model_variant") | 896 arch_variant = encoder_cfg.get("model_variant") |
| 897 arch_custom = encoder_cfg.get("custom_model") | |
| 786 arch_name = None | 898 arch_name = None |
| 899 if arch_custom: | |
| 900 arch_name = str(arch_custom) | |
| 787 if arch_type: | 901 if arch_type: |
| 788 arch_base = str(arch_type).replace("_", " ").title() | 902 arch_base = str(arch_type).replace("_", " ").title() |
| 789 arch_name = f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base | 903 arch_type_name = ( |
| 904 f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base | |
| 905 ) | |
| 906 # Prefer explicit custom model names (e.g., MetaFormer) but fall back to encoder type | |
| 907 arch_name = arch_name or arch_type_name | |
| 908 if not arch_name and config.get("model_name"): | |
| 909 # As a last resort, show the user-selected model name (handles custom/MetaFormer cases) | |
| 910 arch_name = str(config.get("model_name")) | |
| 790 | 911 |
| 791 summary_fields = { | 912 summary_fields = { |
| 792 "architecture": arch_name, | 913 "architecture": arch_name, |
| 793 "model_variant": arch_variant, | 914 "model_variant": arch_variant, |
| 794 "pretrained": encoder_cfg.get("use_pretrained"), | 915 "pretrained": encoder_cfg.get("use_pretrained"), |
| 812 continue | 933 continue |
| 813 # Do not override user-passed target/image column names in config | 934 # Do not override user-passed target/image column names in config |
| 814 if k in {"target_column", "image_column"} and config_for_summary.get(k): | 935 if k in {"target_column", "image_column"} and config_for_summary.get(k): |
| 815 continue | 936 continue |
| 816 config_for_summary.setdefault(k, v) | 937 config_for_summary.setdefault(k, v) |
| 938 | |
| 939 dataset_field = None | |
| 940 if isinstance(desc_json, dict): | |
| 941 dataset_field = desc_json.get("dataset") or desc_cfg.get("dataset") | |
| 942 if dataset_field: | |
| 943 try: | |
| 944 dataset_path_from_desc = Path(dataset_field) | |
| 945 except TypeError: | |
| 946 dataset_path_from_desc = None | |
| 947 if dataset_path_from_desc and (not label_metadata_path or not label_metadata_path.exists()): | |
| 948 label_metadata_path = dataset_path_from_desc | |
| 817 except Exception as e: # pragma: no cover - defensive | 949 except Exception as e: # pragma: no cover - defensive |
| 818 logger.warning(f"Could not merge description.json into config summary: {e}") | 950 logger.warning(f"Could not merge description.json into config summary: {e}") |
| 819 | 951 |
| 820 base_viz_dir = exp_dir / "visualizations" | 952 base_viz_dir = exp_dir / "visualizations" |
| 821 train_viz_dir = base_viz_dir / "train" | 953 train_viz_dir = base_viz_dir / "train" |
| 822 test_viz_dir = base_viz_dir / "test" | |
| 823 | 954 |
| 824 html = get_html_template() | 955 html = get_html_template() |
| 825 | 956 |
| 826 # Extra CSS & JS: center Plotly and enable CSV download for predictions table | 957 # Extra CSS & JS: center Plotly and enable CSV download for predictions table |
| 827 html += """ | 958 html += """ |
| 878 document.body.removeChild(a); | 1009 document.body.removeChild(a); |
| 879 URL.revokeObjectURL(url); | 1010 URL.revokeObjectURL(url); |
| 880 }); | 1011 }); |
| 881 } | 1012 } |
| 882 }); | 1013 }); |
| 883 </script> | 1014 </script> |
| 884 """ | 1015 """ |
| 885 html += f"<h1>{title}</h1>" | 1016 html += f"<h1>{title}</h1>" |
| 1017 | |
| 1018 def append_plot_blocks(tab_html: str, plots: List[Dict[str, str]], title_suffix: str = "") -> str: | |
| 1019 """Append Plotly blocks to a tab with consistent markup.""" | |
| 1020 if not plots: | |
| 1021 return tab_html | |
| 1022 suffix = title_suffix or "" | |
| 1023 for plot in plots: | |
| 1024 tab_html += ( | |
| 1025 f"<h2 style='text-align: center;'>{plot['title']}{suffix}</h2>" | |
| 1026 f"<div class='plotly-center'>{plot['html']}</div>" | |
| 1027 ) | |
| 1028 return tab_html | |
| 1029 | |
| 1030 def build_dataset_overview( | |
| 1031 label_metadata: Optional[Path], | |
| 1032 output_type: Optional[str], | |
| 1033 split_probabilities: Optional[List[float]], | |
| 1034 label_split_counts: Optional[List[Dict[str, int]]] = None, | |
| 1035 split_counts: Optional[Dict[int, int]] = None, | |
| 1036 fallback_dataset: Optional[Path] = None, | |
| 1037 ) -> str: | |
| 1038 """Summarize dataset distribution across splits using the actual split config.""" | |
| 1039 if label_split_counts: | |
| 1040 # Use the actual counts captured during data prep instead of heuristics. | |
| 1041 return format_dataset_overview_table(label_split_counts, regression_mode=False) | |
| 1042 | |
| 1043 if output_type == "regression" and split_counts: | |
| 1044 rows = [ | |
| 1045 {"split": "train", "count": int(split_counts.get(0, 0))}, | |
| 1046 {"split": "validation", "count": int(split_counts.get(1, 0))}, | |
| 1047 {"split": "test", "count": int(split_counts.get(2, 0))}, | |
| 1048 ] | |
| 1049 return format_dataset_overview_table(rows, regression_mode=True) | |
| 1050 | |
| 1051 candidate_paths: List[Path] = [] | |
| 1052 if label_metadata and label_metadata.exists(): | |
| 1053 candidate_paths.append(label_metadata) | |
| 1054 if fallback_dataset and fallback_dataset.exists(): | |
| 1055 candidate_paths.append(fallback_dataset) | |
| 1056 if not candidate_paths: | |
| 1057 return format_dataset_overview_table([]) | |
| 1058 | |
| 1059 def _normalize_split_probabilities( | |
| 1060 probs: Optional[List[float]], | |
| 1061 ) -> Optional[List[float]]: | |
| 1062 if not probs or len(probs) != 3: | |
| 1063 return None | |
| 1064 try: | |
| 1065 probs = [float(p) for p in probs] | |
| 1066 except (TypeError, ValueError): | |
| 1067 return None | |
| 1068 total = sum(probs) | |
| 1069 if total <= 0: | |
| 1070 return None | |
| 1071 return [p / total for p in probs] | |
| 1072 | |
| 1073 def _split_counts_from_column(df: pd.DataFrame) -> Dict[int, int]: | |
| 1074 if SPLIT_COLUMN_NAME not in df.columns: | |
| 1075 return {} | |
| 1076 split_series = pd.to_numeric( | |
| 1077 df[SPLIT_COLUMN_NAME], errors="coerce" | |
| 1078 ).dropna() | |
| 1079 if split_series.empty: | |
| 1080 return {} | |
| 1081 split_series = split_series.astype(int) | |
| 1082 return split_series.value_counts().to_dict() | |
| 1083 | |
| 1084 def _split_counts_from_probs(total: int, probs: List[float]) -> Dict[int, int]: | |
| 1085 train_n = int(total * probs[0]) | |
| 1086 val_n = int(total * probs[1]) | |
| 1087 test_n = max(0, total - train_n - val_n) | |
| 1088 return {0: train_n, 1: val_n, 2: test_n} | |
| 1089 | |
| 1090 fallback_rows: Optional[List[Dict[str, int]]] = None | |
| 1091 for meta_path in candidate_paths: | |
| 1092 try: | |
| 1093 df_labels = pd.read_csv(meta_path) | |
| 1094 probs = _normalize_split_probabilities(split_probabilities) | |
| 1095 | |
| 1096 # Regression (or missing label column): only need split counts | |
| 1097 if output_type == "regression" or LABEL_COLUMN_NAME not in df_labels.columns: | |
| 1098 split_counts_found = _split_counts_from_column(df_labels) | |
| 1099 if split_counts_found: | |
| 1100 rows = [ | |
| 1101 {"split": "train", "count": int(split_counts_found.get(0, 0))}, | |
| 1102 {"split": "validation", "count": int(split_counts_found.get(1, 0))}, | |
| 1103 {"split": "test", "count": int(split_counts_found.get(2, 0))}, | |
| 1104 ] | |
| 1105 return format_dataset_overview_table(rows, regression_mode=True) | |
| 1106 if probs and fallback_rows is None: | |
| 1107 split_counts_found = _split_counts_from_probs(len(df_labels), probs) | |
| 1108 fallback_rows = [ | |
| 1109 {"split": "train", "count": int(split_counts_found.get(0, 0))}, | |
| 1110 {"split": "validation", "count": int(split_counts_found.get(1, 0))}, | |
| 1111 {"split": "test", "count": int(split_counts_found.get(2, 0))}, | |
| 1112 ] | |
| 1113 continue | |
| 1114 | |
| 1115 # Classification: prefer actual split assignments; fall back to configured probabilities | |
| 1116 if SPLIT_COLUMN_NAME in df_labels.columns: | |
| 1117 df_counts = df_labels[[LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME]].copy() | |
| 1118 df_counts[SPLIT_COLUMN_NAME] = pd.to_numeric( | |
| 1119 df_counts[SPLIT_COLUMN_NAME], errors="coerce" | |
| 1120 ) | |
| 1121 df_counts = df_counts.dropna(subset=[SPLIT_COLUMN_NAME]) | |
| 1122 if df_counts.empty: | |
| 1123 continue | |
| 1124 | |
| 1125 df_counts[SPLIT_COLUMN_NAME] = df_counts[SPLIT_COLUMN_NAME].astype(int) | |
| 1126 df_counts = df_counts.dropna(subset=[LABEL_COLUMN_NAME]) | |
| 1127 counts = ( | |
| 1128 df_counts.groupby([LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME]) | |
| 1129 .size() | |
| 1130 .unstack(fill_value=0) | |
| 1131 .sort_index() | |
| 1132 ) | |
| 1133 rows = [] | |
| 1134 for lbl, row in counts.iterrows(): | |
| 1135 rows.append( | |
| 1136 { | |
| 1137 "label": str(lbl), | |
| 1138 "train": int(row.get(0, 0)), | |
| 1139 "validation": int(row.get(1, 0)), | |
| 1140 "test": int(row.get(2, 0)), | |
| 1141 } | |
| 1142 ) | |
| 1143 return format_dataset_overview_table(rows) | |
| 1144 | |
| 1145 if probs: | |
| 1146 label_series = df_labels[LABEL_COLUMN_NAME].dropna() | |
| 1147 label_counts = label_series.value_counts().sort_index() | |
| 1148 if label_counts.empty: | |
| 1149 continue | |
| 1150 rows = [] | |
| 1151 for lbl, count in label_counts.items(): | |
| 1152 train_n = int(count * probs[0]) | |
| 1153 val_n = int(count * probs[1]) | |
| 1154 test_n = max(0, count - train_n - val_n) | |
| 1155 rows.append( | |
| 1156 { | |
| 1157 "label": str(lbl), | |
| 1158 "train": train_n, | |
| 1159 "validation": val_n, | |
| 1160 "test": test_n, | |
| 1161 } | |
| 1162 ) | |
| 1163 fallback_rows = fallback_rows or rows | |
| 1164 except Exception as exc: | |
| 1165 logger.warning("Failed to build dataset overview from %s: %s", meta_path, exc) | |
| 1166 continue | |
| 1167 | |
| 1168 if fallback_rows: | |
| 1169 return format_dataset_overview_table(fallback_rows, regression_mode=output_type == "regression") | |
| 1170 return format_dataset_overview_table([]) | |
| 886 | 1171 |
| 887 metrics_html = "" | 1172 metrics_html = "" |
| 888 train_val_metrics_html = "" | 1173 train_val_metrics_html = "" |
| 889 test_metrics_html = "" | 1174 test_metrics_html = "" |
| 890 output_type = None | 1175 output_type = None |
| 909 except Exception as e: | 1194 except Exception as e: |
| 910 logger.warning( | 1195 logger.warning( |
| 911 f"Could not load stats for HTML report: {type(e).__name__}: {e}" | 1196 f"Could not load stats for HTML report: {type(e).__name__}: {e}" |
| 912 ) | 1197 ) |
| 913 | 1198 |
| 1199 if not output_type: | |
| 1200 # Fallback to configured task type when stats are unavailable (e.g., failed run). | |
| 1201 output_type = ( | |
| 1202 str(config_for_summary.get("task_type")).lower() | |
| 1203 if config_for_summary.get("task_type") | |
| 1204 else None | |
| 1205 ) | |
| 1206 | |
| 1207 dataset_overview_html = build_dataset_overview( | |
| 1208 label_metadata_path, | |
| 1209 output_type, | |
| 1210 config.get("split_probabilities"), | |
| 1211 config.get("label_split_counts"), | |
| 1212 config.get("split_counts"), | |
| 1213 dataset_path_from_desc, | |
| 1214 ) | |
| 1215 | |
| 914 config_html = "" | 1216 config_html = "" |
| 915 training_progress = self.get_training_process(output_dir) | 1217 training_progress = self.get_training_process(output_dir) |
| 916 try: | 1218 try: |
| 917 config_html = format_config_table_html( | 1219 config_html = format_config_table_html( |
| 918 config_for_summary, split_info, training_progress, output_type | 1220 config_for_summary, split_info, training_progress, output_type |
| 935 dir_path: Path, | 1237 dir_path: Path, |
| 936 output_type: str = None, | 1238 output_type: str = None, |
| 937 exclude_names: Optional[set] = None, | 1239 exclude_names: Optional[set] = None, |
| 938 ) -> str: | 1240 ) -> str: |
| 939 if not dir_path.exists(): | 1241 if not dir_path.exists(): |
| 940 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" | 1242 return "" |
| 941 | 1243 |
| 942 exclude_names = exclude_names or set() | 1244 exclude_names = exclude_names or set() |
| 943 | 1245 |
| 944 imgs = list(dir_path.glob("*.png")) | 1246 # Search recursively because Ludwig can nest figures under per-feature folders |
| 1247 imgs = list(dir_path.rglob("*.png")) | |
| 945 | 1248 |
| 946 # Exclude ROC curves and standard confusion matrices (keep only entropy version) | 1249 # Exclude ROC curves and standard confusion matrices (keep only entropy version) |
| 947 default_exclude = { | 1250 default_exclude = { |
| 948 # "roc_curves.png", # Remove ROC curves from test tab | 1251 # "roc_curves.png", # Remove ROC curves from test tab |
| 949 "confusion_matrix__label_top5.png", # Remove standard confusion matrix | 1252 "confusion_matrix__label_top5.png", # Remove standard confusion matrix |
| 981 and "label" in img.stem | 1284 and "label" in img.stem |
| 982 ) | 1285 ) |
| 983 ] | 1286 ] |
| 984 | 1287 |
| 985 if not imgs: | 1288 if not imgs: |
| 986 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" | 1289 return "" |
| 987 | 1290 |
| 988 # Sort images by name for consistent ordering (works with string and numeric labels) | 1291 # Sort images by name for consistent ordering (works with string and numeric labels) |
| 989 imgs = sorted(imgs, key=lambda x: x.name) | 1292 imgs = sorted(imgs, key=lambda x: x.name) |
| 990 | 1293 |
| 991 html_section = "" | 1294 html_section = "" |
| 1004 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' | 1307 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' |
| 1005 f"</div>" | 1308 f"</div>" |
| 1006 ) | 1309 ) |
| 1007 return html_section | 1310 return html_section |
| 1008 | 1311 |
| 1009 # Show performance first, then config | 1312 # Show dataset overview, performance first, then config |
| 1010 tab1_content = metrics_html + config_html | 1313 predictions_csv_path = exp_dir / "predictions.csv" |
| 1011 | 1314 |
| 1012 tab2_content = train_val_metrics_html + render_img_section( | 1315 tab1_content = dataset_overview_html + metrics_html + config_html |
| 1013 "Training and Validation Visualizations", | 1316 |
| 1014 train_viz_dir, | 1317 tab2_content = train_val_metrics_html |
| 1015 output_type, | 1318 # Preload binary threshold plot so it appears first in Train/Val tab |
| 1016 exclude_names={ | 1319 threshold_plot = None |
| 1017 "compare_classifiers_performance_from_prob.png", | 1320 threshold_value = ( |
| 1018 "roc_curves_from_prediction_statistics.png", | 1321 config_for_summary.get("threshold") |
| 1019 "precision_recall_curves_from_prediction_statistics.png", | 1322 if config_for_summary.get("threshold") is not None |
| 1020 "precision_recall_curve.png", | 1323 else config.get("threshold") |
| 1021 }, | |
| 1022 ) | 1324 ) |
| 1325 if threshold_value is None and output_type == "binary": | |
| 1326 threshold_value = 0.5 | |
| 1327 if output_type == "binary" and predictions_csv_path.exists(): | |
| 1328 try: | |
| 1329 threshold_plot = build_binary_threshold_plot( | |
| 1330 str(predictions_csv_path), | |
| 1331 label_data_path=str(config.get("label_column_data_path")) | |
| 1332 if config.get("label_column_data_path") | |
| 1333 else None, | |
| 1334 split_value=1, | |
| 1335 ) | |
| 1336 except Exception as e: | |
| 1337 logger.warning(f"Could not generate validation threshold plot: {e}") | |
| 1338 | |
| 1023 if train_stats_path.exists(): | 1339 if train_stats_path.exists(): |
| 1024 try: | 1340 try: |
| 1025 if output_type == "regression": | 1341 if output_type == "regression": |
| 1026 tv_plots = build_regression_train_val_plots(str(train_stats_path)) | 1342 tv_plots = build_regression_train_val_plots(str(train_stats_path)) |
| 1343 tab2_content = append_plot_blocks(tab2_content, tv_plots) | |
| 1027 else: | 1344 else: |
| 1028 tv_plots = build_train_validation_plots(str(train_stats_path)) | 1345 tv_plots = build_train_validation_plots(str(train_stats_path)) |
| 1029 for plot in tv_plots: | 1346 # Add threshold plot first, then other train/val plots |
| 1030 tab2_content += ( | 1347 if threshold_plot: |
| 1031 f"<h2 style='text-align: center;'>{plot['title']}</h2>" | 1348 tab2_content = append_plot_blocks(tab2_content, [threshold_plot]) |
| 1032 f"<div class='plotly-center'>{plot['html']}</div>" | 1349 # Only append once; avoid duplicates if added elsewhere |
| 1033 ) | 1350 threshold_plot = None |
| 1034 if tv_plots: | 1351 tab2_content = append_plot_blocks(tab2_content, tv_plots) |
| 1035 logger.info(f"Generated {len(tv_plots)} train/val diagnostic plots") | 1352 if threshold_plot or tv_plots: |
| 1353 logger.info( | |
| 1354 f"Added {len(tv_plots) + (1 if threshold_plot else 0)} train/val diagnostic plots" | |
| 1355 ) | |
| 1036 except Exception as e: | 1356 except Exception as e: |
| 1037 logger.warning(f"Could not generate train/val plots: {e}") | 1357 logger.warning(f"Could not generate train/val plots: {e}") |
| 1358 | |
| 1359 # Only include training PNGs for regression; classification is handled by filtered Plotly plots | |
| 1360 if output_type == "regression": | |
| 1361 tab2_content += render_img_section( | |
| 1362 "Training and Validation Visualizations", | |
| 1363 train_viz_dir, | |
| 1364 output_type, | |
| 1365 exclude_names={ | |
| 1366 "compare_classifiers_performance_from_prob.png", | |
| 1367 "roc_curves_from_prediction_statistics.png", | |
| 1368 "precision_recall_curves_from_prediction_statistics.png", | |
| 1369 "precision_recall_curve.png", | |
| 1370 }, | |
| 1371 ) | |
| 1372 | |
| 1373 # Validation diagnostics (calibration/threshold) from predictions.csv, using split=1 | |
| 1374 if output_type in ("binary", "category") and predictions_csv_path.exists(): | |
| 1375 try: | |
| 1376 val_diag_plots = build_prediction_diagnostics( | |
| 1377 str(predictions_csv_path), | |
| 1378 label_data_path=str(config.get("label_column_data_path")) | |
| 1379 if config.get("label_column_data_path") | |
| 1380 else None, | |
| 1381 split_value=1, | |
| 1382 ) | |
| 1383 val_conf_plots = [ | |
| 1384 p for p in val_diag_plots if "Prediction Confidence Distribution" in p.get("title", "") | |
| 1385 ] | |
| 1386 tab2_content = append_plot_blocks( | |
| 1387 tab2_content, val_conf_plots, " (Validation)" | |
| 1388 ) | |
| 1389 except Exception as e: | |
| 1390 logger.warning(f"Could not generate validation diagnostics: {e}") | |
| 1038 | 1391 |
| 1039 # --- Predictions vs Ground Truth table (REGRESSION ONLY) --- | 1392 # --- Predictions vs Ground Truth table (REGRESSION ONLY) --- |
| 1040 preds_section = "" | 1393 preds_section = "" |
| 1041 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME | 1394 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME |
| 1042 if output_type == "regression" and parquet_path.exists(): | 1395 if output_type == "regression" and parquet_path.exists(): |
| 1075 ) | 1428 ) |
| 1076 except Exception as e: | 1429 except Exception as e: |
| 1077 logger.warning(f"Could not build Predictions vs GT table: {e}") | 1430 logger.warning(f"Could not build Predictions vs GT table: {e}") |
| 1078 | 1431 |
| 1079 tab3_content = test_metrics_html + preds_section | 1432 tab3_content = test_metrics_html + preds_section |
| 1080 test_plotly_added = False | |
| 1081 | 1433 |
| 1082 if output_type == "regression" and train_stats_path.exists(): | 1434 if output_type == "regression" and train_stats_path.exists(): |
| 1083 try: | 1435 try: |
| 1084 test_plots = build_regression_test_plots(str(train_stats_path)) | 1436 test_plots = build_regression_test_plots(str(train_stats_path)) |
| 1085 for plot in test_plots: | 1437 tab3_content = append_plot_blocks(tab3_content, test_plots) |
| 1086 tab3_content += ( | |
| 1087 f"<h2 style='text-align: center;'>{plot['title']}</h2>" | |
| 1088 f"<div class='plotly-center'>{plot['html']}</div>" | |
| 1089 ) | |
| 1090 if test_plots: | 1438 if test_plots: |
| 1091 test_plotly_added = True | |
| 1092 logger.info(f"Generated {len(test_plots)} regression test plots") | 1439 logger.info(f"Generated {len(test_plots)} regression test plots") |
| 1093 except Exception as e: | 1440 except Exception as e: |
| 1094 logger.warning(f"Could not generate regression test plots: {e}") | 1441 logger.warning(f"Could not generate regression test plots: {e}") |
| 1095 | 1442 |
| 1096 if output_type in ("binary", "category") and test_stats_path.exists(): | 1443 if output_type in ("binary", "category") and test_stats_path.exists(): |
| 1102 if label_metadata_path and label_metadata_path.exists() | 1449 if label_metadata_path and label_metadata_path.exists() |
| 1103 else None, | 1450 else None, |
| 1104 train_set_metadata_path=str(train_set_metadata_path) | 1451 train_set_metadata_path=str(train_set_metadata_path) |
| 1105 if train_set_metadata_path.exists() | 1452 if train_set_metadata_path.exists() |
| 1106 else None, | 1453 else None, |
| 1107 ) | 1454 threshold=threshold_value, |
| 1108 for plot in interactive_plots: | 1455 ) |
| 1109 tab3_content += ( | 1456 tab3_content = append_plot_blocks(tab3_content, interactive_plots) |
| 1110 f"<h2 style='text-align: center;'>{plot['title']}</h2>" | |
| 1111 f"<div class='plotly-center'>{plot['html']}</div>" | |
| 1112 ) | |
| 1113 if interactive_plots: | 1457 if interactive_plots: |
| 1114 test_plotly_added = True | 1458 logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots") |
| 1115 logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots") | |
| 1116 except Exception as e: | 1459 except Exception as e: |
| 1117 logger.warning(f"Could not generate Plotly plots: {e}") | 1460 logger.warning(f"Could not generate Plotly plots: {e}") |
| 1118 | 1461 |
| 1119 # Add prediction diagnostics from predictions.csv | 1462 # Multi-class transparency plots from test stats (replace ROC/PR for multi-class) |
| 1120 predictions_csv_path = exp_dir / "predictions.csv" | 1463 if output_type == "category" and test_stats_path.exists(): |
| 1121 try: | 1464 try: |
| 1122 diag_plots = build_prediction_diagnostics( | 1465 multi_curves = build_multiclass_metric_plots(str(test_stats_path)) |
| 1123 str(predictions_csv_path), | 1466 tab3_content = append_plot_blocks(tab3_content, multi_curves) |
| 1124 label_data_path=str(config.get("label_column_data_path")) | 1467 if multi_curves: |
| 1125 if config.get("label_column_data_path") | 1468 logger.info("Added multi-class per-class metric plots to test tab") |
| 1126 else None, | 1469 except Exception as e: |
| 1127 threshold=config.get("threshold"), | 1470 logger.warning(f"Could not generate multi-class metric plots: {e}") |
| 1128 ) | 1471 |
| 1129 for plot in diag_plots: | 1472 # Test diagnostics (confidence histogram) from predictions.csv, using split=2 |
| 1130 tab3_content += ( | 1473 if predictions_csv_path.exists(): |
| 1131 f"<h2 style='text-align: center;'>{plot['title']}</h2>" | 1474 try: |
| 1132 f"<div class='plotly-center'>{plot['html']}</div>" | 1475 test_diag_plots = build_prediction_diagnostics( |
| 1476 str(predictions_csv_path), | |
| 1477 label_data_path=str(config.get("label_column_data_path")) | |
| 1478 if config.get("label_column_data_path") | |
| 1479 else None, | |
| 1480 split_value=2, | |
| 1133 ) | 1481 ) |
| 1134 if diag_plots: | 1482 test_conf_plots = [ |
| 1135 test_plotly_added = True | 1483 p for p in test_diag_plots if "Prediction Confidence Distribution" in p.get("title", "") |
| 1136 logger.info(f"Generated {len(diag_plots)} prediction diagnostic plots") | 1484 ] |
| 1137 except Exception as e: | 1485 if test_conf_plots: |
| 1138 logger.warning(f"Could not generate prediction diagnostics: {e}") | 1486 tab3_content = append_plot_blocks(tab3_content, test_conf_plots) |
| 1139 | 1487 logger.info("Added test prediction confidence plot") |
| 1140 # Fallback: include static PNGs if no interactive plots were added | 1488 except Exception as e: |
| 1141 if not test_plotly_added: | 1489 logger.warning(f"Could not generate test diagnostics: {e}") |
| 1142 tab3_content += render_img_section( | |
| 1143 "Test Visualizations (PNG fallback)", | |
| 1144 test_viz_dir, | |
| 1145 output_type, | |
| 1146 ) | |
| 1147 | 1490 |
| 1148 # Add static TEST PNGs (with default dedupe/exclusions) | 1491 # Add static TEST PNGs (with default dedupe/exclusions) |
| 1149 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) | 1492 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) |
| 1150 modal_html = get_metrics_help_modal() | 1493 modal_html = get_metrics_help_modal() |
| 1151 html += tabbed_html + modal_html + get_html_closing() | 1494 html += tabbed_html + modal_html + get_html_closing() |
