comparison ludwig_backend.py @ 17:db9be962dc13 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
author goeckslab
date Wed, 10 Dec 2025 00:24:13 +0000
parents 8729f69e9207
children
comparison
equal deleted inserted replaced
16:8729f69e9207 17:db9be962dc13
1 import inspect
1 import json 2 import json
2 import logging 3 import logging
3 import os 4 import os
4 from pathlib import Path 5 from pathlib import Path
5 from typing import Any, Dict, Optional, Protocol, Tuple 6 from typing import Any, Dict, List, Optional, Protocol, Tuple
6 7
7 import pandas as pd 8 import pandas as pd
8 import pandas.api.types as ptypes 9 import pandas.api.types as ptypes
9 import yaml 10 import yaml
10 from constants import ( 11 from constants import (
15 ) 16 )
16 from html_structure import ( 17 from html_structure import (
17 build_tabbed_html, 18 build_tabbed_html,
18 encode_image_to_base64, 19 encode_image_to_base64,
19 format_config_table_html, 20 format_config_table_html,
21 format_dataset_overview_table,
20 format_stats_table_html, 22 format_stats_table_html,
21 format_test_merged_stats_table_html, 23 format_test_merged_stats_table_html,
22 format_train_val_stats_table_html, 24 format_train_val_stats_table_html,
23 get_html_closing, 25 get_html_closing,
24 get_html_template, 26 get_html_template,
31 TRAIN_SET_METADATA_FILE_NAME, 33 TRAIN_SET_METADATA_FILE_NAME,
32 ) 34 )
33 from ludwig.utils.data_utils import get_split_path 35 from ludwig.utils.data_utils import get_split_path
34 from metaformer_setup import get_visualizations_registry, META_DEFAULT_CFGS 36 from metaformer_setup import get_visualizations_registry, META_DEFAULT_CFGS
35 from plotly_plots import ( 37 from plotly_plots import (
38 build_binary_threshold_plot,
36 build_classification_plots, 39 build_classification_plots,
40 build_multiclass_metric_plots,
37 build_prediction_diagnostics, 41 build_prediction_diagnostics,
38 build_regression_test_plots, 42 build_regression_test_plots,
39 build_regression_train_val_plots, 43 build_regression_train_val_plots,
40 build_train_validation_plots, 44 build_train_validation_plots,
41 ) 45 )
265 "trainable": trainable, 269 "trainable": trainable,
266 } 270 }
267 else: 271 else:
268 encoder_config = {"type": raw_encoder} 272 encoder_config = {"type": raw_encoder}
269 273
274 # Set a human-friendly architecture string for reporting
275 arch_display = None
276 if is_metaformer and custom_model:
277 arch_display = str(custom_model)
278 elif isinstance(raw_encoder, dict):
279 enc_type = raw_encoder.get("type")
280 enc_variant = raw_encoder.get("model_variant")
281 if enc_type:
282 base = str(enc_type).replace("_", " ").title()
283 arch_display = f"{base} {enc_variant}" if enc_variant is not None else base
284 else:
285 arch_display = str(raw_encoder).replace("_", " ").title()
286
287 if not arch_display:
288 arch_display = str(model_name)
289 config_params["architecture"] = arch_display
290
270 batch_size_cfg = batch_size or "auto" 291 batch_size_cfg = batch_size or "auto"
271 292
272 label_column_path = config_params.get("label_column_data_path") 293 label_column_path = config_params.get("label_column_data_path")
273 label_series = None 294 label_series = None
274 label_metadata_hint = config_params.get("label_metadata") or {} 295 label_metadata_hint = config_params.get("label_metadata") or {}
341 image_feat["preprocessing"]["resize_method"] = "interpolate" # Use interpolation for better quality 362 image_feat["preprocessing"]["resize_method"] = "interpolate" # Use interpolation for better quality
342 image_feat["preprocessing"]["standardize_image"] = "imagenet1k" # Use ImageNet standardization 363 image_feat["preprocessing"]["standardize_image"] = "imagenet1k" # Use ImageNet standardization
343 # Force Ludwig to respect our dimensions by setting additional parameters 364 # Force Ludwig to respect our dimensions by setting additional parameters
344 image_feat["preprocessing"]["requires_equal_dimensions"] = False 365 image_feat["preprocessing"]["requires_equal_dimensions"] = False
345 logger.info(f"Set preprocessing dimensions for MetaFormer: {height}x{width} (infer_dimensions=True with max dimensions to allow validation)") 366 logger.info(f"Set preprocessing dimensions for MetaFormer: {height}x{width} (infer_dimensions=True with max dimensions to allow validation)")
367 config_params["image_size"] = f"{height}x{width}"
346 # Now set the encoder configuration 368 # Now set the encoder configuration
347 image_feat["encoder"] = encoder_config 369 image_feat["encoder"] = encoder_config
348 370
349 if config_params.get("augmentation") is not None: 371 if config_params.get("augmentation") is not None:
350 image_feat["augmentation"] = config_params["augmentation"] 372 image_feat["augmentation"] = config_params["augmentation"]
372 # but set explicit max dimensions to control the output size 394 # but set explicit max dimensions to control the output size
373 image_feat["preprocessing"]["infer_image_dimensions"] = True 395 image_feat["preprocessing"]["infer_image_dimensions"] = True
374 image_feat["preprocessing"]["infer_image_max_height"] = height 396 image_feat["preprocessing"]["infer_image_max_height"] = height
375 image_feat["preprocessing"]["infer_image_max_width"] = width 397 image_feat["preprocessing"]["infer_image_max_width"] = width
376 logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions") 398 logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions")
399 config_params["image_size"] = f"{height}x{width}"
377 except (ValueError, IndexError): 400 except (ValueError, IndexError):
378 logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing") 401 logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing")
402 elif not is_metaformer:
403 # No explicit resize provided; keep for reporting purposes
404 config_params.setdefault("image_size", "original")
379 405
380 def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]: 406 def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]:
381 """Pick a validation metric that Ludwig will accept for the resolved task.""" 407 """Pick a validation metric that Ludwig will accept for the resolved task."""
382 default_map = { 408 default_map = {
383 "regression": "pearson_r", 409 "regression": "pearson_r",
469 val_metric = _resolve_validation_metric( 495 val_metric = _resolve_validation_metric(
470 "binary" if num_unique_labels == 2 else "category", 496 "binary" if num_unique_labels == 2 else "category",
471 config_params.get("validation_metric"), 497 config_params.get("validation_metric"),
472 ) 498 )
473 499
500 # Propagate the resolved validation metric (including any task-based fallback or alias normalization)
501 config_params["validation_metric"] = val_metric
502
474 conf: Dict[str, Any] = { 503 conf: Dict[str, Any] = {
475 "model_type": "ecd", 504 "model_type": "ecd",
476 "input_features": [image_feat], 505 "input_features": [image_feat],
477 "output_features": [output_feat], 506 "output_features": [output_feat],
478 "combiner": {"type": "concat"}, 507 "combiner": {"type": "concat"},
639 df.to_csv(csv_path, index=False) 668 df.to_csv(csv_path, index=False)
640 logger.info(f"Converted Parquet to CSV: {csv_path}") 669 logger.info(f"Converted Parquet to CSV: {csv_path}")
641 except Exception as e: 670 except Exception as e:
642 logger.error(f"Error converting Parquet to CSV: {e}") 671 logger.error(f"Error converting Parquet to CSV: {e}")
643 672
673 @staticmethod
674 def _extract_metric_series(stats: Dict[str, Any], split: str, prefer: Optional[str] = None) -> Tuple[Optional[str], Optional[List[float]]]:
675 """Pull the first numeric metric list we can find for the requested split."""
676 if not isinstance(stats, dict):
677 return None, None
678
679 split_stats = stats.get(split, {})
680 ordered_metrics: List[Tuple[str, List[float]]] = []
681
682 def _append_metrics(metric_map: Dict[str, Any]) -> None:
683 for metric_name, values in metric_map.items():
684 if isinstance(values, list) and any(isinstance(v, (int, float)) for v in values):
685 ordered_metrics.append((metric_name, values))
686
687 if isinstance(split_stats, dict):
688 combined = split_stats.get("combined")
689 if isinstance(combined, dict):
690 _append_metrics(combined)
691
692 for feature_name, feature_metrics in split_stats.items():
693 if feature_name == "combined" or not isinstance(feature_metrics, dict):
694 continue
695 _append_metrics(feature_metrics)
696
697 if prefer:
698 for metric_name, values in ordered_metrics:
699 if metric_name == prefer:
700 return metric_name, values
701
702 return ordered_metrics[0] if ordered_metrics else (None, None)
703
644 def generate_plots(self, output_dir: Path) -> None: 704 def generate_plots(self, output_dir: Path) -> None:
645 """Generate all registered Ludwig visualizations for the latest experiment run.""" 705 """Generate Ludwig visualizations (train/val + test) for the latest experiment run."""
646 logger.info("Generating all Ludwig visualizations…") 706 logger.info("Generating Ludwig visualizations (train/val + test)…")
647 707
648 # Keep only lightweight plots (drop compare_performance/roc_curves) 708 # Train/validation visualizations
649 test_plots = {
650 "roc_curves_from_test_statistics",
651 "confusion_matrix",
652 }
653 train_plots = { 709 train_plots = {
654 "learning_curves", 710 "learning_curves",
655 "compare_classifiers_performance_subset", 711 }
712
713 # Test visualizations (multi-class transparency)
714 test_plots = {
715 "confusion_matrix",
716 "compare_performance",
717 "compare_classifiers_multiclass_multimetric",
718 "frequency_vs_f1",
719 "confidence_thresholding",
720 "confidence_thresholding_data_vs_acc",
721 "confidence_thresholding_data_vs_acc_subset",
722 "confidence_thresholding_data_vs_acc_subset_per_class",
723 # Binary-only visualizations will still be attempted; multi-class replacements handled elsewhere
724 "binary_threshold_vs_metric",
725 "roc_curves",
726 "precision_recall_curves",
727 "calibration_1_vs_all",
728 "calibration_multiclass",
656 } 729 }
657 730
658 output_dir = Path(output_dir) 731 output_dir = Path(output_dir)
659 exp_dirs = sorted( 732 exp_dirs = sorted(
660 output_dir.glob("experiment_run*"), 733 output_dir.glob("experiment_run*"),
675 def _check(p: Path) -> Optional[str]: 748 def _check(p: Path) -> Optional[str]:
676 return str(p) if p.exists() else None 749 return str(p) if p.exists() else None
677 750
678 training_stats = _check(exp_dir / "training_statistics.json") 751 training_stats = _check(exp_dir / "training_statistics.json")
679 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) 752 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME)
680 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME)
681 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) 753 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME)
682 754
683 dataset_path = None 755 dataset_path = None
684 split_file = None 756 split_file = None
685 desc = exp_dir / DESCRIPTION_FILE_NAME 757 desc = exp_dir / DESCRIPTION_FILE_NAME
686 if desc.exists(): 758 if desc.exists():
687 with open(desc, "r") as f: 759 with open(desc, "r") as f:
688 cfg = json.load(f) 760 cfg = json.load(f)
689 dataset_path = _check(Path(cfg.get("dataset", ""))) 761 dataset_path = _check(Path(cfg.get("dataset", "")))
690 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) 762 split_file = _check(Path(get_split_path(cfg.get("dataset", ""))))
763 model_name = cfg.get("model_name", "model")
764 else:
765 model_name = "model"
691 766
692 output_feature = "" 767 output_feature = ""
693 if desc.exists(): 768 if desc.exists():
694 try: 769 try:
695 output_feature = cfg["config"]["output_features"][0]["name"] 770 output_feature = cfg["config"]["output_features"][0]["name"]
698 if not output_feature and test_stats: 773 if not output_feature and test_stats:
699 with open(test_stats, "r") as f: 774 with open(test_stats, "r") as f:
700 stats = json.load(f) 775 stats = json.load(f)
701 output_feature = next(iter(stats.keys()), "") 776 output_feature = next(iter(stats.keys()), "")
702 777
778 probs_path = None
779 prob_candidates = [
780 exp_dir / f"{LABEL_COLUMN_NAME}_probabilities.csv",
781 exp_dir / f"{output_feature}_probabilities.csv" if output_feature else None,
782 exp_dir / "probabilities.csv",
783 exp_dir / "predictions.csv",
784 exp_dir / PREDICTIONS_PARQUET_FILE_NAME,
785 ]
786 for cand in prob_candidates:
787 if cand and Path(cand).exists():
788 probs_path = str(cand)
789 break
790
703 viz_registry = get_visualizations_registry() 791 viz_registry = get_visualizations_registry()
792 if not viz_registry:
793 logger.warning(
794 "Ludwig visualizations registry not available; train/test PNGs will be skipped."
795 )
796 return
797
798 base_kwargs = {
799 "training_statistics": [training_stats] if training_stats else [],
800 "test_statistics": [test_stats] if test_stats else [],
801 "probabilities": [probs_path] if probs_path else [],
802 "output_feature_name": output_feature,
803 "ground_truth_split": 2,
804 "top_n_classes": [20],
805 "top_k": 3,
806 "metrics": ["f1", "precision", "recall", "accuracy"],
807 "positive_label": 0,
808 "ground_truth_metadata": gt_metadata,
809 "ground_truth": dataset_path,
810 "split_file": split_file,
811 "output_directory": None, # set per plot below
812 "normalize": False,
813 "file_format": "png",
814 "model_names": [model_name],
815 }
704 for viz_name, viz_func in viz_registry.items(): 816 for viz_name, viz_func in viz_registry.items():
705 if viz_name in train_plots: 817 if viz_name in train_plots:
706 viz_dir_plot = train_viz 818 viz_dir_plot = train_viz
707 elif viz_name in test_plots: 819 elif viz_name in test_plots:
708 viz_dir_plot = test_viz 820 viz_dir_plot = test_viz
709 else: 821 else:
710 continue 822 continue
711 823
712 try: 824 try:
825 # Build per-viz kwargs based on the function signature to avoid unexpected args
826 sig_params = set(inspect.signature(viz_func).parameters.keys())
827 call_kwargs = {
828 k: v
829 for k, v in base_kwargs.items()
830 if k in sig_params and v is not None
831 }
832 if "output_directory" in sig_params:
833 call_kwargs["output_directory"] = str(viz_dir_plot)
834
713 viz_func( 835 viz_func(
714 training_statistics=[training_stats] if training_stats else [], 836 **call_kwargs,
715 test_statistics=[test_stats] if test_stats else [],
716 probabilities=[probs_path] if probs_path else [],
717 output_feature_name=output_feature,
718 ground_truth_split=2,
719 top_n_classes=[0],
720 top_k=3,
721 ground_truth_metadata=gt_metadata,
722 ground_truth=dataset_path,
723 split_file=split_file,
724 output_directory=str(viz_dir_plot),
725 normalize=False,
726 file_format="png",
727 ) 837 )
728 logger.info(f"✔ Generated {viz_name}") 838 logger.info(f"✔ Generated {viz_name}")
729 except Exception as e: 839 except Exception as e:
730 logger.warning(f"✘ Skipped {viz_name}: {e}") 840 logger.warning(f"✘ Skipped {viz_name}: {e}")
731
732 logger.info(f"All visualizations written to {viz_dir}") 841 logger.info(f"All visualizations written to {viz_dir}")
733 842
734 def generate_html_report( 843 def generate_html_report(
735 self, 844 self,
736 title: str, 845 title: str,
754 exp_dir = exp_dirs[-1] 863 exp_dir = exp_dirs[-1]
755 train_set_metadata_path = exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME 864 train_set_metadata_path = exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME
756 label_metadata_path = config.get("label_column_data_path") 865 label_metadata_path = config.get("label_column_data_path")
757 if label_metadata_path: 866 if label_metadata_path:
758 label_metadata_path = Path(label_metadata_path) 867 label_metadata_path = Path(label_metadata_path)
868 dataset_path_from_desc: Optional[Path] = None
759 869
760 # Pull additional config details from description.json if available 870 # Pull additional config details from description.json if available
761 config_for_summary = dict(config) 871 config_for_summary = dict(config)
762 if "target_column" not in config_for_summary or not config_for_summary.get("target_column"): 872 if "target_column" not in config_for_summary or not config_for_summary.get("target_column"):
763 config_for_summary["target_column"] = LABEL_COLUMN_NAME 873 config_for_summary["target_column"] = LABEL_COLUMN_NAME
764 desc_path = exp_dir / DESCRIPTION_FILE_NAME 874 desc_path = exp_dir / DESCRIPTION_FILE_NAME
765 if desc_path.exists(): 875 if desc_path.exists():
766 try: 876 try:
767 with open(desc_path, "r") as f: 877 with open(desc_path, "r") as f:
768 desc_cfg = json.load(f).get("config", {}) 878 desc_json = json.load(f)
879 desc_cfg = desc_json.get("config", {}) if isinstance(desc_json, dict) else {}
769 encoder_cfg = ( 880 encoder_cfg = (
770 desc_cfg.get("input_features", [{}])[0].get("encoder", {}) 881 desc_cfg.get("input_features", [{}])[0].get("encoder", {})
771 if isinstance(desc_cfg.get("input_features", [{}]), list) 882 if isinstance(desc_cfg.get("input_features", [{}]), list)
772 else {} 883 else {}
773 ) 884 )
781 opt_cfg = trainer_cfg.get("optimizer", {}) if isinstance(trainer_cfg, dict) else {} 892 opt_cfg = trainer_cfg.get("optimizer", {}) if isinstance(trainer_cfg, dict) else {}
782 clip_cfg = trainer_cfg.get("gradient_clipping", {}) if isinstance(trainer_cfg, dict) else {} 893 clip_cfg = trainer_cfg.get("gradient_clipping", {}) if isinstance(trainer_cfg, dict) else {}
783 894
784 arch_type = encoder_cfg.get("type") 895 arch_type = encoder_cfg.get("type")
785 arch_variant = encoder_cfg.get("model_variant") 896 arch_variant = encoder_cfg.get("model_variant")
897 arch_custom = encoder_cfg.get("custom_model")
786 arch_name = None 898 arch_name = None
899 if arch_custom:
900 arch_name = str(arch_custom)
787 if arch_type: 901 if arch_type:
788 arch_base = str(arch_type).replace("_", " ").title() 902 arch_base = str(arch_type).replace("_", " ").title()
789 arch_name = f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base 903 arch_type_name = (
904 f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base
905 )
906 # Prefer explicit custom model names (e.g., MetaFormer) but fall back to encoder type
907 arch_name = arch_name or arch_type_name
908 if not arch_name and config.get("model_name"):
909 # As a last resort, show the user-selected model name (handles custom/MetaFormer cases)
910 arch_name = str(config.get("model_name"))
790 911
791 summary_fields = { 912 summary_fields = {
792 "architecture": arch_name, 913 "architecture": arch_name,
793 "model_variant": arch_variant, 914 "model_variant": arch_variant,
794 "pretrained": encoder_cfg.get("use_pretrained"), 915 "pretrained": encoder_cfg.get("use_pretrained"),
812 continue 933 continue
813 # Do not override user-passed target/image column names in config 934 # Do not override user-passed target/image column names in config
814 if k in {"target_column", "image_column"} and config_for_summary.get(k): 935 if k in {"target_column", "image_column"} and config_for_summary.get(k):
815 continue 936 continue
816 config_for_summary.setdefault(k, v) 937 config_for_summary.setdefault(k, v)
938
939 dataset_field = None
940 if isinstance(desc_json, dict):
941 dataset_field = desc_json.get("dataset") or desc_cfg.get("dataset")
942 if dataset_field:
943 try:
944 dataset_path_from_desc = Path(dataset_field)
945 except TypeError:
946 dataset_path_from_desc = None
947 if dataset_path_from_desc and (not label_metadata_path or not label_metadata_path.exists()):
948 label_metadata_path = dataset_path_from_desc
817 except Exception as e: # pragma: no cover - defensive 949 except Exception as e: # pragma: no cover - defensive
818 logger.warning(f"Could not merge description.json into config summary: {e}") 950 logger.warning(f"Could not merge description.json into config summary: {e}")
819 951
820 base_viz_dir = exp_dir / "visualizations" 952 base_viz_dir = exp_dir / "visualizations"
821 train_viz_dir = base_viz_dir / "train" 953 train_viz_dir = base_viz_dir / "train"
822 test_viz_dir = base_viz_dir / "test"
823 954
824 html = get_html_template() 955 html = get_html_template()
825 956
826 # Extra CSS & JS: center Plotly and enable CSV download for predictions table 957 # Extra CSS & JS: center Plotly and enable CSV download for predictions table
827 html += """ 958 html += """
878 document.body.removeChild(a); 1009 document.body.removeChild(a);
879 URL.revokeObjectURL(url); 1010 URL.revokeObjectURL(url);
880 }); 1011 });
881 } 1012 }
882 }); 1013 });
883 </script> 1014 </script>
884 """ 1015 """
885 html += f"<h1>{title}</h1>" 1016 html += f"<h1>{title}</h1>"
1017
1018 def append_plot_blocks(tab_html: str, plots: List[Dict[str, str]], title_suffix: str = "") -> str:
1019 """Append Plotly blocks to a tab with consistent markup."""
1020 if not plots:
1021 return tab_html
1022 suffix = title_suffix or ""
1023 for plot in plots:
1024 tab_html += (
1025 f"<h2 style='text-align: center;'>{plot['title']}{suffix}</h2>"
1026 f"<div class='plotly-center'>{plot['html']}</div>"
1027 )
1028 return tab_html
1029
1030 def build_dataset_overview(
1031 label_metadata: Optional[Path],
1032 output_type: Optional[str],
1033 split_probabilities: Optional[List[float]],
1034 label_split_counts: Optional[List[Dict[str, int]]] = None,
1035 split_counts: Optional[Dict[int, int]] = None,
1036 fallback_dataset: Optional[Path] = None,
1037 ) -> str:
1038 """Summarize dataset distribution across splits using the actual split config."""
1039 if label_split_counts:
1040 # Use the actual counts captured during data prep instead of heuristics.
1041 return format_dataset_overview_table(label_split_counts, regression_mode=False)
1042
1043 if output_type == "regression" and split_counts:
1044 rows = [
1045 {"split": "train", "count": int(split_counts.get(0, 0))},
1046 {"split": "validation", "count": int(split_counts.get(1, 0))},
1047 {"split": "test", "count": int(split_counts.get(2, 0))},
1048 ]
1049 return format_dataset_overview_table(rows, regression_mode=True)
1050
1051 candidate_paths: List[Path] = []
1052 if label_metadata and label_metadata.exists():
1053 candidate_paths.append(label_metadata)
1054 if fallback_dataset and fallback_dataset.exists():
1055 candidate_paths.append(fallback_dataset)
1056 if not candidate_paths:
1057 return format_dataset_overview_table([])
1058
1059 def _normalize_split_probabilities(
1060 probs: Optional[List[float]],
1061 ) -> Optional[List[float]]:
1062 if not probs or len(probs) != 3:
1063 return None
1064 try:
1065 probs = [float(p) for p in probs]
1066 except (TypeError, ValueError):
1067 return None
1068 total = sum(probs)
1069 if total <= 0:
1070 return None
1071 return [p / total for p in probs]
1072
1073 def _split_counts_from_column(df: pd.DataFrame) -> Dict[int, int]:
1074 if SPLIT_COLUMN_NAME not in df.columns:
1075 return {}
1076 split_series = pd.to_numeric(
1077 df[SPLIT_COLUMN_NAME], errors="coerce"
1078 ).dropna()
1079 if split_series.empty:
1080 return {}
1081 split_series = split_series.astype(int)
1082 return split_series.value_counts().to_dict()
1083
1084 def _split_counts_from_probs(total: int, probs: List[float]) -> Dict[int, int]:
1085 train_n = int(total * probs[0])
1086 val_n = int(total * probs[1])
1087 test_n = max(0, total - train_n - val_n)
1088 return {0: train_n, 1: val_n, 2: test_n}
1089
1090 fallback_rows: Optional[List[Dict[str, int]]] = None
1091 for meta_path in candidate_paths:
1092 try:
1093 df_labels = pd.read_csv(meta_path)
1094 probs = _normalize_split_probabilities(split_probabilities)
1095
1096 # Regression (or missing label column): only need split counts
1097 if output_type == "regression" or LABEL_COLUMN_NAME not in df_labels.columns:
1098 split_counts_found = _split_counts_from_column(df_labels)
1099 if split_counts_found:
1100 rows = [
1101 {"split": "train", "count": int(split_counts_found.get(0, 0))},
1102 {"split": "validation", "count": int(split_counts_found.get(1, 0))},
1103 {"split": "test", "count": int(split_counts_found.get(2, 0))},
1104 ]
1105 return format_dataset_overview_table(rows, regression_mode=True)
1106 if probs and fallback_rows is None:
1107 split_counts_found = _split_counts_from_probs(len(df_labels), probs)
1108 fallback_rows = [
1109 {"split": "train", "count": int(split_counts_found.get(0, 0))},
1110 {"split": "validation", "count": int(split_counts_found.get(1, 0))},
1111 {"split": "test", "count": int(split_counts_found.get(2, 0))},
1112 ]
1113 continue
1114
1115 # Classification: prefer actual split assignments; fall back to configured probabilities
1116 if SPLIT_COLUMN_NAME in df_labels.columns:
1117 df_counts = df_labels[[LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME]].copy()
1118 df_counts[SPLIT_COLUMN_NAME] = pd.to_numeric(
1119 df_counts[SPLIT_COLUMN_NAME], errors="coerce"
1120 )
1121 df_counts = df_counts.dropna(subset=[SPLIT_COLUMN_NAME])
1122 if df_counts.empty:
1123 continue
1124
1125 df_counts[SPLIT_COLUMN_NAME] = df_counts[SPLIT_COLUMN_NAME].astype(int)
1126 df_counts = df_counts.dropna(subset=[LABEL_COLUMN_NAME])
1127 counts = (
1128 df_counts.groupby([LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME])
1129 .size()
1130 .unstack(fill_value=0)
1131 .sort_index()
1132 )
1133 rows = []
1134 for lbl, row in counts.iterrows():
1135 rows.append(
1136 {
1137 "label": str(lbl),
1138 "train": int(row.get(0, 0)),
1139 "validation": int(row.get(1, 0)),
1140 "test": int(row.get(2, 0)),
1141 }
1142 )
1143 return format_dataset_overview_table(rows)
1144
1145 if probs:
1146 label_series = df_labels[LABEL_COLUMN_NAME].dropna()
1147 label_counts = label_series.value_counts().sort_index()
1148 if label_counts.empty:
1149 continue
1150 rows = []
1151 for lbl, count in label_counts.items():
1152 train_n = int(count * probs[0])
1153 val_n = int(count * probs[1])
1154 test_n = max(0, count - train_n - val_n)
1155 rows.append(
1156 {
1157 "label": str(lbl),
1158 "train": train_n,
1159 "validation": val_n,
1160 "test": test_n,
1161 }
1162 )
1163 fallback_rows = fallback_rows or rows
1164 except Exception as exc:
1165 logger.warning("Failed to build dataset overview from %s: %s", meta_path, exc)
1166 continue
1167
1168 if fallback_rows:
1169 return format_dataset_overview_table(fallback_rows, regression_mode=output_type == "regression")
1170 return format_dataset_overview_table([])
886 1171
887 metrics_html = "" 1172 metrics_html = ""
888 train_val_metrics_html = "" 1173 train_val_metrics_html = ""
889 test_metrics_html = "" 1174 test_metrics_html = ""
890 output_type = None 1175 output_type = None
909 except Exception as e: 1194 except Exception as e:
910 logger.warning( 1195 logger.warning(
911 f"Could not load stats for HTML report: {type(e).__name__}: {e}" 1196 f"Could not load stats for HTML report: {type(e).__name__}: {e}"
912 ) 1197 )
913 1198
1199 if not output_type:
1200 # Fallback to configured task type when stats are unavailable (e.g., failed run).
1201 output_type = (
1202 str(config_for_summary.get("task_type")).lower()
1203 if config_for_summary.get("task_type")
1204 else None
1205 )
1206
1207 dataset_overview_html = build_dataset_overview(
1208 label_metadata_path,
1209 output_type,
1210 config.get("split_probabilities"),
1211 config.get("label_split_counts"),
1212 config.get("split_counts"),
1213 dataset_path_from_desc,
1214 )
1215
914 config_html = "" 1216 config_html = ""
915 training_progress = self.get_training_process(output_dir) 1217 training_progress = self.get_training_process(output_dir)
916 try: 1218 try:
917 config_html = format_config_table_html( 1219 config_html = format_config_table_html(
918 config_for_summary, split_info, training_progress, output_type 1220 config_for_summary, split_info, training_progress, output_type
935 dir_path: Path, 1237 dir_path: Path,
936 output_type: str = None, 1238 output_type: str = None,
937 exclude_names: Optional[set] = None, 1239 exclude_names: Optional[set] = None,
938 ) -> str: 1240 ) -> str:
939 if not dir_path.exists(): 1241 if not dir_path.exists():
940 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" 1242 return ""
941 1243
942 exclude_names = exclude_names or set() 1244 exclude_names = exclude_names or set()
943 1245
944 imgs = list(dir_path.glob("*.png")) 1246 # Search recursively because Ludwig can nest figures under per-feature folders
1247 imgs = list(dir_path.rglob("*.png"))
945 1248
946 # Exclude ROC curves and standard confusion matrices (keep only entropy version) 1249 # Exclude ROC curves and standard confusion matrices (keep only entropy version)
947 default_exclude = { 1250 default_exclude = {
948 # "roc_curves.png", # Remove ROC curves from test tab 1251 # "roc_curves.png", # Remove ROC curves from test tab
949 "confusion_matrix__label_top5.png", # Remove standard confusion matrix 1252 "confusion_matrix__label_top5.png", # Remove standard confusion matrix
981 and "label" in img.stem 1284 and "label" in img.stem
982 ) 1285 )
983 ] 1286 ]
984 1287
985 if not imgs: 1288 if not imgs:
986 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" 1289 return ""
987 1290
988 # Sort images by name for consistent ordering (works with string and numeric labels) 1291 # Sort images by name for consistent ordering (works with string and numeric labels)
989 imgs = sorted(imgs, key=lambda x: x.name) 1292 imgs = sorted(imgs, key=lambda x: x.name)
990 1293
991 html_section = "" 1294 html_section = ""
1004 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' 1307 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />'
1005 f"</div>" 1308 f"</div>"
1006 ) 1309 )
1007 return html_section 1310 return html_section
1008 1311
1009 # Show performance first, then config 1312 # Show dataset overview, performance first, then config
1010 tab1_content = metrics_html + config_html 1313 predictions_csv_path = exp_dir / "predictions.csv"
1011 1314
1012 tab2_content = train_val_metrics_html + render_img_section( 1315 tab1_content = dataset_overview_html + metrics_html + config_html
1013 "Training and Validation Visualizations", 1316
1014 train_viz_dir, 1317 tab2_content = train_val_metrics_html
1015 output_type, 1318 # Preload binary threshold plot so it appears first in Train/Val tab
1016 exclude_names={ 1319 threshold_plot = None
1017 "compare_classifiers_performance_from_prob.png", 1320 threshold_value = (
1018 "roc_curves_from_prediction_statistics.png", 1321 config_for_summary.get("threshold")
1019 "precision_recall_curves_from_prediction_statistics.png", 1322 if config_for_summary.get("threshold") is not None
1020 "precision_recall_curve.png", 1323 else config.get("threshold")
1021 },
1022 ) 1324 )
1325 if threshold_value is None and output_type == "binary":
1326 threshold_value = 0.5
1327 if output_type == "binary" and predictions_csv_path.exists():
1328 try:
1329 threshold_plot = build_binary_threshold_plot(
1330 str(predictions_csv_path),
1331 label_data_path=str(config.get("label_column_data_path"))
1332 if config.get("label_column_data_path")
1333 else None,
1334 split_value=1,
1335 )
1336 except Exception as e:
1337 logger.warning(f"Could not generate validation threshold plot: {e}")
1338
1023 if train_stats_path.exists(): 1339 if train_stats_path.exists():
1024 try: 1340 try:
1025 if output_type == "regression": 1341 if output_type == "regression":
1026 tv_plots = build_regression_train_val_plots(str(train_stats_path)) 1342 tv_plots = build_regression_train_val_plots(str(train_stats_path))
1343 tab2_content = append_plot_blocks(tab2_content, tv_plots)
1027 else: 1344 else:
1028 tv_plots = build_train_validation_plots(str(train_stats_path)) 1345 tv_plots = build_train_validation_plots(str(train_stats_path))
1029 for plot in tv_plots: 1346 # Add threshold plot first, then other train/val plots
1030 tab2_content += ( 1347 if threshold_plot:
1031 f"<h2 style='text-align: center;'>{plot['title']}</h2>" 1348 tab2_content = append_plot_blocks(tab2_content, [threshold_plot])
1032 f"<div class='plotly-center'>{plot['html']}</div>" 1349 # Only append once; avoid duplicates if added elsewhere
1033 ) 1350 threshold_plot = None
1034 if tv_plots: 1351 tab2_content = append_plot_blocks(tab2_content, tv_plots)
1035 logger.info(f"Generated {len(tv_plots)} train/val diagnostic plots") 1352 if threshold_plot or tv_plots:
1353 logger.info(
1354 f"Added {len(tv_plots) + (1 if threshold_plot else 0)} train/val diagnostic plots"
1355 )
1036 except Exception as e: 1356 except Exception as e:
1037 logger.warning(f"Could not generate train/val plots: {e}") 1357 logger.warning(f"Could not generate train/val plots: {e}")
1358
1359 # Only include training PNGs for regression; classification is handled by filtered Plotly plots
1360 if output_type == "regression":
1361 tab2_content += render_img_section(
1362 "Training and Validation Visualizations",
1363 train_viz_dir,
1364 output_type,
1365 exclude_names={
1366 "compare_classifiers_performance_from_prob.png",
1367 "roc_curves_from_prediction_statistics.png",
1368 "precision_recall_curves_from_prediction_statistics.png",
1369 "precision_recall_curve.png",
1370 },
1371 )
1372
1373 # Validation diagnostics (calibration/threshold) from predictions.csv, using split=1
1374 if output_type in ("binary", "category") and predictions_csv_path.exists():
1375 try:
1376 val_diag_plots = build_prediction_diagnostics(
1377 str(predictions_csv_path),
1378 label_data_path=str(config.get("label_column_data_path"))
1379 if config.get("label_column_data_path")
1380 else None,
1381 split_value=1,
1382 )
1383 val_conf_plots = [
1384 p for p in val_diag_plots if "Prediction Confidence Distribution" in p.get("title", "")
1385 ]
1386 tab2_content = append_plot_blocks(
1387 tab2_content, val_conf_plots, " (Validation)"
1388 )
1389 except Exception as e:
1390 logger.warning(f"Could not generate validation diagnostics: {e}")
1038 1391
1039 # --- Predictions vs Ground Truth table (REGRESSION ONLY) --- 1392 # --- Predictions vs Ground Truth table (REGRESSION ONLY) ---
1040 preds_section = "" 1393 preds_section = ""
1041 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME 1394 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
1042 if output_type == "regression" and parquet_path.exists(): 1395 if output_type == "regression" and parquet_path.exists():
1075 ) 1428 )
1076 except Exception as e: 1429 except Exception as e:
1077 logger.warning(f"Could not build Predictions vs GT table: {e}") 1430 logger.warning(f"Could not build Predictions vs GT table: {e}")
1078 1431
1079 tab3_content = test_metrics_html + preds_section 1432 tab3_content = test_metrics_html + preds_section
1080 test_plotly_added = False
1081 1433
1082 if output_type == "regression" and train_stats_path.exists(): 1434 if output_type == "regression" and train_stats_path.exists():
1083 try: 1435 try:
1084 test_plots = build_regression_test_plots(str(train_stats_path)) 1436 test_plots = build_regression_test_plots(str(train_stats_path))
1085 for plot in test_plots: 1437 tab3_content = append_plot_blocks(tab3_content, test_plots)
1086 tab3_content += (
1087 f"<h2 style='text-align: center;'>{plot['title']}</h2>"
1088 f"<div class='plotly-center'>{plot['html']}</div>"
1089 )
1090 if test_plots: 1438 if test_plots:
1091 test_plotly_added = True
1092 logger.info(f"Generated {len(test_plots)} regression test plots") 1439 logger.info(f"Generated {len(test_plots)} regression test plots")
1093 except Exception as e: 1440 except Exception as e:
1094 logger.warning(f"Could not generate regression test plots: {e}") 1441 logger.warning(f"Could not generate regression test plots: {e}")
1095 1442
1096 if output_type in ("binary", "category") and test_stats_path.exists(): 1443 if output_type in ("binary", "category") and test_stats_path.exists():
1102 if label_metadata_path and label_metadata_path.exists() 1449 if label_metadata_path and label_metadata_path.exists()
1103 else None, 1450 else None,
1104 train_set_metadata_path=str(train_set_metadata_path) 1451 train_set_metadata_path=str(train_set_metadata_path)
1105 if train_set_metadata_path.exists() 1452 if train_set_metadata_path.exists()
1106 else None, 1453 else None,
1107 ) 1454 threshold=threshold_value,
1108 for plot in interactive_plots: 1455 )
1109 tab3_content += ( 1456 tab3_content = append_plot_blocks(tab3_content, interactive_plots)
1110 f"<h2 style='text-align: center;'>{plot['title']}</h2>"
1111 f"<div class='plotly-center'>{plot['html']}</div>"
1112 )
1113 if interactive_plots: 1457 if interactive_plots:
1114 test_plotly_added = True 1458 logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots")
1115 logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots")
1116 except Exception as e: 1459 except Exception as e:
1117 logger.warning(f"Could not generate Plotly plots: {e}") 1460 logger.warning(f"Could not generate Plotly plots: {e}")
1118 1461
1119 # Add prediction diagnostics from predictions.csv 1462 # Multi-class transparency plots from test stats (replace ROC/PR for multi-class)
1120 predictions_csv_path = exp_dir / "predictions.csv" 1463 if output_type == "category" and test_stats_path.exists():
1121 try: 1464 try:
1122 diag_plots = build_prediction_diagnostics( 1465 multi_curves = build_multiclass_metric_plots(str(test_stats_path))
1123 str(predictions_csv_path), 1466 tab3_content = append_plot_blocks(tab3_content, multi_curves)
1124 label_data_path=str(config.get("label_column_data_path")) 1467 if multi_curves:
1125 if config.get("label_column_data_path") 1468 logger.info("Added multi-class per-class metric plots to test tab")
1126 else None, 1469 except Exception as e:
1127 threshold=config.get("threshold"), 1470 logger.warning(f"Could not generate multi-class metric plots: {e}")
1128 ) 1471
1129 for plot in diag_plots: 1472 # Test diagnostics (confidence histogram) from predictions.csv, using split=2
1130 tab3_content += ( 1473 if predictions_csv_path.exists():
1131 f"<h2 style='text-align: center;'>{plot['title']}</h2>" 1474 try:
1132 f"<div class='plotly-center'>{plot['html']}</div>" 1475 test_diag_plots = build_prediction_diagnostics(
1476 str(predictions_csv_path),
1477 label_data_path=str(config.get("label_column_data_path"))
1478 if config.get("label_column_data_path")
1479 else None,
1480 split_value=2,
1133 ) 1481 )
1134 if diag_plots: 1482 test_conf_plots = [
1135 test_plotly_added = True 1483 p for p in test_diag_plots if "Prediction Confidence Distribution" in p.get("title", "")
1136 logger.info(f"Generated {len(diag_plots)} prediction diagnostic plots") 1484 ]
1137 except Exception as e: 1485 if test_conf_plots:
1138 logger.warning(f"Could not generate prediction diagnostics: {e}") 1486 tab3_content = append_plot_blocks(tab3_content, test_conf_plots)
1139 1487 logger.info("Added test prediction confidence plot")
1140 # Fallback: include static PNGs if no interactive plots were added 1488 except Exception as e:
1141 if not test_plotly_added: 1489 logger.warning(f"Could not generate test diagnostics: {e}")
1142 tab3_content += render_img_section(
1143 "Test Visualizations (PNG fallback)",
1144 test_viz_dir,
1145 output_type,
1146 )
1147 1490
1148 # Add static TEST PNGs (with default dedupe/exclusions) 1491 # Add static TEST PNGs (with default dedupe/exclusions)
1149 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) 1492 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
1150 modal_html = get_metrics_help_modal() 1493 modal_html = get_metrics_help_modal()
1151 html += tabbed_html + modal_html + get_html_closing() 1494 html += tabbed_html + modal_html + get_html_closing()