Mercurial > repos > goeckslab > image_learner
changeset 17:db9be962dc13 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
| author | goeckslab |
|---|---|
| date | Wed, 10 Dec 2025 00:24:13 +0000 |
| parents | 8729f69e9207 |
| children | |
| files | html_structure.py image_learner.xml image_workflow.py ludwig_backend.py plotly_plots.py |
| diffstat | 5 files changed, 1322 insertions(+), 475 deletions(-) [+] |
line wrap: on
line diff
--- a/html_structure.py Wed Dec 03 01:28:52 2025 +0000 +++ b/html_structure.py Wed Dec 10 00:24:13 2025 +0000 @@ -1,6 +1,6 @@ import base64 import json -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from constants import METRIC_DISPLAY_NAMES from utils import detect_output_type, extract_metrics_from_json @@ -23,6 +23,7 @@ ) -> str: display_keys = [ "architecture", + "image_size", "pretrained", "trainable", "target_column", @@ -58,6 +59,15 @@ else: if key == "task_type": val_str = val.title() if isinstance(val, str) else "N/A" + elif key == "image_size": + if val is None: + val_str = "N/A" + elif isinstance(val, (list, tuple)) and len(val) == 2: + val_str = f"{val[0]}x{val[1]}" + elif isinstance(val, str) and val.lower() == "original": + val_str = "Original (no resize)" + else: + val_str = str(val) elif key == "batch_size": if isinstance(val, (int, float)): val_str = int(val) @@ -115,6 +125,11 @@ "Ludwig Trainer Parameters</a> for details." "</span>" ) + elif key == "validation_metric": + if val is not None: + val_str = METRIC_DISPLAY_NAMES.get(str(val), str(val)) + else: + val_str = "N/A" elif key == "epochs": if val is None: val_str = "N/A" @@ -729,6 +744,64 @@ ) return modal_html + modal_js + +def format_dataset_overview_table(rows: List[Dict[str, Any]], regression_mode: bool = False) -> str: + """Render a dataset overview table. + + - Classification: per-label distribution across train/val/test. + - Regression: split counts (train/val/test). + """ + heading = "<h2 style='text-align: center;'>Dataset Overview</h2>" + if not rows: + return heading + "<p style='text-align: center; color: #666;'>Dataset overview unavailable.</p><br>" + + if regression_mode: + headers = ["Split", "Count"] + html = ( + heading + + "<div style='display: flex; justify-content: center;'>" + + "<table class='performance-summary' style='border-collapse: collapse; table-layout: fixed;'>" + + "<thead><tr>" + + "".join( + f"<th style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>{h}</th>" + for h in headers + ) + + "</tr></thead><tbody>" + ) + for row in rows: + html += generate_table_row( + [ + row.get("split", "N/A"), + row.get("count", 0), + ], + "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", + ) + html += "</tbody></table></div><br>" + else: + html = ( + heading + + "<div style='display: flex; justify-content: center;'>" + + "<table class='performance-summary' style='border-collapse: collapse; table-layout: fixed;'>" + + "<thead><tr>" + + "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Label</th>" + + "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th>" + + "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th>" + + "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th>" + + "</tr></thead><tbody>" + ) + for row in rows: + html += generate_table_row( + [ + row.get("label", "N/A"), + row.get("train", 0), + row.get("validation", 0), + row.get("test", 0), + ], + "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", + ) + html += "</tbody></table></div><br>" + return html + # ----------------------------------------- # MODEL PERFORMANCE (Train/Val/Test) TABLE # -----------------------------------------
--- a/image_learner.xml Wed Dec 03 01:28:52 2025 +0000 +++ b/image_learner.xml Wed Dec 10 00:24:13 2025 +0000 @@ -130,13 +130,9 @@ </when> <when value="regression"> <param name="validation_metric_regression" type="select" optional="true" label="Validation metric (regression)" help="Metrics accepted by Ludwig for regression outputs."> - <option value="pearson_r" selected="true">Pearson r</option> - <option value="mae">MAE</option> + <option value="mae" selected="true">MAE</option> <option value="mse">MSE</option> <option value="rmse">RMSE</option> - <option value="mape">MAPE</option> - <option value="r2">R²</option> - <option value="explained_variance">Explained Variance</option> <option value="loss">Loss</option> </param> </when>
--- a/image_workflow.py Wed Dec 03 01:28:52 2025 +0000 +++ b/image_workflow.py Wed Dec 10 00:24:13 2025 +0000 @@ -5,7 +5,7 @@ import tempfile import zipfile from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import pandas as pd import pandas.api.types as ptypes @@ -35,6 +35,8 @@ self.image_extract_dir: Optional[Path] = None self.label_metadata: Dict[str, Any] = {} self.output_type_hint: Optional[str] = None + self.label_split_counts: List[Dict[str, int]] = [] + self.split_counts: Dict[int, int] = {} logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}") def _create_temp_dirs(self) -> None: @@ -186,6 +188,34 @@ logger.error("Error saving prepared CSV", exc_info=True) raise + # Capture actual split counts for downstream reporting (avoids heuristic 70/10/20 tables) + try: + split_series = pd.to_numeric(df[SPLIT_COLUMN_NAME], errors="coerce") + split_series = split_series.dropna().astype(int) + self.split_counts = {int(k): int(v) for k, v in split_series.value_counts().to_dict().items()} + if LABEL_COLUMN_NAME in df.columns: + counts = ( + df.dropna(subset=[LABEL_COLUMN_NAME]) + .assign(**{SPLIT_COLUMN_NAME: split_series}) + .groupby([LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME]) + .size() + .unstack(fill_value=0) + .sort_index() + ) + self.label_split_counts = [ + { + "label": str(lbl), + "train": int(row.get(0, 0)), + "validation": int(row.get(1, 0)), + "test": int(row.get(2, 0)), + } + for lbl, row in counts.iterrows() + ] + except Exception: + logger.warning("Unable to capture split counts for reporting", exc_info=True) + self.label_split_counts = [] + self.split_counts = {} + self._capture_label_metadata(df) return final_csv, split_config, split_info @@ -349,6 +379,8 @@ "random_seed": self.args.random_seed, "early_stop": self.args.early_stop, "label_column_data_path": csv_path, + "label_split_counts": self.label_split_counts, + "split_counts": self.split_counts, "augmentation": self.args.augmentation, "image_resize": self.args.image_resize, "image_zip": self.args.image_zip,
--- a/ludwig_backend.py Wed Dec 03 01:28:52 2025 +0000 +++ b/ludwig_backend.py Wed Dec 10 00:24:13 2025 +0000 @@ -1,8 +1,9 @@ +import inspect import json import logging import os from pathlib import Path -from typing import Any, Dict, Optional, Protocol, Tuple +from typing import Any, Dict, List, Optional, Protocol, Tuple import pandas as pd import pandas.api.types as ptypes @@ -17,6 +18,7 @@ build_tabbed_html, encode_image_to_base64, format_config_table_html, + format_dataset_overview_table, format_stats_table_html, format_test_merged_stats_table_html, format_train_val_stats_table_html, @@ -33,7 +35,9 @@ from ludwig.utils.data_utils import get_split_path from metaformer_setup import get_visualizations_registry, META_DEFAULT_CFGS from plotly_plots import ( + build_binary_threshold_plot, build_classification_plots, + build_multiclass_metric_plots, build_prediction_diagnostics, build_regression_test_plots, build_regression_train_val_plots, @@ -267,6 +271,23 @@ else: encoder_config = {"type": raw_encoder} + # Set a human-friendly architecture string for reporting + arch_display = None + if is_metaformer and custom_model: + arch_display = str(custom_model) + elif isinstance(raw_encoder, dict): + enc_type = raw_encoder.get("type") + enc_variant = raw_encoder.get("model_variant") + if enc_type: + base = str(enc_type).replace("_", " ").title() + arch_display = f"{base} {enc_variant}" if enc_variant is not None else base + else: + arch_display = str(raw_encoder).replace("_", " ").title() + + if not arch_display: + arch_display = str(model_name) + config_params["architecture"] = arch_display + batch_size_cfg = batch_size or "auto" label_column_path = config_params.get("label_column_data_path") @@ -343,6 +364,7 @@ # Force Ludwig to respect our dimensions by setting additional parameters image_feat["preprocessing"]["requires_equal_dimensions"] = False logger.info(f"Set preprocessing dimensions for MetaFormer: {height}x{width} (infer_dimensions=True with max dimensions to allow validation)") + config_params["image_size"] = f"{height}x{width}" # Now set the encoder configuration image_feat["encoder"] = encoder_config @@ -374,8 +396,12 @@ image_feat["preprocessing"]["infer_image_max_height"] = height image_feat["preprocessing"]["infer_image_max_width"] = width logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions") + config_params["image_size"] = f"{height}x{width}" except (ValueError, IndexError): logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing") + elif not is_metaformer: + # No explicit resize provided; keep for reporting purposes + config_params.setdefault("image_size", "original") def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]: """Pick a validation metric that Ludwig will accept for the resolved task.""" @@ -471,6 +497,9 @@ config_params.get("validation_metric"), ) + # Propagate the resolved validation metric (including any task-based fallback or alias normalization) + config_params["validation_metric"] = val_metric + conf: Dict[str, Any] = { "model_type": "ecd", "input_features": [image_feat], @@ -641,18 +670,62 @@ except Exception as e: logger.error(f"Error converting Parquet to CSV: {e}") - def generate_plots(self, output_dir: Path) -> None: - """Generate all registered Ludwig visualizations for the latest experiment run.""" - logger.info("Generating all Ludwig visualizations…") + @staticmethod + def _extract_metric_series(stats: Dict[str, Any], split: str, prefer: Optional[str] = None) -> Tuple[Optional[str], Optional[List[float]]]: + """Pull the first numeric metric list we can find for the requested split.""" + if not isinstance(stats, dict): + return None, None + + split_stats = stats.get(split, {}) + ordered_metrics: List[Tuple[str, List[float]]] = [] + + def _append_metrics(metric_map: Dict[str, Any]) -> None: + for metric_name, values in metric_map.items(): + if isinstance(values, list) and any(isinstance(v, (int, float)) for v in values): + ordered_metrics.append((metric_name, values)) - # Keep only lightweight plots (drop compare_performance/roc_curves) - test_plots = { - "roc_curves_from_test_statistics", - "confusion_matrix", - } + if isinstance(split_stats, dict): + combined = split_stats.get("combined") + if isinstance(combined, dict): + _append_metrics(combined) + + for feature_name, feature_metrics in split_stats.items(): + if feature_name == "combined" or not isinstance(feature_metrics, dict): + continue + _append_metrics(feature_metrics) + + if prefer: + for metric_name, values in ordered_metrics: + if metric_name == prefer: + return metric_name, values + + return ordered_metrics[0] if ordered_metrics else (None, None) + + def generate_plots(self, output_dir: Path) -> None: + """Generate Ludwig visualizations (train/val + test) for the latest experiment run.""" + logger.info("Generating Ludwig visualizations (train/val + test)…") + + # Train/validation visualizations train_plots = { "learning_curves", - "compare_classifiers_performance_subset", + } + + # Test visualizations (multi-class transparency) + test_plots = { + "confusion_matrix", + "compare_performance", + "compare_classifiers_multiclass_multimetric", + "frequency_vs_f1", + "confidence_thresholding", + "confidence_thresholding_data_vs_acc", + "confidence_thresholding_data_vs_acc_subset", + "confidence_thresholding_data_vs_acc_subset_per_class", + # Binary-only visualizations will still be attempted; multi-class replacements handled elsewhere + "binary_threshold_vs_metric", + "roc_curves", + "precision_recall_curves", + "calibration_1_vs_all", + "calibration_multiclass", } output_dir = Path(output_dir) @@ -677,7 +750,6 @@ training_stats = _check(exp_dir / "training_statistics.json") test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) - probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) dataset_path = None @@ -688,6 +760,9 @@ cfg = json.load(f) dataset_path = _check(Path(cfg.get("dataset", ""))) split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) + model_name = cfg.get("model_name", "model") + else: + model_name = "model" output_feature = "" if desc.exists(): @@ -700,7 +775,44 @@ stats = json.load(f) output_feature = next(iter(stats.keys()), "") + probs_path = None + prob_candidates = [ + exp_dir / f"{LABEL_COLUMN_NAME}_probabilities.csv", + exp_dir / f"{output_feature}_probabilities.csv" if output_feature else None, + exp_dir / "probabilities.csv", + exp_dir / "predictions.csv", + exp_dir / PREDICTIONS_PARQUET_FILE_NAME, + ] + for cand in prob_candidates: + if cand and Path(cand).exists(): + probs_path = str(cand) + break + viz_registry = get_visualizations_registry() + if not viz_registry: + logger.warning( + "Ludwig visualizations registry not available; train/test PNGs will be skipped." + ) + return + + base_kwargs = { + "training_statistics": [training_stats] if training_stats else [], + "test_statistics": [test_stats] if test_stats else [], + "probabilities": [probs_path] if probs_path else [], + "output_feature_name": output_feature, + "ground_truth_split": 2, + "top_n_classes": [20], + "top_k": 3, + "metrics": ["f1", "precision", "recall", "accuracy"], + "positive_label": 0, + "ground_truth_metadata": gt_metadata, + "ground_truth": dataset_path, + "split_file": split_file, + "output_directory": None, # set per plot below + "normalize": False, + "file_format": "png", + "model_names": [model_name], + } for viz_name, viz_func in viz_registry.items(): if viz_name in train_plots: viz_dir_plot = train_viz @@ -710,25 +822,22 @@ continue try: + # Build per-viz kwargs based on the function signature to avoid unexpected args + sig_params = set(inspect.signature(viz_func).parameters.keys()) + call_kwargs = { + k: v + for k, v in base_kwargs.items() + if k in sig_params and v is not None + } + if "output_directory" in sig_params: + call_kwargs["output_directory"] = str(viz_dir_plot) + viz_func( - training_statistics=[training_stats] if training_stats else [], - test_statistics=[test_stats] if test_stats else [], - probabilities=[probs_path] if probs_path else [], - output_feature_name=output_feature, - ground_truth_split=2, - top_n_classes=[0], - top_k=3, - ground_truth_metadata=gt_metadata, - ground_truth=dataset_path, - split_file=split_file, - output_directory=str(viz_dir_plot), - normalize=False, - file_format="png", + **call_kwargs, ) logger.info(f"✔ Generated {viz_name}") except Exception as e: logger.warning(f"✘ Skipped {viz_name}: {e}") - logger.info(f"All visualizations written to {viz_dir}") def generate_html_report( @@ -756,6 +865,7 @@ label_metadata_path = config.get("label_column_data_path") if label_metadata_path: label_metadata_path = Path(label_metadata_path) + dataset_path_from_desc: Optional[Path] = None # Pull additional config details from description.json if available config_for_summary = dict(config) @@ -765,7 +875,8 @@ if desc_path.exists(): try: with open(desc_path, "r") as f: - desc_cfg = json.load(f).get("config", {}) + desc_json = json.load(f) + desc_cfg = desc_json.get("config", {}) if isinstance(desc_json, dict) else {} encoder_cfg = ( desc_cfg.get("input_features", [{}])[0].get("encoder", {}) if isinstance(desc_cfg.get("input_features", [{}]), list) @@ -783,10 +894,20 @@ arch_type = encoder_cfg.get("type") arch_variant = encoder_cfg.get("model_variant") + arch_custom = encoder_cfg.get("custom_model") arch_name = None + if arch_custom: + arch_name = str(arch_custom) if arch_type: arch_base = str(arch_type).replace("_", " ").title() - arch_name = f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base + arch_type_name = ( + f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base + ) + # Prefer explicit custom model names (e.g., MetaFormer) but fall back to encoder type + arch_name = arch_name or arch_type_name + if not arch_name and config.get("model_name"): + # As a last resort, show the user-selected model name (handles custom/MetaFormer cases) + arch_name = str(config.get("model_name")) summary_fields = { "architecture": arch_name, @@ -814,12 +935,22 @@ if k in {"target_column", "image_column"} and config_for_summary.get(k): continue config_for_summary.setdefault(k, v) + + dataset_field = None + if isinstance(desc_json, dict): + dataset_field = desc_json.get("dataset") or desc_cfg.get("dataset") + if dataset_field: + try: + dataset_path_from_desc = Path(dataset_field) + except TypeError: + dataset_path_from_desc = None + if dataset_path_from_desc and (not label_metadata_path or not label_metadata_path.exists()): + label_metadata_path = dataset_path_from_desc except Exception as e: # pragma: no cover - defensive logger.warning(f"Could not merge description.json into config summary: {e}") base_viz_dir = exp_dir / "visualizations" train_viz_dir = base_viz_dir / "train" - test_viz_dir = base_viz_dir / "test" html = get_html_template() @@ -880,10 +1011,164 @@ }); } }); -</script> + </script> """ html += f"<h1>{title}</h1>" + def append_plot_blocks(tab_html: str, plots: List[Dict[str, str]], title_suffix: str = "") -> str: + """Append Plotly blocks to a tab with consistent markup.""" + if not plots: + return tab_html + suffix = title_suffix or "" + for plot in plots: + tab_html += ( + f"<h2 style='text-align: center;'>{plot['title']}{suffix}</h2>" + f"<div class='plotly-center'>{plot['html']}</div>" + ) + return tab_html + + def build_dataset_overview( + label_metadata: Optional[Path], + output_type: Optional[str], + split_probabilities: Optional[List[float]], + label_split_counts: Optional[List[Dict[str, int]]] = None, + split_counts: Optional[Dict[int, int]] = None, + fallback_dataset: Optional[Path] = None, + ) -> str: + """Summarize dataset distribution across splits using the actual split config.""" + if label_split_counts: + # Use the actual counts captured during data prep instead of heuristics. + return format_dataset_overview_table(label_split_counts, regression_mode=False) + + if output_type == "regression" and split_counts: + rows = [ + {"split": "train", "count": int(split_counts.get(0, 0))}, + {"split": "validation", "count": int(split_counts.get(1, 0))}, + {"split": "test", "count": int(split_counts.get(2, 0))}, + ] + return format_dataset_overview_table(rows, regression_mode=True) + + candidate_paths: List[Path] = [] + if label_metadata and label_metadata.exists(): + candidate_paths.append(label_metadata) + if fallback_dataset and fallback_dataset.exists(): + candidate_paths.append(fallback_dataset) + if not candidate_paths: + return format_dataset_overview_table([]) + + def _normalize_split_probabilities( + probs: Optional[List[float]], + ) -> Optional[List[float]]: + if not probs or len(probs) != 3: + return None + try: + probs = [float(p) for p in probs] + except (TypeError, ValueError): + return None + total = sum(probs) + if total <= 0: + return None + return [p / total for p in probs] + + def _split_counts_from_column(df: pd.DataFrame) -> Dict[int, int]: + if SPLIT_COLUMN_NAME not in df.columns: + return {} + split_series = pd.to_numeric( + df[SPLIT_COLUMN_NAME], errors="coerce" + ).dropna() + if split_series.empty: + return {} + split_series = split_series.astype(int) + return split_series.value_counts().to_dict() + + def _split_counts_from_probs(total: int, probs: List[float]) -> Dict[int, int]: + train_n = int(total * probs[0]) + val_n = int(total * probs[1]) + test_n = max(0, total - train_n - val_n) + return {0: train_n, 1: val_n, 2: test_n} + + fallback_rows: Optional[List[Dict[str, int]]] = None + for meta_path in candidate_paths: + try: + df_labels = pd.read_csv(meta_path) + probs = _normalize_split_probabilities(split_probabilities) + + # Regression (or missing label column): only need split counts + if output_type == "regression" or LABEL_COLUMN_NAME not in df_labels.columns: + split_counts_found = _split_counts_from_column(df_labels) + if split_counts_found: + rows = [ + {"split": "train", "count": int(split_counts_found.get(0, 0))}, + {"split": "validation", "count": int(split_counts_found.get(1, 0))}, + {"split": "test", "count": int(split_counts_found.get(2, 0))}, + ] + return format_dataset_overview_table(rows, regression_mode=True) + if probs and fallback_rows is None: + split_counts_found = _split_counts_from_probs(len(df_labels), probs) + fallback_rows = [ + {"split": "train", "count": int(split_counts_found.get(0, 0))}, + {"split": "validation", "count": int(split_counts_found.get(1, 0))}, + {"split": "test", "count": int(split_counts_found.get(2, 0))}, + ] + continue + + # Classification: prefer actual split assignments; fall back to configured probabilities + if SPLIT_COLUMN_NAME in df_labels.columns: + df_counts = df_labels[[LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME]].copy() + df_counts[SPLIT_COLUMN_NAME] = pd.to_numeric( + df_counts[SPLIT_COLUMN_NAME], errors="coerce" + ) + df_counts = df_counts.dropna(subset=[SPLIT_COLUMN_NAME]) + if df_counts.empty: + continue + + df_counts[SPLIT_COLUMN_NAME] = df_counts[SPLIT_COLUMN_NAME].astype(int) + df_counts = df_counts.dropna(subset=[LABEL_COLUMN_NAME]) + counts = ( + df_counts.groupby([LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME]) + .size() + .unstack(fill_value=0) + .sort_index() + ) + rows = [] + for lbl, row in counts.iterrows(): + rows.append( + { + "label": str(lbl), + "train": int(row.get(0, 0)), + "validation": int(row.get(1, 0)), + "test": int(row.get(2, 0)), + } + ) + return format_dataset_overview_table(rows) + + if probs: + label_series = df_labels[LABEL_COLUMN_NAME].dropna() + label_counts = label_series.value_counts().sort_index() + if label_counts.empty: + continue + rows = [] + for lbl, count in label_counts.items(): + train_n = int(count * probs[0]) + val_n = int(count * probs[1]) + test_n = max(0, count - train_n - val_n) + rows.append( + { + "label": str(lbl), + "train": train_n, + "validation": val_n, + "test": test_n, + } + ) + fallback_rows = fallback_rows or rows + except Exception as exc: + logger.warning("Failed to build dataset overview from %s: %s", meta_path, exc) + continue + + if fallback_rows: + return format_dataset_overview_table(fallback_rows, regression_mode=output_type == "regression") + return format_dataset_overview_table([]) + metrics_html = "" train_val_metrics_html = "" test_metrics_html = "" @@ -911,6 +1196,23 @@ f"Could not load stats for HTML report: {type(e).__name__}: {e}" ) + if not output_type: + # Fallback to configured task type when stats are unavailable (e.g., failed run). + output_type = ( + str(config_for_summary.get("task_type")).lower() + if config_for_summary.get("task_type") + else None + ) + + dataset_overview_html = build_dataset_overview( + label_metadata_path, + output_type, + config.get("split_probabilities"), + config.get("label_split_counts"), + config.get("split_counts"), + dataset_path_from_desc, + ) + config_html = "" training_progress = self.get_training_process(output_dir) try: @@ -937,11 +1239,12 @@ exclude_names: Optional[set] = None, ) -> str: if not dir_path.exists(): - return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" + return "" exclude_names = exclude_names or set() - imgs = list(dir_path.glob("*.png")) + # Search recursively because Ludwig can nest figures under per-feature folders + imgs = list(dir_path.rglob("*.png")) # Exclude ROC curves and standard confusion matrices (keep only entropy version) default_exclude = { @@ -983,7 +1286,7 @@ ] if not imgs: - return f"<h2>{title}</h2><p><em>No plots found.</em></p>" + return "" # Sort images by name for consistent ordering (works with string and numeric labels) imgs = sorted(imgs, key=lambda x: x.name) @@ -1006,36 +1309,86 @@ ) return html_section - # Show performance first, then config - tab1_content = metrics_html + config_html + # Show dataset overview, performance first, then config + predictions_csv_path = exp_dir / "predictions.csv" + + tab1_content = dataset_overview_html + metrics_html + config_html - tab2_content = train_val_metrics_html + render_img_section( - "Training and Validation Visualizations", - train_viz_dir, - output_type, - exclude_names={ - "compare_classifiers_performance_from_prob.png", - "roc_curves_from_prediction_statistics.png", - "precision_recall_curves_from_prediction_statistics.png", - "precision_recall_curve.png", - }, + tab2_content = train_val_metrics_html + # Preload binary threshold plot so it appears first in Train/Val tab + threshold_plot = None + threshold_value = ( + config_for_summary.get("threshold") + if config_for_summary.get("threshold") is not None + else config.get("threshold") ) + if threshold_value is None and output_type == "binary": + threshold_value = 0.5 + if output_type == "binary" and predictions_csv_path.exists(): + try: + threshold_plot = build_binary_threshold_plot( + str(predictions_csv_path), + label_data_path=str(config.get("label_column_data_path")) + if config.get("label_column_data_path") + else None, + split_value=1, + ) + except Exception as e: + logger.warning(f"Could not generate validation threshold plot: {e}") + if train_stats_path.exists(): try: if output_type == "regression": tv_plots = build_regression_train_val_plots(str(train_stats_path)) + tab2_content = append_plot_blocks(tab2_content, tv_plots) else: tv_plots = build_train_validation_plots(str(train_stats_path)) - for plot in tv_plots: - tab2_content += ( - f"<h2 style='text-align: center;'>{plot['title']}</h2>" - f"<div class='plotly-center'>{plot['html']}</div>" - ) - if tv_plots: - logger.info(f"Generated {len(tv_plots)} train/val diagnostic plots") + # Add threshold plot first, then other train/val plots + if threshold_plot: + tab2_content = append_plot_blocks(tab2_content, [threshold_plot]) + # Only append once; avoid duplicates if added elsewhere + threshold_plot = None + tab2_content = append_plot_blocks(tab2_content, tv_plots) + if threshold_plot or tv_plots: + logger.info( + f"Added {len(tv_plots) + (1 if threshold_plot else 0)} train/val diagnostic plots" + ) except Exception as e: logger.warning(f"Could not generate train/val plots: {e}") + # Only include training PNGs for regression; classification is handled by filtered Plotly plots + if output_type == "regression": + tab2_content += render_img_section( + "Training and Validation Visualizations", + train_viz_dir, + output_type, + exclude_names={ + "compare_classifiers_performance_from_prob.png", + "roc_curves_from_prediction_statistics.png", + "precision_recall_curves_from_prediction_statistics.png", + "precision_recall_curve.png", + }, + ) + + # Validation diagnostics (calibration/threshold) from predictions.csv, using split=1 + if output_type in ("binary", "category") and predictions_csv_path.exists(): + try: + val_diag_plots = build_prediction_diagnostics( + str(predictions_csv_path), + label_data_path=str(config.get("label_column_data_path")) + if config.get("label_column_data_path") + else None, + split_value=1, + ) + val_conf_plots = [ + p for p in val_diag_plots if "Prediction Confidence Distribution" in p.get("title", "") + ] + tab2_content = append_plot_blocks( + tab2_content, val_conf_plots, " (Validation)" + ) + except Exception as e: + logger.warning(f"Could not generate validation diagnostics: {e}") + # --- Predictions vs Ground Truth table (REGRESSION ONLY) --- preds_section = "" parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME @@ -1077,18 +1430,12 @@ logger.warning(f"Could not build Predictions vs GT table: {e}") tab3_content = test_metrics_html + preds_section - test_plotly_added = False if output_type == "regression" and train_stats_path.exists(): try: test_plots = build_regression_test_plots(str(train_stats_path)) - for plot in test_plots: - tab3_content += ( - f"<h2 style='text-align: center;'>{plot['title']}</h2>" - f"<div class='plotly-center'>{plot['html']}</div>" - ) + tab3_content = append_plot_blocks(tab3_content, test_plots) if test_plots: - test_plotly_added = True logger.info(f"Generated {len(test_plots)} regression test plots") except Exception as e: logger.warning(f"Could not generate regression test plots: {e}") @@ -1104,46 +1451,42 @@ train_set_metadata_path=str(train_set_metadata_path) if train_set_metadata_path.exists() else None, + threshold=threshold_value, ) - for plot in interactive_plots: - tab3_content += ( - f"<h2 style='text-align: center;'>{plot['title']}</h2>" - f"<div class='plotly-center'>{plot['html']}</div>" - ) + tab3_content = append_plot_blocks(tab3_content, interactive_plots) if interactive_plots: - test_plotly_added = True - logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots") + logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots") except Exception as e: logger.warning(f"Could not generate Plotly plots: {e}") - # Add prediction diagnostics from predictions.csv - predictions_csv_path = exp_dir / "predictions.csv" - try: - diag_plots = build_prediction_diagnostics( - str(predictions_csv_path), - label_data_path=str(config.get("label_column_data_path")) - if config.get("label_column_data_path") - else None, - threshold=config.get("threshold"), - ) - for plot in diag_plots: - tab3_content += ( - f"<h2 style='text-align: center;'>{plot['title']}</h2>" - f"<div class='plotly-center'>{plot['html']}</div>" + # Multi-class transparency plots from test stats (replace ROC/PR for multi-class) + if output_type == "category" and test_stats_path.exists(): + try: + multi_curves = build_multiclass_metric_plots(str(test_stats_path)) + tab3_content = append_plot_blocks(tab3_content, multi_curves) + if multi_curves: + logger.info("Added multi-class per-class metric plots to test tab") + except Exception as e: + logger.warning(f"Could not generate multi-class metric plots: {e}") + + # Test diagnostics (confidence histogram) from predictions.csv, using split=2 + if predictions_csv_path.exists(): + try: + test_diag_plots = build_prediction_diagnostics( + str(predictions_csv_path), + label_data_path=str(config.get("label_column_data_path")) + if config.get("label_column_data_path") + else None, + split_value=2, ) - if diag_plots: - test_plotly_added = True - logger.info(f"Generated {len(diag_plots)} prediction diagnostic plots") - except Exception as e: - logger.warning(f"Could not generate prediction diagnostics: {e}") - - # Fallback: include static PNGs if no interactive plots were added - if not test_plotly_added: - tab3_content += render_img_section( - "Test Visualizations (PNG fallback)", - test_viz_dir, - output_type, - ) + test_conf_plots = [ + p for p in test_diag_plots if "Prediction Confidence Distribution" in p.get("title", "") + ] + if test_conf_plots: + tab3_content = append_plot_blocks(tab3_content, test_conf_plots) + logger.info("Added test prediction confidence plot") + except Exception as e: + logger.warning(f"Could not generate test diagnostics: {e}") # Add static TEST PNGs (with default dedupe/exclusions) tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
--- a/plotly_plots.py Wed Dec 03 01:28:52 2025 +0000 +++ b/plotly_plots.py Wed Dec 10 00:24:13 2025 +0000 @@ -7,6 +7,17 @@ import plotly.graph_objects as go import plotly.io as pio from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME +from sklearn.metrics import ( + accuracy_score, + auc, + average_precision_score, + f1_score, + precision_recall_curve, + precision_score, + recall_score, + roc_curve, +) +from sklearn.preprocessing import label_binarize def _style_fig(fig: go.Figure, font_size: int = 12) -> go.Figure: @@ -21,6 +32,64 @@ return fig +def _fig_to_html( + fig: go.Figure, *, include_js: bool = False, config: Optional[dict] = None +) -> str: + """Render a Plotly figure to a lightweight HTML fragment.""" + include_plotlyjs = "cdn" if include_js else False + return pio.to_html( + fig, + full_html=False, + include_plotlyjs=include_plotlyjs, + config=config, + ) + + +def _wrap_plot( + title: str, + fig: go.Figure, + *, + include_js: bool = False, + config: Optional[dict] = None, +) -> Dict[str, str]: + """Package a figure with its title for downstream HTML rendering.""" + return {"title": title, "html": _fig_to_html(fig, include_js=include_js, config=config)} + + +def _line_chart( + traces: List[tuple], + *, + title: str, + yaxis_title: str, +) -> go.Figure: + """Build a basic epoch-indexed line chart for train/val/test curves.""" + fig = go.Figure() + for name, series in traces: + if not series: + continue + epochs = list(range(1, len(series) + 1)) + fig.add_trace( + go.Scatter( + x=epochs, + y=series, + mode="lines+markers", + name=name, + line=dict(width=4), + ) + ) + + fig.update_layout( + title=dict(text=title, x=0.5), + xaxis_title="Epoch", + yaxis_title=yaxis_title, + width=760, + height=520, + hovermode="x unified", + ) + _style_fig(fig) + return fig + + def _labels_from_metadata_dict(meta_dict: dict) -> List[str]: """Extract ordered label names from Ludwig train_set_metadata.""" if not isinstance(meta_dict, dict): @@ -106,6 +175,7 @@ training_stats_path: Optional[str] = None, metadata_csv_path: Optional[str] = None, train_set_metadata_path: Optional[str] = None, + threshold: Optional[float] = None, ) -> List[Dict[str, str]]: """ Read Ludwig’s test_statistics.json and build three interactive Plotly panels: @@ -156,8 +226,11 @@ ) ) fig_cm.update_traces(xgap=2, ygap=2) + cm_title = "Confusion Matrix" + if threshold is not None: + cm_title = f"Confusion Matrix (Threshold: {threshold})" fig_cm.update_layout( - title=dict(text="Confusion Matrix", x=0.5), + title=dict(text=cm_title, x=0.5), xaxis_title="Predicted", yaxis_title="Observed", yaxis_autorange="reversed", @@ -196,25 +269,19 @@ yshift=-2, ) - plots.append({ - "title": "Confusion Matrix", - "html": pio.to_html( - fig_cm, - full_html=False, - include_plotlyjs="cdn", - config=common_cfg - ) - }) + plots.append( + _wrap_plot("Confusion Matrix", fig_cm, include_js=True, config=common_cfg) + ) - # 1) ROC Curve (from test_statistics) - roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels) - if roc_plot: - plots.append(roc_plot) + # 1) ROC / PR curves only for binary tasks + if n_classes == 2: + roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels) + if roc_plot: + plots.append(roc_plot) - # 2) Precision-Recall Curve (from test_statistics) - pr_plot = _build_precision_recall_plot(label_stats, common_cfg) - if pr_plot: - plots.append(pr_plot) + pr_plot = _build_precision_recall_plot(label_stats, common_cfg) + if pr_plot: + plots.append(pr_plot) # 2) Classification Report Heatmap pcs = label_stats.get("per_class_stats", {}) @@ -259,15 +326,9 @@ margin=dict(t=80, l=80, r=80, b=80), ) _style_fig(fig_cr) - plots.append({ - "title": "Per-Class metrics", - "html": pio.to_html( - fig_cr, - full_html=False, - include_plotlyjs=False, - config=common_cfg - ) - }) + plots.append( + _wrap_plot("Per-Class metrics", fig_cr, config=common_cfg) + ) # 3) Prediction Diagnostics (from predictions.csv) # Note: appended separately in generate_html_report, not returned here. @@ -294,8 +355,6 @@ include_js = True # Load Plotly.js once for this group def _get_series(stats: dict, metric: str) -> List[float]: - if metric not in stats: - return [] vals = stats.get(metric, []) if isinstance(vals, list): return [float(v) for v in vals] @@ -304,181 +363,98 @@ except Exception: return [] - def _line_plot(metric_key: str, title: str, yaxis_title: str) -> Optional[Dict[str, str]]: - train_series = _get_series(label_train, metric_key) - val_series = _get_series(label_val, metric_key) + metric_specs = [ + ("loss", "Loss across epochs", "Loss"), + ("accuracy", "Accuracy across epochs", "Accuracy"), + ("roc_auc", "ROC-AUC across epochs", "ROC-AUC"), + ("precision", "Precision across epochs", "Precision"), + ("recall", "Recall/Sensitivity across epochs", "Recall"), + ("specificity", "Specificity across epochs", "Specificity"), + ] + + for key, title, yaxis in metric_specs: + train_series = _get_series(label_train, key) + val_series = _get_series(label_val, key) if not train_series and not val_series: - return None - epochs_train = list(range(1, len(train_series) + 1)) - epochs_val = list(range(1, len(val_series) + 1)) - fig = go.Figure() - if train_series: - fig.add_trace( - go.Scatter( - x=epochs_train, - y=train_series, - mode="lines+markers", - name="Train", - line=dict(width=4), - ) - ) - if val_series: - fig.add_trace( - go.Scatter( - x=epochs_val, - y=val_series, - mode="lines+markers", - name="Validation", - line=dict(width=4), - ) - ) - fig.update_layout( - title=dict(text=title, x=0.5), - xaxis_title="Epoch", - yaxis_title=yaxis_title, - width=760, - height=520, - hovermode="x unified", + continue + fig = _line_chart( + [("Train", train_series), ("Validation", val_series)], + title=title, + yaxis_title=yaxis, ) - _style_fig(fig) - return { - "title": title, - "html": pio.to_html( - fig, - full_html=False, - include_plotlyjs="cdn" if include_js else False, - ), - } - - # Core learning curves - for key, title in [ - ("roc_auc", "ROC-AUC across epochs"), - ("precision", "Precision across epochs"), - ("recall", "Recall/Sensitivity across epochs"), - ("specificity", "Specificity across epochs"), - ]: - plot = _line_plot(key, title, title.replace("Learning Curve", "").strip()) - if plot: - plots.append(plot) - include_js = False + plots.append(_wrap_plot(title, fig, include_js=include_js)) + include_js = False # Precision vs Recall evolution (validation) val_prec = _get_series(label_val, "precision") val_rec = _get_series(label_val, "recall") if val_prec and val_rec: - epochs = list(range(1, min(len(val_prec), len(val_rec)) + 1)) - fig_pr = go.Figure() - fig_pr.add_trace( - go.Scatter( - x=epochs, - y=val_prec[: len(epochs)], - mode="lines+markers", - name="Precision", - ) - ) - fig_pr.add_trace( - go.Scatter( - x=epochs, - y=val_rec[: len(epochs)], - mode="lines+markers", - name="Recall", - ) + max_len = min(len(val_prec), len(val_rec)) + fig_pr = _line_chart( + [ + ("Precision", val_prec[:max_len]), + ("Recall", val_rec[:max_len]), + ], + title="Validation Precision and Recall by Epoch", + yaxis_title="Value", ) - fig_pr.update_layout( - title=dict(text="Validation Precision and Recall by Epoch", x=0.5), - xaxis_title="Epoch", - yaxis_title="Value", - width=760, - height=520, - hovermode="x unified", - ) - _style_fig(fig_pr) - plots.append({ - "title": "Precision vs Recall Evolution", - "html": pio.to_html( - fig_pr, - full_html=False, - include_plotlyjs="cdn" if include_js else False, - ), - }) + plots.append(_wrap_plot("Precision vs Recall Evolution", fig_pr, include_js=include_js)) include_js = False - # F1-score derived def _compute_f1(p: List[float], r: List[float]) -> List[float]: - f1_vals = [] - for prec, rec in zip(p, r): - if (prec + rec) == 0: - f1_vals.append(0.0) - else: - f1_vals.append(2 * prec * rec / (prec + rec)) - return f1_vals + return [ + 0.0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec) + for prec, rec in zip(p, r) + ] f1_train = _compute_f1(_get_series(label_train, "precision"), _get_series(label_train, "recall")) f1_val = _compute_f1(val_prec, val_rec) if f1_train or f1_val: - fig = go.Figure() - if f1_train: - fig.add_trace(go.Scatter(x=list(range(1, len(f1_train) + 1)), y=f1_train, mode="lines+markers", name="Train", line=dict(width=4))) - if f1_val: - fig.add_trace(go.Scatter(x=list(range(1, len(f1_val) + 1)), y=f1_val, mode="lines+markers", name="Validation", line=dict(width=4))) - fig.update_layout( - title=dict(text="F1-Score across epochs (derived)", x=0.5), - xaxis_title="Epoch", + fig_f1 = _line_chart( + [("Train", f1_train), ("Validation", f1_val)], + title="F1-Score across epochs (derived)", yaxis_title="F1-Score", - width=760, - height=520, - hovermode="x unified", ) - _style_fig(fig) - plots.append({ - "title": "F1-Score across epochs (derived)", - "html": pio.to_html( - fig, - full_html=False, - include_plotlyjs="cdn" if include_js else False, - ), - }) + plots.append(_wrap_plot("F1-Score across epochs (derived)", fig_f1, include_js=include_js)) include_js = False # Overfitting Gap: Train vs Val ROC-AUC (gap) roc_train = _get_series(label_train, "roc_auc") roc_val = _get_series(label_val, "roc_auc") if roc_train and roc_val: - epochs_gap = list(range(1, min(len(roc_train), len(roc_val)) + 1)) - gaps = [t - v for t, v in zip(roc_train[:len(epochs_gap)], roc_val[:len(epochs_gap)])] - fig_gap = go.Figure() - fig_gap.add_trace(go.Scatter(x=epochs_gap, y=gaps, mode="lines+markers", name="Train - Val ROC-AUC", line=dict(width=4))) - fig_gap.update_layout( - title=dict(text="Overfitting gap: ROC-AUC across epochs", x=0.5), - xaxis_title="Epoch", + max_len = min(len(roc_train), len(roc_val)) + gaps = [t - v for t, v in zip(roc_train[:max_len], roc_val[:max_len])] + fig_gap = _line_chart( + [("Train - Val ROC-AUC", gaps)], + title="Overfitting gap: ROC-AUC across epochs", yaxis_title="Gap", - width=760, - height=520, - hovermode="x unified", ) - _style_fig(fig_gap) - plots.append({ - "title": "Overfitting gap: ROC-AUC across epochs", - "html": pio.to_html( - fig_gap, - full_html=False, - include_plotlyjs="cdn" if include_js else False, - ), - }) + plots.append(_wrap_plot("Overfitting gap: ROC-AUC across epochs", fig_gap, include_js=include_js)) include_js = False # Best Epoch Dashboard (based on max val ROC-AUC) if roc_val: best_idx = int(np.argmax(roc_val)) best_epoch = best_idx + 1 - spec_val = _get_series(label_val, "specificity") - metrics_at_best = { - "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None, - "Precision": val_prec[best_idx] if best_idx < len(val_prec) else None, - "Recall": val_rec[best_idx] if best_idx < len(val_rec) else None, - "Specificity": spec_val[best_idx] if best_idx < len(spec_val) else None, - "F1-Score": f1_val[best_idx] if best_idx < len(f1_val) else None, + metrics_at_best: Dict[str, Optional[float]] = { + "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None } + + for metric_key, label in [ + ("accuracy", "Accuracy"), + ("balanced_accuracy", "Balanced Accuracy"), + ("precision", "Precision"), + ("recall", "Recall"), + ("specificity", "Specificity"), + ("loss", "Loss"), + ]: + series = _get_series(label_val, metric_key) + if series and best_idx < len(series): + metrics_at_best[label] = series[best_idx] + + if f1_val and best_idx < len(f1_val): + metrics_at_best["F1-Score (derived)"] = f1_val[best_idx] + fig_best = go.Figure() for name, value in metrics_at_best.items(): if value is not None: @@ -492,15 +468,7 @@ showlegend=False, ) _style_fig(fig_best) - plots.append({ - "title": "Best Validation Epoch Snapshot (Metrics)", - "html": pio.to_html( - fig_best, - full_html=False, - include_plotlyjs="cdn" if include_js else False, - ), - }) - include_js = False + plots.append(_wrap_plot("Best Validation Epoch Snapshot (Metrics)", fig_best, include_js=include_js)) return plots @@ -529,46 +497,13 @@ val_series = _get_regression_series(val_split, metric_key) if not train_series and not val_series: return None - epochs_train = list(range(1, len(train_series) + 1)) - epochs_val = list(range(1, len(val_series) + 1)) - fig = go.Figure() - if train_series: - fig.add_trace( - go.Scatter( - x=epochs_train, - y=train_series, - mode="lines+markers", - name="Train", - line=dict(width=4), - ) - ) - if val_series: - fig.add_trace( - go.Scatter( - x=epochs_val, - y=val_series, - mode="lines+markers", - name="Validation", - line=dict(width=4), - ) - ) - fig.update_layout( - title=dict(text=title, x=0.5), - xaxis_title="Epoch", + + fig = _line_chart( + [("Train", train_series), ("Validation", val_series)], + title=title, yaxis_title=yaxis_title, - width=760, - height=520, - hovermode="x unified", ) - _style_fig(fig) - return { - "title": title, - "html": pio.to_html( - fig, - full_html=False, - include_plotlyjs="cdn" if include_js else False, - ), - } + return _wrap_plot(title, fig, include_js=include_js) def build_regression_train_val_plots(train_stats_path: str) -> List[Dict[str, str]]: @@ -627,46 +562,25 @@ ("r2", "R² Across Epochs", "R²"), ("loss", "Loss Across Epochs", "Loss"), ] - epochs = None for metric_key, title, ytitle in metrics: series = _get_regression_series(label_test, metric_key) if not series: continue - if epochs is None: - epochs = list(range(1, len(series) + 1)) - fig = go.Figure() - fig.add_trace( - go.Scatter( - x=epochs, - y=series[: len(epochs)], - mode="lines+markers", - name="Test", - line=dict(width=4), - ) + fig = _line_chart( + [("Test", series)], + title=title, + yaxis_title=ytitle, ) - fig.update_layout( - title=dict(text=title, x=0.5), - xaxis_title="Epoch", - yaxis_title=ytitle, - width=760, - height=520, - hovermode="x unified", - ) - _style_fig(fig) - plots.append({ - "title": title, - "html": pio.to_html( - fig, - full_html=False, - include_plotlyjs="cdn" if include_js else False, - ), - }) + plots.append(_wrap_plot(title, fig, include_js=include_js)) include_js = False return plots def _build_static_roc_plot( - label_stats: dict, config: dict, friendly_labels: Optional[List[str]] = None + label_stats: dict, + config: dict, + friendly_labels: Optional[List[str]] = None, + threshold: Optional[float] = None, ) -> Optional[Dict[str, str]]: """Build ROC curve directly from test_statistics.json (single curve).""" roc_data = label_stats.get("roc_curve") @@ -776,6 +690,42 @@ fig.update_xaxes(range=[0, 1.0]) fig.update_yaxes(range=[0, 1.05]) + roc_thresholds = roc_data.get("thresholds") + if threshold is not None and isinstance(roc_thresholds, list) and len(roc_thresholds) == len(fpr): + try: + diffs = [abs(th - threshold) for th in roc_thresholds] + best_idx = int(np.argmin(diffs)) + # dashed guides through the chosen point + fig.add_shape( + type="line", + x0=fpr[best_idx], + x1=fpr[best_idx], + y0=0, + y1=tpr[best_idx], + line=dict(color="gray", width=2, dash="dash"), + ) + fig.add_shape( + type="line", + x0=0, + x1=fpr[best_idx], + y0=tpr[best_idx], + y1=tpr[best_idx], + line=dict(color="gray", width=2, dash="dash"), + ) + fig.add_trace( + go.Scatter( + x=[fpr[best_idx]], + y=[tpr[best_idx]], + mode="markers", + marker=dict(color="black", size=10, symbol="x"), + name=f"Threshold={threshold}", + hovertemplate="FPR: %{x:.3f}<br>TPR: %{y:.3f}<br>Threshold: %{text}<extra></extra>", + text=[f"{threshold}"], + ) + ) + except Exception as exc: + print(f"Warning: could not add threshold marker to ROC: {exc}") + fig.add_annotation( x=0.5, y=-0.15, @@ -786,21 +736,17 @@ xanchor="center", ) - return { - "title": "ROC Curve", - "html": pio.to_html( - fig, - full_html=False, - include_plotlyjs=False, - config=config, - ), - } + return _wrap_plot("ROC Curve", fig, config=config) except Exception as e: print(f"Error building ROC plot: {e}") return None -def _build_precision_recall_plot(label_stats: dict, config: dict) -> Optional[Dict[str, str]]: +def _build_precision_recall_plot( + label_stats: dict, + config: dict, + threshold: Optional[float] = None, +) -> Optional[Dict[str, str]]: """Build Precision-Recall curve directly from test_statistics.json.""" pr_data = label_stats.get("precision_recall_curve") if not isinstance(pr_data, dict): @@ -811,6 +757,8 @@ if not precisions or not recalls or len(precisions) != len(recalls): return None + thresholds = pr_data.get("thresholds") + try: fig = go.Figure() fig.add_trace( @@ -851,15 +799,41 @@ fig.update_xaxes(range=[0, 1.0]) fig.update_yaxes(range=[0, 1.05]) - return { - "title": "Precision-Recall Curve", - "html": pio.to_html( - fig, - full_html=False, - include_plotlyjs=False, - config=config, - ), - } + if threshold is not None and isinstance(thresholds, list) and len(thresholds) == len(recalls): + try: + diffs = [abs(th - threshold) for th in thresholds] + best_idx = int(np.argmin(diffs)) + fig.add_shape( + type="line", + x0=recalls[best_idx], + x1=recalls[best_idx], + y0=0, + y1=precisions[best_idx], + line=dict(color="gray", width=2, dash="dash"), + ) + fig.add_shape( + type="line", + x0=0, + x1=recalls[best_idx], + y0=precisions[best_idx], + y1=precisions[best_idx], + line=dict(color="gray", width=2, dash="dash"), + ) + fig.add_trace( + go.Scatter( + x=[recalls[best_idx]], + y=[precisions[best_idx]], + mode="markers", + marker=dict(color="black", size=10, symbol="x"), + name=f"Threshold={threshold}", + hovertemplate="Recall: %{x:.3f}<br>Precision: %{y:.3f}<br>Threshold: %{text}<extra></extra>", + text=[f"{threshold}"], + ) + ) + except Exception as exc: + print(f"Warning: could not add threshold marker to PR: {exc}") + + return _wrap_plot("Precision-Recall Curve", fig, config=config) except Exception as e: print(f"Error building Precision-Recall plot: {e}") return None @@ -869,7 +843,6 @@ predictions_path: str, label_data_path: Optional[str] = None, split_value: int = 2, - threshold: Optional[float] = None, ) -> List[Dict[str, str]]: """Generate diagnostic plots from predictions.csv for classification tasks.""" preds_file = Path(predictions_path) @@ -883,12 +856,89 @@ return [] plots: List[Dict[str, str]] = [] + labels_from_dataset: Optional[pd.Series] = None + + filtered_by_split = False + + # If a split column exists, focus on the requested split (e.g., validation=1, test=2). + # If not, but label_data_path is available and matches row count, use it to filter predictions. + if SPLIT_COLUMN_NAME in df_pred.columns: + df_pred = df_pred[df_pred[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True) + if df_pred.empty: + return [] + filtered_by_split = True + elif label_data_path and Path(label_data_path).exists(): + try: + df_labels_all = pd.read_csv(label_data_path) + if SPLIT_COLUMN_NAME in df_labels_all.columns and len(df_labels_all) == len(df_pred): + split_mask = pd.to_numeric(df_labels_all[SPLIT_COLUMN_NAME], errors="coerce") == split_value + labels_from_dataset = df_labels_all.loc[split_mask, LABEL_COLUMN_NAME].reset_index(drop=True) + df_pred = df_pred.loc[split_mask].reset_index(drop=True) + if df_pred.empty: + return [] + filtered_by_split = True + except Exception as exc: + print(f"Warning: Unable to filter predictions by split from label data: {exc}") + + # Fallback: no split info available. Assume the predictions file is already filtered + # (common for test-only exports) and avoid heuristic slicing that could discard rows. + if not filtered_by_split: + if split_value != 2: + return [] + + def _strip_prob_prefix(col: str) -> str: + if col.startswith("label_probabilities_"): + return col.replace("label_probabilities_", "") + if col.startswith("probabilities_"): + return col.replace("probabilities_", "") + return col + + def _maybe_expand_probabilities_column(df: pd.DataFrame, labels_guess: List[str]) -> List[str]: + """If only a single 'probabilities' column exists (list-like), expand it into per-class columns.""" + if "probabilities" not in df.columns: + return [] + try: + # Parse first non-null entry to infer length + first_val = df["probabilities"].dropna().iloc[0] + parsed = first_val + if isinstance(first_val, str): + parsed = json.loads(first_val) + probs = list(parsed) + n = len(probs) + if n == 0: + return [] + # Build labels: prefer provided guess; otherwise numeric + if labels_guess and len(labels_guess) == n: + labels_use = labels_guess + else: + labels_use = [str(i) for i in range(n)] + # Expand column + for idx, lbl in enumerate(labels_use): + df[f"probabilities_{lbl}"] = df["probabilities"].apply( + lambda v: (json.loads(v)[idx] if isinstance(v, str) else list(v)[idx]) if pd.notnull(v) else np.nan + ) + return [f"probabilities_{lbl}" for lbl in labels_use] + except Exception: + return [] # Identify probability columns prob_cols = [ - c for c in df_pred.columns - if c.startswith("label_probabilities_") and c != "label_probabilities" + c + for c in df_pred.columns + if ( + (c.startswith("label_probabilities_") or c.startswith("probabilities_")) + and c != "label_probabilities" + ) ] + if not prob_cols and "label_probability" in df_pred.columns: + prob_cols = ["label_probability"] + if not prob_cols and "probability" in df_pred.columns: + prob_cols = ["probability"] + if not prob_cols and "prediction_probability" in df_pred.columns: + prob_cols = ["prediction_probability"] + if not prob_cols and "probabilities" in df_pred.columns: + labels_guess = sorted([str(u) for u in pd.unique(df_pred[LABEL_COLUMN_NAME])]) + prob_cols = _maybe_expand_probabilities_column(df_pred, labels_guess) prob_cols_sorted = sorted(prob_cols) def _select_positive_prob(): @@ -897,14 +947,14 @@ # Prefer a column indicating positive/event/true/1 preferred_keys = ("event", "true", "positive", "pos", "1") for col in prob_cols_sorted: - suffix = col.replace("label_probabilities_", "").lower() + suffix = _strip_prob_prefix(col).lower() if any(k in suffix for k in preferred_keys): return col, suffix if len(prob_cols_sorted) == 2: col = prob_cols_sorted[1] - return col, col.replace("label_probabilities_", "") + return col, _strip_prob_prefix(col) col = prob_cols_sorted[0] - return col, col.replace("label_probabilities_", "") + return col, _strip_prob_prefix(col) pos_prob_col, pos_label_hint = _select_positive_prob() pos_prob_series = df_pred[pos_prob_col] if pos_prob_col and pos_prob_col in df_pred else None @@ -920,6 +970,8 @@ # True labels def _extract_labels(): + if labels_from_dataset is not None: + return labels_from_dataset candidates = [ LABEL_COLUMN_NAME, f"{LABEL_COLUMN_NAME}_ground_truth", @@ -975,10 +1027,7 @@ height=500, ) _style_fig(fig_conf) - plots.append({ - "title": "Prediction Confidence Distribution", - "html": pio.to_html(fig_conf, full_html=False, include_plotlyjs=False), - }) + plots.append(_wrap_plot("Prediction Confidence Distribution", fig_conf)) # The remaining plots require true labels and a positive-class probability if labels_series is None or pos_prob_series is None: @@ -1004,116 +1053,470 @@ y_true = (y_true_raw == positive_label).astype(int).values - # Plot 2: Calibration Curve - bins = np.linspace(0.0, 1.0, 11) - bin_ids = np.digitize(y_score, bins, right=True) - bin_centers = [] - frac_positives = [] - for b in range(1, len(bins)): - mask = bin_ids == b - if not np.any(mask): - continue - bin_centers.append(y_score[mask].mean()) - frac_positives.append(y_true[mask].mean()) - if bin_centers and frac_positives: - fig_cal = go.Figure() - fig_cal.add_trace( + # Utility: compute calibration points + def _calibration_points(y_true_bin: np.ndarray, scores: np.ndarray): + bins = np.linspace(0.0, 1.0, 11) + bin_ids = np.digitize(scores, bins, right=True) + bin_centers, frac_positives = [], [] + for b in range(1, len(bins)): + mask = bin_ids == b + if not np.any(mask): + continue + bin_centers.append(scores[mask].mean()) + frac_positives.append(y_true_bin[mask].mean()) + return bin_centers, frac_positives + + # Plot 2: Calibration Curve (multi-class aware; one-vs-rest per label) + label_prob_map = {} + for col in prob_cols_sorted: + if col.startswith("label_probabilities_"): + cls = col.replace("label_probabilities_", "") + label_prob_map[cls] = col + + unique_label_strs = [str(u) for u in unique_labels_list] + if len(label_prob_map) > 1 and len(unique_label_strs) > 2: + # Skip multi-class calibration curve for now (not informative in current report) + pass + else: + # Binary/unknown fallback (previous behavior) + bin_centers, frac_positives = _calibration_points(y_true, y_score) + if bin_centers and frac_positives: + fig_cal = go.Figure() + fig_cal.add_trace( + go.Scatter( + x=bin_centers, + y=frac_positives, + mode="lines+markers", + name="Calibration", + line=dict(color="#2ca02c", width=4), + ) + ) + fig_cal.add_trace( + go.Scatter( + x=[0, 1], + y=[0, 1], + mode="lines", + name="Perfect Calibration", + line=dict(color="gray", width=2, dash="dash"), + ) + ) + fig_cal.update_layout( + title=dict(text="Calibration Curve", x=0.5), + xaxis_title="Predicted probability", + yaxis_title="Observed frequency", + width=700, + height=500, + ) + _style_fig(fig_cal) + plots.append( + _wrap_plot( + "Calibration Curve (Predicted Probability vs Observed Frequency)", + fig_cal, + ) + ) + + return plots + + +def build_binary_threshold_plot( + predictions_path: str, + label_data_path: Optional[str] = None, + split_value: int = 1, +) -> Optional[Dict[str, str]]: + """Build a binary threshold sweep plot (accuracy, precision, recall, F1) for a given split.""" + preds_file = Path(predictions_path) + if not preds_file.exists(): + return None + + try: + df_pred = pd.read_csv(predictions_path) + except Exception as exc: + print(f"Warning: Unable to read predictions CSV for threshold plot: {exc}") + return None + + labels_from_dataset: Optional[pd.Series] = None + df_full = df_pred.copy() + + def _filter_by_split(df: pd.DataFrame, split_val: int) -> pd.DataFrame: + if SPLIT_COLUMN_NAME in df.columns: + return df[df[SPLIT_COLUMN_NAME] == split_val].reset_index(drop=True) + return df + + # Try preferred split, then fallback to others with data (val -> test -> train) + candidate_splits = [split_value, 2, 0, 1] if split_value == 1 else [split_value, 1, 0, 2] + df_candidate = pd.DataFrame() + used_split: Optional[int] = None + for sv in candidate_splits: + df_candidate = _filter_by_split(df_full, sv) + if not df_candidate.empty: + used_split = sv + break + if used_split is None: + df_candidate = df_full + df_pred = df_candidate.reset_index(drop=True) + + # If still empty (e.g., split column exists but no rows for candidates), fall back to all rows + if df_pred.empty: + df_pred = df_full.reset_index(drop=True) + labels_from_dataset = None + + if label_data_path and Path(label_data_path).exists(): + try: + df_labels_all = pd.read_csv(label_data_path) + if SPLIT_COLUMN_NAME in df_labels_all.columns and len(df_labels_all) == len(df_full): + mask = ( + pd.to_numeric(df_labels_all[SPLIT_COLUMN_NAME], errors="coerce") == used_split + if used_split is not None and SPLIT_COLUMN_NAME in df_labels_all.columns + else pd.Series([True] * len(df_full)) + ) + labels_from_dataset = df_labels_all.loc[mask, LABEL_COLUMN_NAME].reset_index(drop=True) + if len(labels_from_dataset) == len(df_pred): + labels_from_dataset = labels_from_dataset.reset_index(drop=True) + except Exception as exc: + print(f"Warning: Unable to align labels for threshold plot: {exc}") + + # Identify probability columns + prob_cols = [ + c + for c in df_pred.columns + if ( + (c.startswith("label_probabilities_") or c.startswith("probabilities_")) + and c != "label_probabilities" + ) + ] + if not prob_cols and "probabilities" in df_pred.columns: + labels_guess = sorted([str(u) for u in pd.unique(df_pred.get(LABEL_COLUMN_NAME, []))]) + # reuse expansion logic from diagnostics + try: + first_val = df_pred["probabilities"].dropna().iloc[0] + parsed = json.loads(first_val) if isinstance(first_val, str) else list(first_val) + n = len(parsed) + if n > 0: + if labels_guess and len(labels_guess) == n: + labels_use = labels_guess + else: + labels_use = [str(i) for i in range(n)] + for idx, lbl in enumerate(labels_use): + df_pred[f"probabilities_{lbl}"] = df_pred["probabilities"].apply( + lambda v: (json.loads(v)[idx] if isinstance(v, str) else list(v)[idx]) if pd.notnull(v) else np.nan + ) + prob_cols = [f"probabilities_{lbl}" for lbl in labels_use] + except Exception: + prob_cols = [] + prob_cols_sorted = sorted(prob_cols) + + def _strip_prob_prefix(col: str) -> str: + if col.startswith("label_probabilities_"): + return col.replace("label_probabilities_", "") + if col.startswith("probabilities_"): + return col.replace("probabilities_", "") + return col + + # True labels + def _extract_labels(): + if labels_from_dataset is not None: + return labels_from_dataset + for col in [ + LABEL_COLUMN_NAME, + f"{LABEL_COLUMN_NAME}_ground_truth", + f"{LABEL_COLUMN_NAME}__ground_truth", + f"{LABEL_COLUMN_NAME}_target", + f"{LABEL_COLUMN_NAME}__target", + "label", + "label_true", + "label_predictions", + "prediction", + ]: + if col in df_pred.columns and col not in prob_cols_sorted: + return df_pred[col] + return None + + labels_series = _extract_labels() + if labels_series is None or not prob_cols_sorted: + return None + + # Positive prob column selection + preferred_keys = ("event", "true", "positive", "pos", "1") + pos_prob_col = None + for col in prob_cols_sorted: + suffix = _strip_prob_prefix(col).lower() + if any(k in suffix for k in preferred_keys): + pos_prob_col = col + break + if pos_prob_col is None: + pos_prob_col = prob_cols_sorted[-1] + + min_len = min(len(labels_series), len(df_pred[pos_prob_col])) + if min_len == 0: + return None + + y_true = np.array(labels_series.iloc[:min_len]) + # map to binary 0/1 + unique_labels = pd.unique(y_true) + if len(unique_labels) < 2: + return None + positive_label = unique_labels[1] if len(unique_labels) >= 2 else unique_labels[0] + y_true_bin = (y_true == positive_label).astype(int) + y_score = np.array(df_pred[pos_prob_col].iloc[:min_len], dtype=float) + + thresholds = np.linspace(0.0, 1.0, 101) + accs: List[float] = [] + precs: List[float] = [] + recs: List[float] = [] + f1s: List[float] = [] + for t in thresholds: + preds = (y_score >= t).astype(int) + accs.append(accuracy_score(y_true_bin, preds)) + precs.append(precision_score(y_true_bin, preds, zero_division=0)) + recs.append(recall_score(y_true_bin, preds, zero_division=0)) + f1s.append(f1_score(y_true_bin, preds, zero_division=0)) + + best_idx = int(np.argmax(f1s)) + best_thr = thresholds[best_idx] + + fig = go.Figure() + fig.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4))) + fig.add_trace(go.Scatter(x=thresholds, y=precs, mode="lines", name="Precision", line=dict(width=4))) + fig.add_trace(go.Scatter(x=thresholds, y=recs, mode="lines", name="Recall", line=dict(width=4))) + fig.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1-Score", line=dict(width=4))) + fig.add_shape( + type="line", + x0=best_thr, + x1=best_thr, + y0=0, + y1=1, + line=dict(color="gray", width=2, dash="dash"), + ) + fig.update_layout( + title=dict(text="Threshold plot", x=0.5), + xaxis_title="Threshold", + yaxis_title="Metric value", + yaxis=dict(range=[0, 1]), + width=760, + height=520, + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), + ) + _style_fig(fig) + return _wrap_plot("Threshold plot", fig, include_js=True) + + +def build_multiclass_roc_pr_plots( + predictions_path: str, + split_value: int = 2, +) -> List[Dict[str, str]]: + """Build one-vs-rest ROC and PR curves for multi-class classification from predictions.""" + preds_file = Path(predictions_path) + if not preds_file.exists(): + return [] + try: + df_pred = pd.read_csv(predictions_path) + except Exception as exc: + print(f"Warning: Unable to read predictions CSV: {exc}") + return [] + + if SPLIT_COLUMN_NAME in df_pred.columns: + df_pred = df_pred[df_pred[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True) + if df_pred.empty: + return [] + + if LABEL_COLUMN_NAME not in df_pred.columns: + return [] + + # Identify per-class probability columns + prob_cols = [ + c + for c in df_pred.columns + if ( + (c.startswith("label_probabilities_") or c.startswith("probabilities_")) + and c != "label_probabilities" + ) + ] + if not prob_cols: + return [] + labels = [c.replace("label_probabilities_", "").replace("probabilities_", "") for c in prob_cols] + labels_sorted = sorted(labels) + + # Ensure all labels are present as probability columns + prob_map = { + c.replace("label_probabilities_", "").replace("probabilities_", ""): c + for c in prob_cols + } + if len(labels_sorted) < 3: + return [] + + y_true_raw = df_pred[LABEL_COLUMN_NAME].astype(str) + # Drop rows with NaN probabilities across any class to avoid metric errors + prob_matrix = df_pred[[prob_map[lbl] for lbl in labels_sorted]].astype(float) + mask_valid = ~prob_matrix.isnull().any(axis=1) + prob_matrix = prob_matrix[mask_valid] + y_true_raw = y_true_raw[mask_valid] + if prob_matrix.empty: + return [] + + y_true_bin = label_binarize(y_true_raw, classes=labels_sorted) + y_score = prob_matrix.to_numpy() + + plots: List[Dict[str, str]] = [] + + # ROC: one-vs-rest + micro + fig_roc = go.Figure() + added_any = False + for idx, lbl in enumerate(labels_sorted): + if y_true_bin[:, idx].sum() == 0 or y_true_bin[:, idx].sum() == len(y_true_bin): + continue # skip classes without both positives and negatives + fpr, tpr, _ = roc_curve(y_true_bin[:, idx], y_score[:, idx]) + fig_roc.add_trace( go.Scatter( - x=bin_centers, - y=frac_positives, - mode="lines+markers", - name="Calibration", - line=dict(color="#2ca02c", width=4), - ) - ) - fig_cal.add_trace( - go.Scatter( - x=[0, 1], - y=[0, 1], + x=fpr, + y=tpr, mode="lines", - name="Perfect Calibration", - line=dict(color="gray", width=2, dash="dash"), + name=f"{lbl} (AUC={auc(fpr, tpr):.3f})", + line=dict(width=3), ) ) - fig_cal.update_layout( - title=dict(text="Calibration Curve", x=0.5), - xaxis_title="Predicted probability", - yaxis_title="Observed frequency", - width=700, - height=500, + added_any = True + # Micro-average only if we have mixed labels + if y_true_bin.sum() > 0 and y_true_bin.sum() < y_true_bin.size: + fpr_micro, tpr_micro, _ = roc_curve(y_true_bin.ravel(), y_score.ravel()) + fig_roc.add_trace( + go.Scatter( + x=fpr_micro, + y=tpr_micro, + mode="lines", + name=f"Micro-average (AUC={auc(fpr_micro, tpr_micro):.3f})", + line=dict(width=3, dash="dash"), + ) ) - _style_fig(fig_cal) - plots.append({ - "title": "Calibration Curve (Predicted Probability vs Observed Frequency)", - "html": pio.to_html(fig_cal, full_html=False, include_plotlyjs=False), - }) - - # Plot 3: Threshold vs Metrics - thresholds = np.linspace(0.0, 1.0, 21) - accs, f1s, sens, specs = [], [], [], [] - for t in thresholds: - y_pred = (y_score >= t).astype(int) - tp = np.sum((y_true == 1) & (y_pred == 1)) - tn = np.sum((y_true == 0) & (y_pred == 0)) - fp = np.sum((y_true == 0) & (y_pred == 1)) - fn = np.sum((y_true == 1) & (y_pred == 0)) - acc = (tp + tn) / max(len(y_true), 1) - prec = tp / max(tp + fp, 1e-9) - rec = tp / max(tp + fn, 1e-9) - f1 = 0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec) - sensitivity = rec - specificity = tn / max(tn + fp, 1e-9) - accs.append(acc) - f1s.append(f1) - sens.append(sensitivity) - specs.append(specificity) - - fig_thresh = go.Figure() - fig_thresh.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4))) - fig_thresh.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1", line=dict(width=4))) - fig_thresh.add_trace(go.Scatter(x=thresholds, y=sens, mode="lines", name="Sensitivity", line=dict(width=4))) - fig_thresh.add_trace(go.Scatter(x=thresholds, y=specs, mode="lines", name="Specificity", line=dict(width=4))) - fig_thresh.update_layout( - title=dict(text="Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", x=0.5), - xaxis_title="Decision threshold", - yaxis_title="Metric value", - width=700, - height=500, + added_any = True + if not added_any: + return [] + fig_roc.add_trace( + go.Scatter( + x=[0, 1], + y=[0, 1], + mode="lines", + name="Random", + line=dict(color="gray", width=2, dash="dot"), + ) + ) + fig_roc.update_layout( + title=dict(text="Multi-class ROC-AUC (one-vs-rest)", x=0.5), + xaxis_title="False Positive Rate", + yaxis_title="True Positive Rate", + width=820, + height=620, legend=dict( - x=0.7, - y=0.2, + x=0.62, + y=0.05, bgcolor="rgba(255,255,255,0.9)", bordercolor="rgba(0,0,0,0.2)", borderwidth=1, ), - shapes=[ - dict( - type="line", - x0=threshold, - x1=threshold, - y0=0, - y1=1, - xref="x", - yref="paper", - line=dict(color="#d62728", width=2, dash="dash"), + ) + _style_fig(fig_roc) + plots.append(_wrap_plot("Multi-class ROC-AUC (one-vs-rest)", fig_roc)) + + # PR: one-vs-rest + micro AP + fig_pr = go.Figure() + added_pr = False + for idx, lbl in enumerate(labels_sorted): + if y_true_bin[:, idx].sum() == 0: + continue + prec, rec, _ = precision_recall_curve(y_true_bin[:, idx], y_score[:, idx]) + ap = average_precision_score(y_true_bin[:, idx], y_score[:, idx]) + fig_pr.add_trace( + go.Scatter( + x=rec, + y=prec, + mode="lines", + name=f"{lbl} (AP={ap:.3f})", + line=dict(width=3), ) - ] if isinstance(threshold, (int, float)) else [], - annotations=[ - dict( - x=threshold, - y=1.02, - xref="x", - yref="paper", - showarrow=False, - text=f"Threshold = {threshold:.2f}", - font=dict(size=11, color="#d62728"), + ) + added_pr = True + if y_true_bin.sum() > 0: + prec_micro, rec_micro, _ = precision_recall_curve(y_true_bin.ravel(), y_score.ravel()) + ap_micro = average_precision_score(y_true_bin, y_score, average="micro") + fig_pr.add_trace( + go.Scatter( + x=rec_micro, + y=prec_micro, + mode="lines", + name=f"Micro-average (AP={ap_micro:.3f})", + line=dict(width=3, dash="dash"), ) - ] if isinstance(threshold, (int, float)) else [], + ) + added_pr = True + if not added_pr: + return plots + fig_pr.update_layout( + title=dict(text="Multi-class Precision-Recall (one-vs-rest)", x=0.5), + xaxis_title="Recall", + yaxis_title="Precision", + width=820, + height=620, + legend=dict( + x=0.62, + y=0.05, + bgcolor="rgba(255,255,255,0.9)", + bordercolor="rgba(0,0,0,0.2)", + borderwidth=1, + ), ) - _style_fig(fig_thresh) - plots.append({ - "title": "Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", - "html": pio.to_html(fig_thresh, full_html=False, include_plotlyjs=False), - }) + _style_fig(fig_pr) + plots.append(_wrap_plot("Multi-class Precision-Recall (one-vs-rest)", fig_pr)) return plots + + +def build_multiclass_metric_plots(test_stats_path: str) -> List[Dict[str, str]]: + """Alternative multi-class transparency plots using test_statistics.json per-class stats.""" + ts_path = Path(test_stats_path) + if not ts_path.exists(): + return [] + try: + with open(ts_path, "r") as f: + test_stats = json.load(f) + except Exception: + return [] + + label_stats = test_stats.get("label", {}) + pcs = label_stats.get("per_class_stats", {}) + if not pcs: + return [] + classes = list(pcs.keys()) + if not classes: + return [] + + metrics = ["precision", "recall", "f1_score", "specificity", "accuracy"] + fig_bar = go.Figure() + for metric in metrics: + values = [] + for cls in classes: + v = pcs.get(cls, {}).get(metric) + values.append(v if isinstance(v, (int, float)) else 0) + fig_bar.add_trace( + go.Bar( + x=classes, + y=values, + name=metric.replace("_", " ").title(), + ) + ) + fig_bar.update_layout( + title=dict(text="Per-Class Metrics (Test)", x=0.5), + xaxis_title="Class", + yaxis_title="Metric value", + barmode="group", + width=900, + height=600, + legend=dict( + x=1.02, + y=1.0, + bgcolor="rgba(255,255,255,0.9)", + bordercolor="rgba(0,0,0,0.2)", + borderwidth=1, + ), + ) + _style_fig(fig_bar) + + return [_wrap_plot("Per-Class Metrics (Test)", fig_bar)]
