# HG changeset patch
# User goeckslab
# Date 1765326253 0
# Node ID db9be962dc13398731eaad271704660ed1629b9a
# Parent 8729f69e9207efd51e8a42f8c3e584b775bd0dd1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
diff -r 8729f69e9207 -r db9be962dc13 html_structure.py
--- a/html_structure.py Wed Dec 03 01:28:52 2025 +0000
+++ b/html_structure.py Wed Dec 10 00:24:13 2025 +0000
@@ -1,6 +1,6 @@
import base64
import json
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
from constants import METRIC_DISPLAY_NAMES
from utils import detect_output_type, extract_metrics_from_json
@@ -23,6 +23,7 @@
) -> str:
display_keys = [
"architecture",
+ "image_size",
"pretrained",
"trainable",
"target_column",
@@ -58,6 +59,15 @@
else:
if key == "task_type":
val_str = val.title() if isinstance(val, str) else "N/A"
+ elif key == "image_size":
+ if val is None:
+ val_str = "N/A"
+ elif isinstance(val, (list, tuple)) and len(val) == 2:
+ val_str = f"{val[0]}x{val[1]}"
+ elif isinstance(val, str) and val.lower() == "original":
+ val_str = "Original (no resize)"
+ else:
+ val_str = str(val)
elif key == "batch_size":
if isinstance(val, (int, float)):
val_str = int(val)
@@ -115,6 +125,11 @@
"Ludwig Trainer Parameters for details."
""
)
+ elif key == "validation_metric":
+ if val is not None:
+ val_str = METRIC_DISPLAY_NAMES.get(str(val), str(val))
+ else:
+ val_str = "N/A"
elif key == "epochs":
if val is None:
val_str = "N/A"
@@ -729,6 +744,64 @@
)
return modal_html + modal_js
+
+def format_dataset_overview_table(rows: List[Dict[str, Any]], regression_mode: bool = False) -> str:
+ """Render a dataset overview table.
+
+ - Classification: per-label distribution across train/val/test.
+ - Regression: split counts (train/val/test).
+ """
+ heading = "
Dataset Overview
"
+ if not rows:
+ return heading + "Dataset overview unavailable.
"
+
+ if regression_mode:
+ headers = ["Split", "Count"]
+ html = (
+ heading
+ + "
"
+ else:
+ html = (
+ heading
+ + "
"
+ return html
+
# -----------------------------------------
# MODEL PERFORMANCE (Train/Val/Test) TABLE
# -----------------------------------------
diff -r 8729f69e9207 -r db9be962dc13 image_learner.xml
--- a/image_learner.xml Wed Dec 03 01:28:52 2025 +0000
+++ b/image_learner.xml Wed Dec 10 00:24:13 2025 +0000
@@ -130,13 +130,9 @@
-
-
+
-
-
-
diff -r 8729f69e9207 -r db9be962dc13 image_workflow.py
--- a/image_workflow.py Wed Dec 03 01:28:52 2025 +0000
+++ b/image_workflow.py Wed Dec 10 00:24:13 2025 +0000
@@ -5,7 +5,7 @@
import tempfile
import zipfile
from pathlib import Path
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
import pandas as pd
import pandas.api.types as ptypes
@@ -35,6 +35,8 @@
self.image_extract_dir: Optional[Path] = None
self.label_metadata: Dict[str, Any] = {}
self.output_type_hint: Optional[str] = None
+ self.label_split_counts: List[Dict[str, int]] = []
+ self.split_counts: Dict[int, int] = {}
logger.info(f"Orchestrator initialized with backend: {type(backend).__name__}")
def _create_temp_dirs(self) -> None:
@@ -186,6 +188,34 @@
logger.error("Error saving prepared CSV", exc_info=True)
raise
+ # Capture actual split counts for downstream reporting (avoids heuristic 70/10/20 tables)
+ try:
+ split_series = pd.to_numeric(df[SPLIT_COLUMN_NAME], errors="coerce")
+ split_series = split_series.dropna().astype(int)
+ self.split_counts = {int(k): int(v) for k, v in split_series.value_counts().to_dict().items()}
+ if LABEL_COLUMN_NAME in df.columns:
+ counts = (
+ df.dropna(subset=[LABEL_COLUMN_NAME])
+ .assign(**{SPLIT_COLUMN_NAME: split_series})
+ .groupby([LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME])
+ .size()
+ .unstack(fill_value=0)
+ .sort_index()
+ )
+ self.label_split_counts = [
+ {
+ "label": str(lbl),
+ "train": int(row.get(0, 0)),
+ "validation": int(row.get(1, 0)),
+ "test": int(row.get(2, 0)),
+ }
+ for lbl, row in counts.iterrows()
+ ]
+ except Exception:
+ logger.warning("Unable to capture split counts for reporting", exc_info=True)
+ self.label_split_counts = []
+ self.split_counts = {}
+
self._capture_label_metadata(df)
return final_csv, split_config, split_info
@@ -349,6 +379,8 @@
"random_seed": self.args.random_seed,
"early_stop": self.args.early_stop,
"label_column_data_path": csv_path,
+ "label_split_counts": self.label_split_counts,
+ "split_counts": self.split_counts,
"augmentation": self.args.augmentation,
"image_resize": self.args.image_resize,
"image_zip": self.args.image_zip,
diff -r 8729f69e9207 -r db9be962dc13 ludwig_backend.py
--- a/ludwig_backend.py Wed Dec 03 01:28:52 2025 +0000
+++ b/ludwig_backend.py Wed Dec 10 00:24:13 2025 +0000
@@ -1,8 +1,9 @@
+import inspect
import json
import logging
import os
from pathlib import Path
-from typing import Any, Dict, Optional, Protocol, Tuple
+from typing import Any, Dict, List, Optional, Protocol, Tuple
import pandas as pd
import pandas.api.types as ptypes
@@ -17,6 +18,7 @@
build_tabbed_html,
encode_image_to_base64,
format_config_table_html,
+ format_dataset_overview_table,
format_stats_table_html,
format_test_merged_stats_table_html,
format_train_val_stats_table_html,
@@ -33,7 +35,9 @@
from ludwig.utils.data_utils import get_split_path
from metaformer_setup import get_visualizations_registry, META_DEFAULT_CFGS
from plotly_plots import (
+ build_binary_threshold_plot,
build_classification_plots,
+ build_multiclass_metric_plots,
build_prediction_diagnostics,
build_regression_test_plots,
build_regression_train_val_plots,
@@ -267,6 +271,23 @@
else:
encoder_config = {"type": raw_encoder}
+ # Set a human-friendly architecture string for reporting
+ arch_display = None
+ if is_metaformer and custom_model:
+ arch_display = str(custom_model)
+ elif isinstance(raw_encoder, dict):
+ enc_type = raw_encoder.get("type")
+ enc_variant = raw_encoder.get("model_variant")
+ if enc_type:
+ base = str(enc_type).replace("_", " ").title()
+ arch_display = f"{base} {enc_variant}" if enc_variant is not None else base
+ else:
+ arch_display = str(raw_encoder).replace("_", " ").title()
+
+ if not arch_display:
+ arch_display = str(model_name)
+ config_params["architecture"] = arch_display
+
batch_size_cfg = batch_size or "auto"
label_column_path = config_params.get("label_column_data_path")
@@ -343,6 +364,7 @@
# Force Ludwig to respect our dimensions by setting additional parameters
image_feat["preprocessing"]["requires_equal_dimensions"] = False
logger.info(f"Set preprocessing dimensions for MetaFormer: {height}x{width} (infer_dimensions=True with max dimensions to allow validation)")
+ config_params["image_size"] = f"{height}x{width}"
# Now set the encoder configuration
image_feat["encoder"] = encoder_config
@@ -374,8 +396,12 @@
image_feat["preprocessing"]["infer_image_max_height"] = height
image_feat["preprocessing"]["infer_image_max_width"] = width
logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions")
+ config_params["image_size"] = f"{height}x{width}"
except (ValueError, IndexError):
logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing")
+ elif not is_metaformer:
+ # No explicit resize provided; keep for reporting purposes
+ config_params.setdefault("image_size", "original")
def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]:
"""Pick a validation metric that Ludwig will accept for the resolved task."""
@@ -471,6 +497,9 @@
config_params.get("validation_metric"),
)
+ # Propagate the resolved validation metric (including any task-based fallback or alias normalization)
+ config_params["validation_metric"] = val_metric
+
conf: Dict[str, Any] = {
"model_type": "ecd",
"input_features": [image_feat],
@@ -641,18 +670,62 @@
except Exception as e:
logger.error(f"Error converting Parquet to CSV: {e}")
- def generate_plots(self, output_dir: Path) -> None:
- """Generate all registered Ludwig visualizations for the latest experiment run."""
- logger.info("Generating all Ludwig visualizations…")
+ @staticmethod
+ def _extract_metric_series(stats: Dict[str, Any], split: str, prefer: Optional[str] = None) -> Tuple[Optional[str], Optional[List[float]]]:
+ """Pull the first numeric metric list we can find for the requested split."""
+ if not isinstance(stats, dict):
+ return None, None
+
+ split_stats = stats.get(split, {})
+ ordered_metrics: List[Tuple[str, List[float]]] = []
+
+ def _append_metrics(metric_map: Dict[str, Any]) -> None:
+ for metric_name, values in metric_map.items():
+ if isinstance(values, list) and any(isinstance(v, (int, float)) for v in values):
+ ordered_metrics.append((metric_name, values))
- # Keep only lightweight plots (drop compare_performance/roc_curves)
- test_plots = {
- "roc_curves_from_test_statistics",
- "confusion_matrix",
- }
+ if isinstance(split_stats, dict):
+ combined = split_stats.get("combined")
+ if isinstance(combined, dict):
+ _append_metrics(combined)
+
+ for feature_name, feature_metrics in split_stats.items():
+ if feature_name == "combined" or not isinstance(feature_metrics, dict):
+ continue
+ _append_metrics(feature_metrics)
+
+ if prefer:
+ for metric_name, values in ordered_metrics:
+ if metric_name == prefer:
+ return metric_name, values
+
+ return ordered_metrics[0] if ordered_metrics else (None, None)
+
+ def generate_plots(self, output_dir: Path) -> None:
+ """Generate Ludwig visualizations (train/val + test) for the latest experiment run."""
+ logger.info("Generating Ludwig visualizations (train/val + test)…")
+
+ # Train/validation visualizations
train_plots = {
"learning_curves",
- "compare_classifiers_performance_subset",
+ }
+
+ # Test visualizations (multi-class transparency)
+ test_plots = {
+ "confusion_matrix",
+ "compare_performance",
+ "compare_classifiers_multiclass_multimetric",
+ "frequency_vs_f1",
+ "confidence_thresholding",
+ "confidence_thresholding_data_vs_acc",
+ "confidence_thresholding_data_vs_acc_subset",
+ "confidence_thresholding_data_vs_acc_subset_per_class",
+ # Binary-only visualizations will still be attempted; multi-class replacements handled elsewhere
+ "binary_threshold_vs_metric",
+ "roc_curves",
+ "precision_recall_curves",
+ "calibration_1_vs_all",
+ "calibration_multiclass",
}
output_dir = Path(output_dir)
@@ -677,7 +750,6 @@
training_stats = _check(exp_dir / "training_statistics.json")
test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME)
- probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME)
gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME)
dataset_path = None
@@ -688,6 +760,9 @@
cfg = json.load(f)
dataset_path = _check(Path(cfg.get("dataset", "")))
split_file = _check(Path(get_split_path(cfg.get("dataset", ""))))
+ model_name = cfg.get("model_name", "model")
+ else:
+ model_name = "model"
output_feature = ""
if desc.exists():
@@ -700,7 +775,44 @@
stats = json.load(f)
output_feature = next(iter(stats.keys()), "")
+ probs_path = None
+ prob_candidates = [
+ exp_dir / f"{LABEL_COLUMN_NAME}_probabilities.csv",
+ exp_dir / f"{output_feature}_probabilities.csv" if output_feature else None,
+ exp_dir / "probabilities.csv",
+ exp_dir / "predictions.csv",
+ exp_dir / PREDICTIONS_PARQUET_FILE_NAME,
+ ]
+ for cand in prob_candidates:
+ if cand and Path(cand).exists():
+ probs_path = str(cand)
+ break
+
viz_registry = get_visualizations_registry()
+ if not viz_registry:
+ logger.warning(
+ "Ludwig visualizations registry not available; train/test PNGs will be skipped."
+ )
+ return
+
+ base_kwargs = {
+ "training_statistics": [training_stats] if training_stats else [],
+ "test_statistics": [test_stats] if test_stats else [],
+ "probabilities": [probs_path] if probs_path else [],
+ "output_feature_name": output_feature,
+ "ground_truth_split": 2,
+ "top_n_classes": [20],
+ "top_k": 3,
+ "metrics": ["f1", "precision", "recall", "accuracy"],
+ "positive_label": 0,
+ "ground_truth_metadata": gt_metadata,
+ "ground_truth": dataset_path,
+ "split_file": split_file,
+ "output_directory": None, # set per plot below
+ "normalize": False,
+ "file_format": "png",
+ "model_names": [model_name],
+ }
for viz_name, viz_func in viz_registry.items():
if viz_name in train_plots:
viz_dir_plot = train_viz
@@ -710,25 +822,22 @@
continue
try:
+ # Build per-viz kwargs based on the function signature to avoid unexpected args
+ sig_params = set(inspect.signature(viz_func).parameters.keys())
+ call_kwargs = {
+ k: v
+ for k, v in base_kwargs.items()
+ if k in sig_params and v is not None
+ }
+ if "output_directory" in sig_params:
+ call_kwargs["output_directory"] = str(viz_dir_plot)
+
viz_func(
- training_statistics=[training_stats] if training_stats else [],
- test_statistics=[test_stats] if test_stats else [],
- probabilities=[probs_path] if probs_path else [],
- output_feature_name=output_feature,
- ground_truth_split=2,
- top_n_classes=[0],
- top_k=3,
- ground_truth_metadata=gt_metadata,
- ground_truth=dataset_path,
- split_file=split_file,
- output_directory=str(viz_dir_plot),
- normalize=False,
- file_format="png",
+ **call_kwargs,
)
logger.info(f"✔ Generated {viz_name}")
except Exception as e:
logger.warning(f"✘ Skipped {viz_name}: {e}")
-
logger.info(f"All visualizations written to {viz_dir}")
def generate_html_report(
@@ -756,6 +865,7 @@
label_metadata_path = config.get("label_column_data_path")
if label_metadata_path:
label_metadata_path = Path(label_metadata_path)
+ dataset_path_from_desc: Optional[Path] = None
# Pull additional config details from description.json if available
config_for_summary = dict(config)
@@ -765,7 +875,8 @@
if desc_path.exists():
try:
with open(desc_path, "r") as f:
- desc_cfg = json.load(f).get("config", {})
+ desc_json = json.load(f)
+ desc_cfg = desc_json.get("config", {}) if isinstance(desc_json, dict) else {}
encoder_cfg = (
desc_cfg.get("input_features", [{}])[0].get("encoder", {})
if isinstance(desc_cfg.get("input_features", [{}]), list)
@@ -783,10 +894,20 @@
arch_type = encoder_cfg.get("type")
arch_variant = encoder_cfg.get("model_variant")
+ arch_custom = encoder_cfg.get("custom_model")
arch_name = None
+ if arch_custom:
+ arch_name = str(arch_custom)
if arch_type:
arch_base = str(arch_type).replace("_", " ").title()
- arch_name = f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base
+ arch_type_name = (
+ f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base
+ )
+ # Prefer explicit custom model names (e.g., MetaFormer) but fall back to encoder type
+ arch_name = arch_name or arch_type_name
+ if not arch_name and config.get("model_name"):
+ # As a last resort, show the user-selected model name (handles custom/MetaFormer cases)
+ arch_name = str(config.get("model_name"))
summary_fields = {
"architecture": arch_name,
@@ -814,12 +935,22 @@
if k in {"target_column", "image_column"} and config_for_summary.get(k):
continue
config_for_summary.setdefault(k, v)
+
+ dataset_field = None
+ if isinstance(desc_json, dict):
+ dataset_field = desc_json.get("dataset") or desc_cfg.get("dataset")
+ if dataset_field:
+ try:
+ dataset_path_from_desc = Path(dataset_field)
+ except TypeError:
+ dataset_path_from_desc = None
+ if dataset_path_from_desc and (not label_metadata_path or not label_metadata_path.exists()):
+ label_metadata_path = dataset_path_from_desc
except Exception as e: # pragma: no cover - defensive
logger.warning(f"Could not merge description.json into config summary: {e}")
base_viz_dir = exp_dir / "visualizations"
train_viz_dir = base_viz_dir / "train"
- test_viz_dir = base_viz_dir / "test"
html = get_html_template()
@@ -880,10 +1011,164 @@
});
}
});
-
+
"""
html += f"{title}
"
+ def append_plot_blocks(tab_html: str, plots: List[Dict[str, str]], title_suffix: str = "") -> str:
+ """Append Plotly blocks to a tab with consistent markup."""
+ if not plots:
+ return tab_html
+ suffix = title_suffix or ""
+ for plot in plots:
+ tab_html += (
+ f"{plot['title']}{suffix}
"
+ f"{plot['html']}
"
+ )
+ return tab_html
+
+ def build_dataset_overview(
+ label_metadata: Optional[Path],
+ output_type: Optional[str],
+ split_probabilities: Optional[List[float]],
+ label_split_counts: Optional[List[Dict[str, int]]] = None,
+ split_counts: Optional[Dict[int, int]] = None,
+ fallback_dataset: Optional[Path] = None,
+ ) -> str:
+ """Summarize dataset distribution across splits using the actual split config."""
+ if label_split_counts:
+ # Use the actual counts captured during data prep instead of heuristics.
+ return format_dataset_overview_table(label_split_counts, regression_mode=False)
+
+ if output_type == "regression" and split_counts:
+ rows = [
+ {"split": "train", "count": int(split_counts.get(0, 0))},
+ {"split": "validation", "count": int(split_counts.get(1, 0))},
+ {"split": "test", "count": int(split_counts.get(2, 0))},
+ ]
+ return format_dataset_overview_table(rows, regression_mode=True)
+
+ candidate_paths: List[Path] = []
+ if label_metadata and label_metadata.exists():
+ candidate_paths.append(label_metadata)
+ if fallback_dataset and fallback_dataset.exists():
+ candidate_paths.append(fallback_dataset)
+ if not candidate_paths:
+ return format_dataset_overview_table([])
+
+ def _normalize_split_probabilities(
+ probs: Optional[List[float]],
+ ) -> Optional[List[float]]:
+ if not probs or len(probs) != 3:
+ return None
+ try:
+ probs = [float(p) for p in probs]
+ except (TypeError, ValueError):
+ return None
+ total = sum(probs)
+ if total <= 0:
+ return None
+ return [p / total for p in probs]
+
+ def _split_counts_from_column(df: pd.DataFrame) -> Dict[int, int]:
+ if SPLIT_COLUMN_NAME not in df.columns:
+ return {}
+ split_series = pd.to_numeric(
+ df[SPLIT_COLUMN_NAME], errors="coerce"
+ ).dropna()
+ if split_series.empty:
+ return {}
+ split_series = split_series.astype(int)
+ return split_series.value_counts().to_dict()
+
+ def _split_counts_from_probs(total: int, probs: List[float]) -> Dict[int, int]:
+ train_n = int(total * probs[0])
+ val_n = int(total * probs[1])
+ test_n = max(0, total - train_n - val_n)
+ return {0: train_n, 1: val_n, 2: test_n}
+
+ fallback_rows: Optional[List[Dict[str, int]]] = None
+ for meta_path in candidate_paths:
+ try:
+ df_labels = pd.read_csv(meta_path)
+ probs = _normalize_split_probabilities(split_probabilities)
+
+ # Regression (or missing label column): only need split counts
+ if output_type == "regression" or LABEL_COLUMN_NAME not in df_labels.columns:
+ split_counts_found = _split_counts_from_column(df_labels)
+ if split_counts_found:
+ rows = [
+ {"split": "train", "count": int(split_counts_found.get(0, 0))},
+ {"split": "validation", "count": int(split_counts_found.get(1, 0))},
+ {"split": "test", "count": int(split_counts_found.get(2, 0))},
+ ]
+ return format_dataset_overview_table(rows, regression_mode=True)
+ if probs and fallback_rows is None:
+ split_counts_found = _split_counts_from_probs(len(df_labels), probs)
+ fallback_rows = [
+ {"split": "train", "count": int(split_counts_found.get(0, 0))},
+ {"split": "validation", "count": int(split_counts_found.get(1, 0))},
+ {"split": "test", "count": int(split_counts_found.get(2, 0))},
+ ]
+ continue
+
+ # Classification: prefer actual split assignments; fall back to configured probabilities
+ if SPLIT_COLUMN_NAME in df_labels.columns:
+ df_counts = df_labels[[LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME]].copy()
+ df_counts[SPLIT_COLUMN_NAME] = pd.to_numeric(
+ df_counts[SPLIT_COLUMN_NAME], errors="coerce"
+ )
+ df_counts = df_counts.dropna(subset=[SPLIT_COLUMN_NAME])
+ if df_counts.empty:
+ continue
+
+ df_counts[SPLIT_COLUMN_NAME] = df_counts[SPLIT_COLUMN_NAME].astype(int)
+ df_counts = df_counts.dropna(subset=[LABEL_COLUMN_NAME])
+ counts = (
+ df_counts.groupby([LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME])
+ .size()
+ .unstack(fill_value=0)
+ .sort_index()
+ )
+ rows = []
+ for lbl, row in counts.iterrows():
+ rows.append(
+ {
+ "label": str(lbl),
+ "train": int(row.get(0, 0)),
+ "validation": int(row.get(1, 0)),
+ "test": int(row.get(2, 0)),
+ }
+ )
+ return format_dataset_overview_table(rows)
+
+ if probs:
+ label_series = df_labels[LABEL_COLUMN_NAME].dropna()
+ label_counts = label_series.value_counts().sort_index()
+ if label_counts.empty:
+ continue
+ rows = []
+ for lbl, count in label_counts.items():
+ train_n = int(count * probs[0])
+ val_n = int(count * probs[1])
+ test_n = max(0, count - train_n - val_n)
+ rows.append(
+ {
+ "label": str(lbl),
+ "train": train_n,
+ "validation": val_n,
+ "test": test_n,
+ }
+ )
+ fallback_rows = fallback_rows or rows
+ except Exception as exc:
+ logger.warning("Failed to build dataset overview from %s: %s", meta_path, exc)
+ continue
+
+ if fallback_rows:
+ return format_dataset_overview_table(fallback_rows, regression_mode=output_type == "regression")
+ return format_dataset_overview_table([])
+
metrics_html = ""
train_val_metrics_html = ""
test_metrics_html = ""
@@ -911,6 +1196,23 @@
f"Could not load stats for HTML report: {type(e).__name__}: {e}"
)
+ if not output_type:
+ # Fallback to configured task type when stats are unavailable (e.g., failed run).
+ output_type = (
+ str(config_for_summary.get("task_type")).lower()
+ if config_for_summary.get("task_type")
+ else None
+ )
+
+ dataset_overview_html = build_dataset_overview(
+ label_metadata_path,
+ output_type,
+ config.get("split_probabilities"),
+ config.get("label_split_counts"),
+ config.get("split_counts"),
+ dataset_path_from_desc,
+ )
+
config_html = ""
training_progress = self.get_training_process(output_dir)
try:
@@ -937,11 +1239,12 @@
exclude_names: Optional[set] = None,
) -> str:
if not dir_path.exists():
- return f"{title}
Directory not found.
"
+ return ""
exclude_names = exclude_names or set()
- imgs = list(dir_path.glob("*.png"))
+ # Search recursively because Ludwig can nest figures under per-feature folders
+ imgs = list(dir_path.rglob("*.png"))
# Exclude ROC curves and standard confusion matrices (keep only entropy version)
default_exclude = {
@@ -983,7 +1286,7 @@
]
if not imgs:
- return f"{title}
No plots found.
"
+ return ""
# Sort images by name for consistent ordering (works with string and numeric labels)
imgs = sorted(imgs, key=lambda x: x.name)
@@ -1006,36 +1309,86 @@
)
return html_section
- # Show performance first, then config
- tab1_content = metrics_html + config_html
+ # Show dataset overview, performance first, then config
+ predictions_csv_path = exp_dir / "predictions.csv"
+
+ tab1_content = dataset_overview_html + metrics_html + config_html
- tab2_content = train_val_metrics_html + render_img_section(
- "Training and Validation Visualizations",
- train_viz_dir,
- output_type,
- exclude_names={
- "compare_classifiers_performance_from_prob.png",
- "roc_curves_from_prediction_statistics.png",
- "precision_recall_curves_from_prediction_statistics.png",
- "precision_recall_curve.png",
- },
+ tab2_content = train_val_metrics_html
+ # Preload binary threshold plot so it appears first in Train/Val tab
+ threshold_plot = None
+ threshold_value = (
+ config_for_summary.get("threshold")
+ if config_for_summary.get("threshold") is not None
+ else config.get("threshold")
)
+ if threshold_value is None and output_type == "binary":
+ threshold_value = 0.5
+ if output_type == "binary" and predictions_csv_path.exists():
+ try:
+ threshold_plot = build_binary_threshold_plot(
+ str(predictions_csv_path),
+ label_data_path=str(config.get("label_column_data_path"))
+ if config.get("label_column_data_path")
+ else None,
+ split_value=1,
+ )
+ except Exception as e:
+ logger.warning(f"Could not generate validation threshold plot: {e}")
+
if train_stats_path.exists():
try:
if output_type == "regression":
tv_plots = build_regression_train_val_plots(str(train_stats_path))
+ tab2_content = append_plot_blocks(tab2_content, tv_plots)
else:
tv_plots = build_train_validation_plots(str(train_stats_path))
- for plot in tv_plots:
- tab2_content += (
- f"{plot['title']}
"
- f"{plot['html']}
"
- )
- if tv_plots:
- logger.info(f"Generated {len(tv_plots)} train/val diagnostic plots")
+ # Add threshold plot first, then other train/val plots
+ if threshold_plot:
+ tab2_content = append_plot_blocks(tab2_content, [threshold_plot])
+ # Only append once; avoid duplicates if added elsewhere
+ threshold_plot = None
+ tab2_content = append_plot_blocks(tab2_content, tv_plots)
+ if threshold_plot or tv_plots:
+ logger.info(
+ f"Added {len(tv_plots) + (1 if threshold_plot else 0)} train/val diagnostic plots"
+ )
except Exception as e:
logger.warning(f"Could not generate train/val plots: {e}")
+ # Only include training PNGs for regression; classification is handled by filtered Plotly plots
+ if output_type == "regression":
+ tab2_content += render_img_section(
+ "Training and Validation Visualizations",
+ train_viz_dir,
+ output_type,
+ exclude_names={
+ "compare_classifiers_performance_from_prob.png",
+ "roc_curves_from_prediction_statistics.png",
+ "precision_recall_curves_from_prediction_statistics.png",
+ "precision_recall_curve.png",
+ },
+ )
+
+ # Validation diagnostics (calibration/threshold) from predictions.csv, using split=1
+ if output_type in ("binary", "category") and predictions_csv_path.exists():
+ try:
+ val_diag_plots = build_prediction_diagnostics(
+ str(predictions_csv_path),
+ label_data_path=str(config.get("label_column_data_path"))
+ if config.get("label_column_data_path")
+ else None,
+ split_value=1,
+ )
+ val_conf_plots = [
+ p for p in val_diag_plots if "Prediction Confidence Distribution" in p.get("title", "")
+ ]
+ tab2_content = append_plot_blocks(
+ tab2_content, val_conf_plots, " (Validation)"
+ )
+ except Exception as e:
+ logger.warning(f"Could not generate validation diagnostics: {e}")
+
# --- Predictions vs Ground Truth table (REGRESSION ONLY) ---
preds_section = ""
parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
@@ -1077,18 +1430,12 @@
logger.warning(f"Could not build Predictions vs GT table: {e}")
tab3_content = test_metrics_html + preds_section
- test_plotly_added = False
if output_type == "regression" and train_stats_path.exists():
try:
test_plots = build_regression_test_plots(str(train_stats_path))
- for plot in test_plots:
- tab3_content += (
- f"{plot['title']}
"
- f"{plot['html']}
"
- )
+ tab3_content = append_plot_blocks(tab3_content, test_plots)
if test_plots:
- test_plotly_added = True
logger.info(f"Generated {len(test_plots)} regression test plots")
except Exception as e:
logger.warning(f"Could not generate regression test plots: {e}")
@@ -1104,46 +1451,42 @@
train_set_metadata_path=str(train_set_metadata_path)
if train_set_metadata_path.exists()
else None,
+ threshold=threshold_value,
)
- for plot in interactive_plots:
- tab3_content += (
- f"{plot['title']}
"
- f"{plot['html']}
"
- )
+ tab3_content = append_plot_blocks(tab3_content, interactive_plots)
if interactive_plots:
- test_plotly_added = True
- logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots")
+ logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots")
except Exception as e:
logger.warning(f"Could not generate Plotly plots: {e}")
- # Add prediction diagnostics from predictions.csv
- predictions_csv_path = exp_dir / "predictions.csv"
- try:
- diag_plots = build_prediction_diagnostics(
- str(predictions_csv_path),
- label_data_path=str(config.get("label_column_data_path"))
- if config.get("label_column_data_path")
- else None,
- threshold=config.get("threshold"),
- )
- for plot in diag_plots:
- tab3_content += (
- f"{plot['title']}
"
- f"{plot['html']}
"
+ # Multi-class transparency plots from test stats (replace ROC/PR for multi-class)
+ if output_type == "category" and test_stats_path.exists():
+ try:
+ multi_curves = build_multiclass_metric_plots(str(test_stats_path))
+ tab3_content = append_plot_blocks(tab3_content, multi_curves)
+ if multi_curves:
+ logger.info("Added multi-class per-class metric plots to test tab")
+ except Exception as e:
+ logger.warning(f"Could not generate multi-class metric plots: {e}")
+
+ # Test diagnostics (confidence histogram) from predictions.csv, using split=2
+ if predictions_csv_path.exists():
+ try:
+ test_diag_plots = build_prediction_diagnostics(
+ str(predictions_csv_path),
+ label_data_path=str(config.get("label_column_data_path"))
+ if config.get("label_column_data_path")
+ else None,
+ split_value=2,
)
- if diag_plots:
- test_plotly_added = True
- logger.info(f"Generated {len(diag_plots)} prediction diagnostic plots")
- except Exception as e:
- logger.warning(f"Could not generate prediction diagnostics: {e}")
-
- # Fallback: include static PNGs if no interactive plots were added
- if not test_plotly_added:
- tab3_content += render_img_section(
- "Test Visualizations (PNG fallback)",
- test_viz_dir,
- output_type,
- )
+ test_conf_plots = [
+ p for p in test_diag_plots if "Prediction Confidence Distribution" in p.get("title", "")
+ ]
+ if test_conf_plots:
+ tab3_content = append_plot_blocks(tab3_content, test_conf_plots)
+ logger.info("Added test prediction confidence plot")
+ except Exception as e:
+ logger.warning(f"Could not generate test diagnostics: {e}")
# Add static TEST PNGs (with default dedupe/exclusions)
tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
diff -r 8729f69e9207 -r db9be962dc13 plotly_plots.py
--- a/plotly_plots.py Wed Dec 03 01:28:52 2025 +0000
+++ b/plotly_plots.py Wed Dec 10 00:24:13 2025 +0000
@@ -7,6 +7,17 @@
import plotly.graph_objects as go
import plotly.io as pio
from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME
+from sklearn.metrics import (
+ accuracy_score,
+ auc,
+ average_precision_score,
+ f1_score,
+ precision_recall_curve,
+ precision_score,
+ recall_score,
+ roc_curve,
+)
+from sklearn.preprocessing import label_binarize
def _style_fig(fig: go.Figure, font_size: int = 12) -> go.Figure:
@@ -21,6 +32,64 @@
return fig
+def _fig_to_html(
+ fig: go.Figure, *, include_js: bool = False, config: Optional[dict] = None
+) -> str:
+ """Render a Plotly figure to a lightweight HTML fragment."""
+ include_plotlyjs = "cdn" if include_js else False
+ return pio.to_html(
+ fig,
+ full_html=False,
+ include_plotlyjs=include_plotlyjs,
+ config=config,
+ )
+
+
+def _wrap_plot(
+ title: str,
+ fig: go.Figure,
+ *,
+ include_js: bool = False,
+ config: Optional[dict] = None,
+) -> Dict[str, str]:
+ """Package a figure with its title for downstream HTML rendering."""
+ return {"title": title, "html": _fig_to_html(fig, include_js=include_js, config=config)}
+
+
+def _line_chart(
+ traces: List[tuple],
+ *,
+ title: str,
+ yaxis_title: str,
+) -> go.Figure:
+ """Build a basic epoch-indexed line chart for train/val/test curves."""
+ fig = go.Figure()
+ for name, series in traces:
+ if not series:
+ continue
+ epochs = list(range(1, len(series) + 1))
+ fig.add_trace(
+ go.Scatter(
+ x=epochs,
+ y=series,
+ mode="lines+markers",
+ name=name,
+ line=dict(width=4),
+ )
+ )
+
+ fig.update_layout(
+ title=dict(text=title, x=0.5),
+ xaxis_title="Epoch",
+ yaxis_title=yaxis_title,
+ width=760,
+ height=520,
+ hovermode="x unified",
+ )
+ _style_fig(fig)
+ return fig
+
+
def _labels_from_metadata_dict(meta_dict: dict) -> List[str]:
"""Extract ordered label names from Ludwig train_set_metadata."""
if not isinstance(meta_dict, dict):
@@ -106,6 +175,7 @@
training_stats_path: Optional[str] = None,
metadata_csv_path: Optional[str] = None,
train_set_metadata_path: Optional[str] = None,
+ threshold: Optional[float] = None,
) -> List[Dict[str, str]]:
"""
Read Ludwig’s test_statistics.json and build three interactive Plotly panels:
@@ -156,8 +226,11 @@
)
)
fig_cm.update_traces(xgap=2, ygap=2)
+ cm_title = "Confusion Matrix"
+ if threshold is not None:
+ cm_title = f"Confusion Matrix (Threshold: {threshold})"
fig_cm.update_layout(
- title=dict(text="Confusion Matrix", x=0.5),
+ title=dict(text=cm_title, x=0.5),
xaxis_title="Predicted",
yaxis_title="Observed",
yaxis_autorange="reversed",
@@ -196,25 +269,19 @@
yshift=-2,
)
- plots.append({
- "title": "Confusion Matrix",
- "html": pio.to_html(
- fig_cm,
- full_html=False,
- include_plotlyjs="cdn",
- config=common_cfg
- )
- })
+ plots.append(
+ _wrap_plot("Confusion Matrix", fig_cm, include_js=True, config=common_cfg)
+ )
- # 1) ROC Curve (from test_statistics)
- roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels)
- if roc_plot:
- plots.append(roc_plot)
+ # 1) ROC / PR curves only for binary tasks
+ if n_classes == 2:
+ roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels)
+ if roc_plot:
+ plots.append(roc_plot)
- # 2) Precision-Recall Curve (from test_statistics)
- pr_plot = _build_precision_recall_plot(label_stats, common_cfg)
- if pr_plot:
- plots.append(pr_plot)
+ pr_plot = _build_precision_recall_plot(label_stats, common_cfg)
+ if pr_plot:
+ plots.append(pr_plot)
# 2) Classification Report Heatmap
pcs = label_stats.get("per_class_stats", {})
@@ -259,15 +326,9 @@
margin=dict(t=80, l=80, r=80, b=80),
)
_style_fig(fig_cr)
- plots.append({
- "title": "Per-Class metrics",
- "html": pio.to_html(
- fig_cr,
- full_html=False,
- include_plotlyjs=False,
- config=common_cfg
- )
- })
+ plots.append(
+ _wrap_plot("Per-Class metrics", fig_cr, config=common_cfg)
+ )
# 3) Prediction Diagnostics (from predictions.csv)
# Note: appended separately in generate_html_report, not returned here.
@@ -294,8 +355,6 @@
include_js = True # Load Plotly.js once for this group
def _get_series(stats: dict, metric: str) -> List[float]:
- if metric not in stats:
- return []
vals = stats.get(metric, [])
if isinstance(vals, list):
return [float(v) for v in vals]
@@ -304,181 +363,98 @@
except Exception:
return []
- def _line_plot(metric_key: str, title: str, yaxis_title: str) -> Optional[Dict[str, str]]:
- train_series = _get_series(label_train, metric_key)
- val_series = _get_series(label_val, metric_key)
+ metric_specs = [
+ ("loss", "Loss across epochs", "Loss"),
+ ("accuracy", "Accuracy across epochs", "Accuracy"),
+ ("roc_auc", "ROC-AUC across epochs", "ROC-AUC"),
+ ("precision", "Precision across epochs", "Precision"),
+ ("recall", "Recall/Sensitivity across epochs", "Recall"),
+ ("specificity", "Specificity across epochs", "Specificity"),
+ ]
+
+ for key, title, yaxis in metric_specs:
+ train_series = _get_series(label_train, key)
+ val_series = _get_series(label_val, key)
if not train_series and not val_series:
- return None
- epochs_train = list(range(1, len(train_series) + 1))
- epochs_val = list(range(1, len(val_series) + 1))
- fig = go.Figure()
- if train_series:
- fig.add_trace(
- go.Scatter(
- x=epochs_train,
- y=train_series,
- mode="lines+markers",
- name="Train",
- line=dict(width=4),
- )
- )
- if val_series:
- fig.add_trace(
- go.Scatter(
- x=epochs_val,
- y=val_series,
- mode="lines+markers",
- name="Validation",
- line=dict(width=4),
- )
- )
- fig.update_layout(
- title=dict(text=title, x=0.5),
- xaxis_title="Epoch",
- yaxis_title=yaxis_title,
- width=760,
- height=520,
- hovermode="x unified",
+ continue
+ fig = _line_chart(
+ [("Train", train_series), ("Validation", val_series)],
+ title=title,
+ yaxis_title=yaxis,
)
- _style_fig(fig)
- return {
- "title": title,
- "html": pio.to_html(
- fig,
- full_html=False,
- include_plotlyjs="cdn" if include_js else False,
- ),
- }
-
- # Core learning curves
- for key, title in [
- ("roc_auc", "ROC-AUC across epochs"),
- ("precision", "Precision across epochs"),
- ("recall", "Recall/Sensitivity across epochs"),
- ("specificity", "Specificity across epochs"),
- ]:
- plot = _line_plot(key, title, title.replace("Learning Curve", "").strip())
- if plot:
- plots.append(plot)
- include_js = False
+ plots.append(_wrap_plot(title, fig, include_js=include_js))
+ include_js = False
# Precision vs Recall evolution (validation)
val_prec = _get_series(label_val, "precision")
val_rec = _get_series(label_val, "recall")
if val_prec and val_rec:
- epochs = list(range(1, min(len(val_prec), len(val_rec)) + 1))
- fig_pr = go.Figure()
- fig_pr.add_trace(
- go.Scatter(
- x=epochs,
- y=val_prec[: len(epochs)],
- mode="lines+markers",
- name="Precision",
- )
- )
- fig_pr.add_trace(
- go.Scatter(
- x=epochs,
- y=val_rec[: len(epochs)],
- mode="lines+markers",
- name="Recall",
- )
+ max_len = min(len(val_prec), len(val_rec))
+ fig_pr = _line_chart(
+ [
+ ("Precision", val_prec[:max_len]),
+ ("Recall", val_rec[:max_len]),
+ ],
+ title="Validation Precision and Recall by Epoch",
+ yaxis_title="Value",
)
- fig_pr.update_layout(
- title=dict(text="Validation Precision and Recall by Epoch", x=0.5),
- xaxis_title="Epoch",
- yaxis_title="Value",
- width=760,
- height=520,
- hovermode="x unified",
- )
- _style_fig(fig_pr)
- plots.append({
- "title": "Precision vs Recall Evolution",
- "html": pio.to_html(
- fig_pr,
- full_html=False,
- include_plotlyjs="cdn" if include_js else False,
- ),
- })
+ plots.append(_wrap_plot("Precision vs Recall Evolution", fig_pr, include_js=include_js))
include_js = False
- # F1-score derived
def _compute_f1(p: List[float], r: List[float]) -> List[float]:
- f1_vals = []
- for prec, rec in zip(p, r):
- if (prec + rec) == 0:
- f1_vals.append(0.0)
- else:
- f1_vals.append(2 * prec * rec / (prec + rec))
- return f1_vals
+ return [
+ 0.0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)
+ for prec, rec in zip(p, r)
+ ]
f1_train = _compute_f1(_get_series(label_train, "precision"), _get_series(label_train, "recall"))
f1_val = _compute_f1(val_prec, val_rec)
if f1_train or f1_val:
- fig = go.Figure()
- if f1_train:
- fig.add_trace(go.Scatter(x=list(range(1, len(f1_train) + 1)), y=f1_train, mode="lines+markers", name="Train", line=dict(width=4)))
- if f1_val:
- fig.add_trace(go.Scatter(x=list(range(1, len(f1_val) + 1)), y=f1_val, mode="lines+markers", name="Validation", line=dict(width=4)))
- fig.update_layout(
- title=dict(text="F1-Score across epochs (derived)", x=0.5),
- xaxis_title="Epoch",
+ fig_f1 = _line_chart(
+ [("Train", f1_train), ("Validation", f1_val)],
+ title="F1-Score across epochs (derived)",
yaxis_title="F1-Score",
- width=760,
- height=520,
- hovermode="x unified",
)
- _style_fig(fig)
- plots.append({
- "title": "F1-Score across epochs (derived)",
- "html": pio.to_html(
- fig,
- full_html=False,
- include_plotlyjs="cdn" if include_js else False,
- ),
- })
+ plots.append(_wrap_plot("F1-Score across epochs (derived)", fig_f1, include_js=include_js))
include_js = False
# Overfitting Gap: Train vs Val ROC-AUC (gap)
roc_train = _get_series(label_train, "roc_auc")
roc_val = _get_series(label_val, "roc_auc")
if roc_train and roc_val:
- epochs_gap = list(range(1, min(len(roc_train), len(roc_val)) + 1))
- gaps = [t - v for t, v in zip(roc_train[:len(epochs_gap)], roc_val[:len(epochs_gap)])]
- fig_gap = go.Figure()
- fig_gap.add_trace(go.Scatter(x=epochs_gap, y=gaps, mode="lines+markers", name="Train - Val ROC-AUC", line=dict(width=4)))
- fig_gap.update_layout(
- title=dict(text="Overfitting gap: ROC-AUC across epochs", x=0.5),
- xaxis_title="Epoch",
+ max_len = min(len(roc_train), len(roc_val))
+ gaps = [t - v for t, v in zip(roc_train[:max_len], roc_val[:max_len])]
+ fig_gap = _line_chart(
+ [("Train - Val ROC-AUC", gaps)],
+ title="Overfitting gap: ROC-AUC across epochs",
yaxis_title="Gap",
- width=760,
- height=520,
- hovermode="x unified",
)
- _style_fig(fig_gap)
- plots.append({
- "title": "Overfitting gap: ROC-AUC across epochs",
- "html": pio.to_html(
- fig_gap,
- full_html=False,
- include_plotlyjs="cdn" if include_js else False,
- ),
- })
+ plots.append(_wrap_plot("Overfitting gap: ROC-AUC across epochs", fig_gap, include_js=include_js))
include_js = False
# Best Epoch Dashboard (based on max val ROC-AUC)
if roc_val:
best_idx = int(np.argmax(roc_val))
best_epoch = best_idx + 1
- spec_val = _get_series(label_val, "specificity")
- metrics_at_best = {
- "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None,
- "Precision": val_prec[best_idx] if best_idx < len(val_prec) else None,
- "Recall": val_rec[best_idx] if best_idx < len(val_rec) else None,
- "Specificity": spec_val[best_idx] if best_idx < len(spec_val) else None,
- "F1-Score": f1_val[best_idx] if best_idx < len(f1_val) else None,
+ metrics_at_best: Dict[str, Optional[float]] = {
+ "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None
}
+
+ for metric_key, label in [
+ ("accuracy", "Accuracy"),
+ ("balanced_accuracy", "Balanced Accuracy"),
+ ("precision", "Precision"),
+ ("recall", "Recall"),
+ ("specificity", "Specificity"),
+ ("loss", "Loss"),
+ ]:
+ series = _get_series(label_val, metric_key)
+ if series and best_idx < len(series):
+ metrics_at_best[label] = series[best_idx]
+
+ if f1_val and best_idx < len(f1_val):
+ metrics_at_best["F1-Score (derived)"] = f1_val[best_idx]
+
fig_best = go.Figure()
for name, value in metrics_at_best.items():
if value is not None:
@@ -492,15 +468,7 @@
showlegend=False,
)
_style_fig(fig_best)
- plots.append({
- "title": "Best Validation Epoch Snapshot (Metrics)",
- "html": pio.to_html(
- fig_best,
- full_html=False,
- include_plotlyjs="cdn" if include_js else False,
- ),
- })
- include_js = False
+ plots.append(_wrap_plot("Best Validation Epoch Snapshot (Metrics)", fig_best, include_js=include_js))
return plots
@@ -529,46 +497,13 @@
val_series = _get_regression_series(val_split, metric_key)
if not train_series and not val_series:
return None
- epochs_train = list(range(1, len(train_series) + 1))
- epochs_val = list(range(1, len(val_series) + 1))
- fig = go.Figure()
- if train_series:
- fig.add_trace(
- go.Scatter(
- x=epochs_train,
- y=train_series,
- mode="lines+markers",
- name="Train",
- line=dict(width=4),
- )
- )
- if val_series:
- fig.add_trace(
- go.Scatter(
- x=epochs_val,
- y=val_series,
- mode="lines+markers",
- name="Validation",
- line=dict(width=4),
- )
- )
- fig.update_layout(
- title=dict(text=title, x=0.5),
- xaxis_title="Epoch",
+
+ fig = _line_chart(
+ [("Train", train_series), ("Validation", val_series)],
+ title=title,
yaxis_title=yaxis_title,
- width=760,
- height=520,
- hovermode="x unified",
)
- _style_fig(fig)
- return {
- "title": title,
- "html": pio.to_html(
- fig,
- full_html=False,
- include_plotlyjs="cdn" if include_js else False,
- ),
- }
+ return _wrap_plot(title, fig, include_js=include_js)
def build_regression_train_val_plots(train_stats_path: str) -> List[Dict[str, str]]:
@@ -627,46 +562,25 @@
("r2", "R² Across Epochs", "R²"),
("loss", "Loss Across Epochs", "Loss"),
]
- epochs = None
for metric_key, title, ytitle in metrics:
series = _get_regression_series(label_test, metric_key)
if not series:
continue
- if epochs is None:
- epochs = list(range(1, len(series) + 1))
- fig = go.Figure()
- fig.add_trace(
- go.Scatter(
- x=epochs,
- y=series[: len(epochs)],
- mode="lines+markers",
- name="Test",
- line=dict(width=4),
- )
+ fig = _line_chart(
+ [("Test", series)],
+ title=title,
+ yaxis_title=ytitle,
)
- fig.update_layout(
- title=dict(text=title, x=0.5),
- xaxis_title="Epoch",
- yaxis_title=ytitle,
- width=760,
- height=520,
- hovermode="x unified",
- )
- _style_fig(fig)
- plots.append({
- "title": title,
- "html": pio.to_html(
- fig,
- full_html=False,
- include_plotlyjs="cdn" if include_js else False,
- ),
- })
+ plots.append(_wrap_plot(title, fig, include_js=include_js))
include_js = False
return plots
def _build_static_roc_plot(
- label_stats: dict, config: dict, friendly_labels: Optional[List[str]] = None
+ label_stats: dict,
+ config: dict,
+ friendly_labels: Optional[List[str]] = None,
+ threshold: Optional[float] = None,
) -> Optional[Dict[str, str]]:
"""Build ROC curve directly from test_statistics.json (single curve)."""
roc_data = label_stats.get("roc_curve")
@@ -776,6 +690,42 @@
fig.update_xaxes(range=[0, 1.0])
fig.update_yaxes(range=[0, 1.05])
+ roc_thresholds = roc_data.get("thresholds")
+ if threshold is not None and isinstance(roc_thresholds, list) and len(roc_thresholds) == len(fpr):
+ try:
+ diffs = [abs(th - threshold) for th in roc_thresholds]
+ best_idx = int(np.argmin(diffs))
+ # dashed guides through the chosen point
+ fig.add_shape(
+ type="line",
+ x0=fpr[best_idx],
+ x1=fpr[best_idx],
+ y0=0,
+ y1=tpr[best_idx],
+ line=dict(color="gray", width=2, dash="dash"),
+ )
+ fig.add_shape(
+ type="line",
+ x0=0,
+ x1=fpr[best_idx],
+ y0=tpr[best_idx],
+ y1=tpr[best_idx],
+ line=dict(color="gray", width=2, dash="dash"),
+ )
+ fig.add_trace(
+ go.Scatter(
+ x=[fpr[best_idx]],
+ y=[tpr[best_idx]],
+ mode="markers",
+ marker=dict(color="black", size=10, symbol="x"),
+ name=f"Threshold={threshold}",
+ hovertemplate="FPR: %{x:.3f}
TPR: %{y:.3f}
Threshold: %{text}",
+ text=[f"{threshold}"],
+ )
+ )
+ except Exception as exc:
+ print(f"Warning: could not add threshold marker to ROC: {exc}")
+
fig.add_annotation(
x=0.5,
y=-0.15,
@@ -786,21 +736,17 @@
xanchor="center",
)
- return {
- "title": "ROC Curve",
- "html": pio.to_html(
- fig,
- full_html=False,
- include_plotlyjs=False,
- config=config,
- ),
- }
+ return _wrap_plot("ROC Curve", fig, config=config)
except Exception as e:
print(f"Error building ROC plot: {e}")
return None
-def _build_precision_recall_plot(label_stats: dict, config: dict) -> Optional[Dict[str, str]]:
+def _build_precision_recall_plot(
+ label_stats: dict,
+ config: dict,
+ threshold: Optional[float] = None,
+) -> Optional[Dict[str, str]]:
"""Build Precision-Recall curve directly from test_statistics.json."""
pr_data = label_stats.get("precision_recall_curve")
if not isinstance(pr_data, dict):
@@ -811,6 +757,8 @@
if not precisions or not recalls or len(precisions) != len(recalls):
return None
+ thresholds = pr_data.get("thresholds")
+
try:
fig = go.Figure()
fig.add_trace(
@@ -851,15 +799,41 @@
fig.update_xaxes(range=[0, 1.0])
fig.update_yaxes(range=[0, 1.05])
- return {
- "title": "Precision-Recall Curve",
- "html": pio.to_html(
- fig,
- full_html=False,
- include_plotlyjs=False,
- config=config,
- ),
- }
+ if threshold is not None and isinstance(thresholds, list) and len(thresholds) == len(recalls):
+ try:
+ diffs = [abs(th - threshold) for th in thresholds]
+ best_idx = int(np.argmin(diffs))
+ fig.add_shape(
+ type="line",
+ x0=recalls[best_idx],
+ x1=recalls[best_idx],
+ y0=0,
+ y1=precisions[best_idx],
+ line=dict(color="gray", width=2, dash="dash"),
+ )
+ fig.add_shape(
+ type="line",
+ x0=0,
+ x1=recalls[best_idx],
+ y0=precisions[best_idx],
+ y1=precisions[best_idx],
+ line=dict(color="gray", width=2, dash="dash"),
+ )
+ fig.add_trace(
+ go.Scatter(
+ x=[recalls[best_idx]],
+ y=[precisions[best_idx]],
+ mode="markers",
+ marker=dict(color="black", size=10, symbol="x"),
+ name=f"Threshold={threshold}",
+ hovertemplate="Recall: %{x:.3f}
Precision: %{y:.3f}
Threshold: %{text}",
+ text=[f"{threshold}"],
+ )
+ )
+ except Exception as exc:
+ print(f"Warning: could not add threshold marker to PR: {exc}")
+
+ return _wrap_plot("Precision-Recall Curve", fig, config=config)
except Exception as e:
print(f"Error building Precision-Recall plot: {e}")
return None
@@ -869,7 +843,6 @@
predictions_path: str,
label_data_path: Optional[str] = None,
split_value: int = 2,
- threshold: Optional[float] = None,
) -> List[Dict[str, str]]:
"""Generate diagnostic plots from predictions.csv for classification tasks."""
preds_file = Path(predictions_path)
@@ -883,12 +856,89 @@
return []
plots: List[Dict[str, str]] = []
+ labels_from_dataset: Optional[pd.Series] = None
+
+ filtered_by_split = False
+
+ # If a split column exists, focus on the requested split (e.g., validation=1, test=2).
+ # If not, but label_data_path is available and matches row count, use it to filter predictions.
+ if SPLIT_COLUMN_NAME in df_pred.columns:
+ df_pred = df_pred[df_pred[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True)
+ if df_pred.empty:
+ return []
+ filtered_by_split = True
+ elif label_data_path and Path(label_data_path).exists():
+ try:
+ df_labels_all = pd.read_csv(label_data_path)
+ if SPLIT_COLUMN_NAME in df_labels_all.columns and len(df_labels_all) == len(df_pred):
+ split_mask = pd.to_numeric(df_labels_all[SPLIT_COLUMN_NAME], errors="coerce") == split_value
+ labels_from_dataset = df_labels_all.loc[split_mask, LABEL_COLUMN_NAME].reset_index(drop=True)
+ df_pred = df_pred.loc[split_mask].reset_index(drop=True)
+ if df_pred.empty:
+ return []
+ filtered_by_split = True
+ except Exception as exc:
+ print(f"Warning: Unable to filter predictions by split from label data: {exc}")
+
+ # Fallback: no split info available. Assume the predictions file is already filtered
+ # (common for test-only exports) and avoid heuristic slicing that could discard rows.
+ if not filtered_by_split:
+ if split_value != 2:
+ return []
+
+ def _strip_prob_prefix(col: str) -> str:
+ if col.startswith("label_probabilities_"):
+ return col.replace("label_probabilities_", "")
+ if col.startswith("probabilities_"):
+ return col.replace("probabilities_", "")
+ return col
+
+ def _maybe_expand_probabilities_column(df: pd.DataFrame, labels_guess: List[str]) -> List[str]:
+ """If only a single 'probabilities' column exists (list-like), expand it into per-class columns."""
+ if "probabilities" not in df.columns:
+ return []
+ try:
+ # Parse first non-null entry to infer length
+ first_val = df["probabilities"].dropna().iloc[0]
+ parsed = first_val
+ if isinstance(first_val, str):
+ parsed = json.loads(first_val)
+ probs = list(parsed)
+ n = len(probs)
+ if n == 0:
+ return []
+ # Build labels: prefer provided guess; otherwise numeric
+ if labels_guess and len(labels_guess) == n:
+ labels_use = labels_guess
+ else:
+ labels_use = [str(i) for i in range(n)]
+ # Expand column
+ for idx, lbl in enumerate(labels_use):
+ df[f"probabilities_{lbl}"] = df["probabilities"].apply(
+ lambda v: (json.loads(v)[idx] if isinstance(v, str) else list(v)[idx]) if pd.notnull(v) else np.nan
+ )
+ return [f"probabilities_{lbl}" for lbl in labels_use]
+ except Exception:
+ return []
# Identify probability columns
prob_cols = [
- c for c in df_pred.columns
- if c.startswith("label_probabilities_") and c != "label_probabilities"
+ c
+ for c in df_pred.columns
+ if (
+ (c.startswith("label_probabilities_") or c.startswith("probabilities_"))
+ and c != "label_probabilities"
+ )
]
+ if not prob_cols and "label_probability" in df_pred.columns:
+ prob_cols = ["label_probability"]
+ if not prob_cols and "probability" in df_pred.columns:
+ prob_cols = ["probability"]
+ if not prob_cols and "prediction_probability" in df_pred.columns:
+ prob_cols = ["prediction_probability"]
+ if not prob_cols and "probabilities" in df_pred.columns:
+ labels_guess = sorted([str(u) for u in pd.unique(df_pred[LABEL_COLUMN_NAME])])
+ prob_cols = _maybe_expand_probabilities_column(df_pred, labels_guess)
prob_cols_sorted = sorted(prob_cols)
def _select_positive_prob():
@@ -897,14 +947,14 @@
# Prefer a column indicating positive/event/true/1
preferred_keys = ("event", "true", "positive", "pos", "1")
for col in prob_cols_sorted:
- suffix = col.replace("label_probabilities_", "").lower()
+ suffix = _strip_prob_prefix(col).lower()
if any(k in suffix for k in preferred_keys):
return col, suffix
if len(prob_cols_sorted) == 2:
col = prob_cols_sorted[1]
- return col, col.replace("label_probabilities_", "")
+ return col, _strip_prob_prefix(col)
col = prob_cols_sorted[0]
- return col, col.replace("label_probabilities_", "")
+ return col, _strip_prob_prefix(col)
pos_prob_col, pos_label_hint = _select_positive_prob()
pos_prob_series = df_pred[pos_prob_col] if pos_prob_col and pos_prob_col in df_pred else None
@@ -920,6 +970,8 @@
# True labels
def _extract_labels():
+ if labels_from_dataset is not None:
+ return labels_from_dataset
candidates = [
LABEL_COLUMN_NAME,
f"{LABEL_COLUMN_NAME}_ground_truth",
@@ -975,10 +1027,7 @@
height=500,
)
_style_fig(fig_conf)
- plots.append({
- "title": "Prediction Confidence Distribution",
- "html": pio.to_html(fig_conf, full_html=False, include_plotlyjs=False),
- })
+ plots.append(_wrap_plot("Prediction Confidence Distribution", fig_conf))
# The remaining plots require true labels and a positive-class probability
if labels_series is None or pos_prob_series is None:
@@ -1004,116 +1053,470 @@
y_true = (y_true_raw == positive_label).astype(int).values
- # Plot 2: Calibration Curve
- bins = np.linspace(0.0, 1.0, 11)
- bin_ids = np.digitize(y_score, bins, right=True)
- bin_centers = []
- frac_positives = []
- for b in range(1, len(bins)):
- mask = bin_ids == b
- if not np.any(mask):
- continue
- bin_centers.append(y_score[mask].mean())
- frac_positives.append(y_true[mask].mean())
- if bin_centers and frac_positives:
- fig_cal = go.Figure()
- fig_cal.add_trace(
+ # Utility: compute calibration points
+ def _calibration_points(y_true_bin: np.ndarray, scores: np.ndarray):
+ bins = np.linspace(0.0, 1.0, 11)
+ bin_ids = np.digitize(scores, bins, right=True)
+ bin_centers, frac_positives = [], []
+ for b in range(1, len(bins)):
+ mask = bin_ids == b
+ if not np.any(mask):
+ continue
+ bin_centers.append(scores[mask].mean())
+ frac_positives.append(y_true_bin[mask].mean())
+ return bin_centers, frac_positives
+
+ # Plot 2: Calibration Curve (multi-class aware; one-vs-rest per label)
+ label_prob_map = {}
+ for col in prob_cols_sorted:
+ if col.startswith("label_probabilities_"):
+ cls = col.replace("label_probabilities_", "")
+ label_prob_map[cls] = col
+
+ unique_label_strs = [str(u) for u in unique_labels_list]
+ if len(label_prob_map) > 1 and len(unique_label_strs) > 2:
+ # Skip multi-class calibration curve for now (not informative in current report)
+ pass
+ else:
+ # Binary/unknown fallback (previous behavior)
+ bin_centers, frac_positives = _calibration_points(y_true, y_score)
+ if bin_centers and frac_positives:
+ fig_cal = go.Figure()
+ fig_cal.add_trace(
+ go.Scatter(
+ x=bin_centers,
+ y=frac_positives,
+ mode="lines+markers",
+ name="Calibration",
+ line=dict(color="#2ca02c", width=4),
+ )
+ )
+ fig_cal.add_trace(
+ go.Scatter(
+ x=[0, 1],
+ y=[0, 1],
+ mode="lines",
+ name="Perfect Calibration",
+ line=dict(color="gray", width=2, dash="dash"),
+ )
+ )
+ fig_cal.update_layout(
+ title=dict(text="Calibration Curve", x=0.5),
+ xaxis_title="Predicted probability",
+ yaxis_title="Observed frequency",
+ width=700,
+ height=500,
+ )
+ _style_fig(fig_cal)
+ plots.append(
+ _wrap_plot(
+ "Calibration Curve (Predicted Probability vs Observed Frequency)",
+ fig_cal,
+ )
+ )
+
+ return plots
+
+
+def build_binary_threshold_plot(
+ predictions_path: str,
+ label_data_path: Optional[str] = None,
+ split_value: int = 1,
+) -> Optional[Dict[str, str]]:
+ """Build a binary threshold sweep plot (accuracy, precision, recall, F1) for a given split."""
+ preds_file = Path(predictions_path)
+ if not preds_file.exists():
+ return None
+
+ try:
+ df_pred = pd.read_csv(predictions_path)
+ except Exception as exc:
+ print(f"Warning: Unable to read predictions CSV for threshold plot: {exc}")
+ return None
+
+ labels_from_dataset: Optional[pd.Series] = None
+ df_full = df_pred.copy()
+
+ def _filter_by_split(df: pd.DataFrame, split_val: int) -> pd.DataFrame:
+ if SPLIT_COLUMN_NAME in df.columns:
+ return df[df[SPLIT_COLUMN_NAME] == split_val].reset_index(drop=True)
+ return df
+
+ # Try preferred split, then fallback to others with data (val -> test -> train)
+ candidate_splits = [split_value, 2, 0, 1] if split_value == 1 else [split_value, 1, 0, 2]
+ df_candidate = pd.DataFrame()
+ used_split: Optional[int] = None
+ for sv in candidate_splits:
+ df_candidate = _filter_by_split(df_full, sv)
+ if not df_candidate.empty:
+ used_split = sv
+ break
+ if used_split is None:
+ df_candidate = df_full
+ df_pred = df_candidate.reset_index(drop=True)
+
+ # If still empty (e.g., split column exists but no rows for candidates), fall back to all rows
+ if df_pred.empty:
+ df_pred = df_full.reset_index(drop=True)
+ labels_from_dataset = None
+
+ if label_data_path and Path(label_data_path).exists():
+ try:
+ df_labels_all = pd.read_csv(label_data_path)
+ if SPLIT_COLUMN_NAME in df_labels_all.columns and len(df_labels_all) == len(df_full):
+ mask = (
+ pd.to_numeric(df_labels_all[SPLIT_COLUMN_NAME], errors="coerce") == used_split
+ if used_split is not None and SPLIT_COLUMN_NAME in df_labels_all.columns
+ else pd.Series([True] * len(df_full))
+ )
+ labels_from_dataset = df_labels_all.loc[mask, LABEL_COLUMN_NAME].reset_index(drop=True)
+ if len(labels_from_dataset) == len(df_pred):
+ labels_from_dataset = labels_from_dataset.reset_index(drop=True)
+ except Exception as exc:
+ print(f"Warning: Unable to align labels for threshold plot: {exc}")
+
+ # Identify probability columns
+ prob_cols = [
+ c
+ for c in df_pred.columns
+ if (
+ (c.startswith("label_probabilities_") or c.startswith("probabilities_"))
+ and c != "label_probabilities"
+ )
+ ]
+ if not prob_cols and "probabilities" in df_pred.columns:
+ labels_guess = sorted([str(u) for u in pd.unique(df_pred.get(LABEL_COLUMN_NAME, []))])
+ # reuse expansion logic from diagnostics
+ try:
+ first_val = df_pred["probabilities"].dropna().iloc[0]
+ parsed = json.loads(first_val) if isinstance(first_val, str) else list(first_val)
+ n = len(parsed)
+ if n > 0:
+ if labels_guess and len(labels_guess) == n:
+ labels_use = labels_guess
+ else:
+ labels_use = [str(i) for i in range(n)]
+ for idx, lbl in enumerate(labels_use):
+ df_pred[f"probabilities_{lbl}"] = df_pred["probabilities"].apply(
+ lambda v: (json.loads(v)[idx] if isinstance(v, str) else list(v)[idx]) if pd.notnull(v) else np.nan
+ )
+ prob_cols = [f"probabilities_{lbl}" for lbl in labels_use]
+ except Exception:
+ prob_cols = []
+ prob_cols_sorted = sorted(prob_cols)
+
+ def _strip_prob_prefix(col: str) -> str:
+ if col.startswith("label_probabilities_"):
+ return col.replace("label_probabilities_", "")
+ if col.startswith("probabilities_"):
+ return col.replace("probabilities_", "")
+ return col
+
+ # True labels
+ def _extract_labels():
+ if labels_from_dataset is not None:
+ return labels_from_dataset
+ for col in [
+ LABEL_COLUMN_NAME,
+ f"{LABEL_COLUMN_NAME}_ground_truth",
+ f"{LABEL_COLUMN_NAME}__ground_truth",
+ f"{LABEL_COLUMN_NAME}_target",
+ f"{LABEL_COLUMN_NAME}__target",
+ "label",
+ "label_true",
+ "label_predictions",
+ "prediction",
+ ]:
+ if col in df_pred.columns and col not in prob_cols_sorted:
+ return df_pred[col]
+ return None
+
+ labels_series = _extract_labels()
+ if labels_series is None or not prob_cols_sorted:
+ return None
+
+ # Positive prob column selection
+ preferred_keys = ("event", "true", "positive", "pos", "1")
+ pos_prob_col = None
+ for col in prob_cols_sorted:
+ suffix = _strip_prob_prefix(col).lower()
+ if any(k in suffix for k in preferred_keys):
+ pos_prob_col = col
+ break
+ if pos_prob_col is None:
+ pos_prob_col = prob_cols_sorted[-1]
+
+ min_len = min(len(labels_series), len(df_pred[pos_prob_col]))
+ if min_len == 0:
+ return None
+
+ y_true = np.array(labels_series.iloc[:min_len])
+ # map to binary 0/1
+ unique_labels = pd.unique(y_true)
+ if len(unique_labels) < 2:
+ return None
+ positive_label = unique_labels[1] if len(unique_labels) >= 2 else unique_labels[0]
+ y_true_bin = (y_true == positive_label).astype(int)
+ y_score = np.array(df_pred[pos_prob_col].iloc[:min_len], dtype=float)
+
+ thresholds = np.linspace(0.0, 1.0, 101)
+ accs: List[float] = []
+ precs: List[float] = []
+ recs: List[float] = []
+ f1s: List[float] = []
+ for t in thresholds:
+ preds = (y_score >= t).astype(int)
+ accs.append(accuracy_score(y_true_bin, preds))
+ precs.append(precision_score(y_true_bin, preds, zero_division=0))
+ recs.append(recall_score(y_true_bin, preds, zero_division=0))
+ f1s.append(f1_score(y_true_bin, preds, zero_division=0))
+
+ best_idx = int(np.argmax(f1s))
+ best_thr = thresholds[best_idx]
+
+ fig = go.Figure()
+ fig.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4)))
+ fig.add_trace(go.Scatter(x=thresholds, y=precs, mode="lines", name="Precision", line=dict(width=4)))
+ fig.add_trace(go.Scatter(x=thresholds, y=recs, mode="lines", name="Recall", line=dict(width=4)))
+ fig.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1-Score", line=dict(width=4)))
+ fig.add_shape(
+ type="line",
+ x0=best_thr,
+ x1=best_thr,
+ y0=0,
+ y1=1,
+ line=dict(color="gray", width=2, dash="dash"),
+ )
+ fig.update_layout(
+ title=dict(text="Threshold plot", x=0.5),
+ xaxis_title="Threshold",
+ yaxis_title="Metric value",
+ yaxis=dict(range=[0, 1]),
+ width=760,
+ height=520,
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
+ )
+ _style_fig(fig)
+ return _wrap_plot("Threshold plot", fig, include_js=True)
+
+
+def build_multiclass_roc_pr_plots(
+ predictions_path: str,
+ split_value: int = 2,
+) -> List[Dict[str, str]]:
+ """Build one-vs-rest ROC and PR curves for multi-class classification from predictions."""
+ preds_file = Path(predictions_path)
+ if not preds_file.exists():
+ return []
+ try:
+ df_pred = pd.read_csv(predictions_path)
+ except Exception as exc:
+ print(f"Warning: Unable to read predictions CSV: {exc}")
+ return []
+
+ if SPLIT_COLUMN_NAME in df_pred.columns:
+ df_pred = df_pred[df_pred[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True)
+ if df_pred.empty:
+ return []
+
+ if LABEL_COLUMN_NAME not in df_pred.columns:
+ return []
+
+ # Identify per-class probability columns
+ prob_cols = [
+ c
+ for c in df_pred.columns
+ if (
+ (c.startswith("label_probabilities_") or c.startswith("probabilities_"))
+ and c != "label_probabilities"
+ )
+ ]
+ if not prob_cols:
+ return []
+ labels = [c.replace("label_probabilities_", "").replace("probabilities_", "") for c in prob_cols]
+ labels_sorted = sorted(labels)
+
+ # Ensure all labels are present as probability columns
+ prob_map = {
+ c.replace("label_probabilities_", "").replace("probabilities_", ""): c
+ for c in prob_cols
+ }
+ if len(labels_sorted) < 3:
+ return []
+
+ y_true_raw = df_pred[LABEL_COLUMN_NAME].astype(str)
+ # Drop rows with NaN probabilities across any class to avoid metric errors
+ prob_matrix = df_pred[[prob_map[lbl] for lbl in labels_sorted]].astype(float)
+ mask_valid = ~prob_matrix.isnull().any(axis=1)
+ prob_matrix = prob_matrix[mask_valid]
+ y_true_raw = y_true_raw[mask_valid]
+ if prob_matrix.empty:
+ return []
+
+ y_true_bin = label_binarize(y_true_raw, classes=labels_sorted)
+ y_score = prob_matrix.to_numpy()
+
+ plots: List[Dict[str, str]] = []
+
+ # ROC: one-vs-rest + micro
+ fig_roc = go.Figure()
+ added_any = False
+ for idx, lbl in enumerate(labels_sorted):
+ if y_true_bin[:, idx].sum() == 0 or y_true_bin[:, idx].sum() == len(y_true_bin):
+ continue # skip classes without both positives and negatives
+ fpr, tpr, _ = roc_curve(y_true_bin[:, idx], y_score[:, idx])
+ fig_roc.add_trace(
go.Scatter(
- x=bin_centers,
- y=frac_positives,
- mode="lines+markers",
- name="Calibration",
- line=dict(color="#2ca02c", width=4),
- )
- )
- fig_cal.add_trace(
- go.Scatter(
- x=[0, 1],
- y=[0, 1],
+ x=fpr,
+ y=tpr,
mode="lines",
- name="Perfect Calibration",
- line=dict(color="gray", width=2, dash="dash"),
+ name=f"{lbl} (AUC={auc(fpr, tpr):.3f})",
+ line=dict(width=3),
)
)
- fig_cal.update_layout(
- title=dict(text="Calibration Curve", x=0.5),
- xaxis_title="Predicted probability",
- yaxis_title="Observed frequency",
- width=700,
- height=500,
+ added_any = True
+ # Micro-average only if we have mixed labels
+ if y_true_bin.sum() > 0 and y_true_bin.sum() < y_true_bin.size:
+ fpr_micro, tpr_micro, _ = roc_curve(y_true_bin.ravel(), y_score.ravel())
+ fig_roc.add_trace(
+ go.Scatter(
+ x=fpr_micro,
+ y=tpr_micro,
+ mode="lines",
+ name=f"Micro-average (AUC={auc(fpr_micro, tpr_micro):.3f})",
+ line=dict(width=3, dash="dash"),
+ )
)
- _style_fig(fig_cal)
- plots.append({
- "title": "Calibration Curve (Predicted Probability vs Observed Frequency)",
- "html": pio.to_html(fig_cal, full_html=False, include_plotlyjs=False),
- })
-
- # Plot 3: Threshold vs Metrics
- thresholds = np.linspace(0.0, 1.0, 21)
- accs, f1s, sens, specs = [], [], [], []
- for t in thresholds:
- y_pred = (y_score >= t).astype(int)
- tp = np.sum((y_true == 1) & (y_pred == 1))
- tn = np.sum((y_true == 0) & (y_pred == 0))
- fp = np.sum((y_true == 0) & (y_pred == 1))
- fn = np.sum((y_true == 1) & (y_pred == 0))
- acc = (tp + tn) / max(len(y_true), 1)
- prec = tp / max(tp + fp, 1e-9)
- rec = tp / max(tp + fn, 1e-9)
- f1 = 0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)
- sensitivity = rec
- specificity = tn / max(tn + fp, 1e-9)
- accs.append(acc)
- f1s.append(f1)
- sens.append(sensitivity)
- specs.append(specificity)
-
- fig_thresh = go.Figure()
- fig_thresh.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4)))
- fig_thresh.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1", line=dict(width=4)))
- fig_thresh.add_trace(go.Scatter(x=thresholds, y=sens, mode="lines", name="Sensitivity", line=dict(width=4)))
- fig_thresh.add_trace(go.Scatter(x=thresholds, y=specs, mode="lines", name="Specificity", line=dict(width=4)))
- fig_thresh.update_layout(
- title=dict(text="Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", x=0.5),
- xaxis_title="Decision threshold",
- yaxis_title="Metric value",
- width=700,
- height=500,
+ added_any = True
+ if not added_any:
+ return []
+ fig_roc.add_trace(
+ go.Scatter(
+ x=[0, 1],
+ y=[0, 1],
+ mode="lines",
+ name="Random",
+ line=dict(color="gray", width=2, dash="dot"),
+ )
+ )
+ fig_roc.update_layout(
+ title=dict(text="Multi-class ROC-AUC (one-vs-rest)", x=0.5),
+ xaxis_title="False Positive Rate",
+ yaxis_title="True Positive Rate",
+ width=820,
+ height=620,
legend=dict(
- x=0.7,
- y=0.2,
+ x=0.62,
+ y=0.05,
bgcolor="rgba(255,255,255,0.9)",
bordercolor="rgba(0,0,0,0.2)",
borderwidth=1,
),
- shapes=[
- dict(
- type="line",
- x0=threshold,
- x1=threshold,
- y0=0,
- y1=1,
- xref="x",
- yref="paper",
- line=dict(color="#d62728", width=2, dash="dash"),
+ )
+ _style_fig(fig_roc)
+ plots.append(_wrap_plot("Multi-class ROC-AUC (one-vs-rest)", fig_roc))
+
+ # PR: one-vs-rest + micro AP
+ fig_pr = go.Figure()
+ added_pr = False
+ for idx, lbl in enumerate(labels_sorted):
+ if y_true_bin[:, idx].sum() == 0:
+ continue
+ prec, rec, _ = precision_recall_curve(y_true_bin[:, idx], y_score[:, idx])
+ ap = average_precision_score(y_true_bin[:, idx], y_score[:, idx])
+ fig_pr.add_trace(
+ go.Scatter(
+ x=rec,
+ y=prec,
+ mode="lines",
+ name=f"{lbl} (AP={ap:.3f})",
+ line=dict(width=3),
)
- ] if isinstance(threshold, (int, float)) else [],
- annotations=[
- dict(
- x=threshold,
- y=1.02,
- xref="x",
- yref="paper",
- showarrow=False,
- text=f"Threshold = {threshold:.2f}",
- font=dict(size=11, color="#d62728"),
+ )
+ added_pr = True
+ if y_true_bin.sum() > 0:
+ prec_micro, rec_micro, _ = precision_recall_curve(y_true_bin.ravel(), y_score.ravel())
+ ap_micro = average_precision_score(y_true_bin, y_score, average="micro")
+ fig_pr.add_trace(
+ go.Scatter(
+ x=rec_micro,
+ y=prec_micro,
+ mode="lines",
+ name=f"Micro-average (AP={ap_micro:.3f})",
+ line=dict(width=3, dash="dash"),
)
- ] if isinstance(threshold, (int, float)) else [],
+ )
+ added_pr = True
+ if not added_pr:
+ return plots
+ fig_pr.update_layout(
+ title=dict(text="Multi-class Precision-Recall (one-vs-rest)", x=0.5),
+ xaxis_title="Recall",
+ yaxis_title="Precision",
+ width=820,
+ height=620,
+ legend=dict(
+ x=0.62,
+ y=0.05,
+ bgcolor="rgba(255,255,255,0.9)",
+ bordercolor="rgba(0,0,0,0.2)",
+ borderwidth=1,
+ ),
)
- _style_fig(fig_thresh)
- plots.append({
- "title": "Threshold Sweep: Accuracy, F1, Sensitivity, Specificity",
- "html": pio.to_html(fig_thresh, full_html=False, include_plotlyjs=False),
- })
+ _style_fig(fig_pr)
+ plots.append(_wrap_plot("Multi-class Precision-Recall (one-vs-rest)", fig_pr))
return plots
+
+
+def build_multiclass_metric_plots(test_stats_path: str) -> List[Dict[str, str]]:
+ """Alternative multi-class transparency plots using test_statistics.json per-class stats."""
+ ts_path = Path(test_stats_path)
+ if not ts_path.exists():
+ return []
+ try:
+ with open(ts_path, "r") as f:
+ test_stats = json.load(f)
+ except Exception:
+ return []
+
+ label_stats = test_stats.get("label", {})
+ pcs = label_stats.get("per_class_stats", {})
+ if not pcs:
+ return []
+ classes = list(pcs.keys())
+ if not classes:
+ return []
+
+ metrics = ["precision", "recall", "f1_score", "specificity", "accuracy"]
+ fig_bar = go.Figure()
+ for metric in metrics:
+ values = []
+ for cls in classes:
+ v = pcs.get(cls, {}).get(metric)
+ values.append(v if isinstance(v, (int, float)) else 0)
+ fig_bar.add_trace(
+ go.Bar(
+ x=classes,
+ y=values,
+ name=metric.replace("_", " ").title(),
+ )
+ )
+ fig_bar.update_layout(
+ title=dict(text="Per-Class Metrics (Test)", x=0.5),
+ xaxis_title="Class",
+ yaxis_title="Metric value",
+ barmode="group",
+ width=900,
+ height=600,
+ legend=dict(
+ x=1.02,
+ y=1.0,
+ bgcolor="rgba(255,255,255,0.9)",
+ bordercolor="rgba(0,0,0,0.2)",
+ borderwidth=1,
+ ),
+ )
+ _style_fig(fig_bar)
+
+ return [_wrap_plot("Per-Class Metrics (Test)", fig_bar)]