changeset 15:d17e3a1b8659 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
author goeckslab
date Fri, 28 Nov 2025 15:45:49 +0000
parents 94cd9ac4a9b1
children
files html_structure.py image_learner.xml image_learner_cli.py image_workflow.py ludwig_backend.py plotly_plots.py
diffstat 6 files changed, 1486 insertions(+), 293 deletions(-) [+]
line wrap: on
line diff
--- a/html_structure.py	Wed Nov 26 22:00:32 2025 +0000
+++ b/html_structure.py	Fri Nov 28 15:45:49 2025 +0000
@@ -22,21 +22,31 @@
     output_type: Optional[str] = None,
 ) -> str:
     display_keys = [
+        "architecture",
+        "pretrained",
+        "trainable",
+        "target_column",
         "task_type",
-        "model_name",
+        "validation_metric",
+        "loss_function",
+        "threshold",
         "epochs",
+        "total_epochs",
         "batch_size",
         "fine_tune",
         "use_pretrained",
         "learning_rate",
+        "optimizer",
         "random_seed",
         "early_stop",
-        "threshold",
+        "use_mixed_precision",
+        "gradient_clipping",
     ]
 
     rows = []
 
     for key in display_keys:
+        val_str = "N/A"
         val = config.get(key, None)
         if key == "threshold":
             if output_type != "binary":
@@ -49,7 +59,7 @@
             if key == "task_type":
                 val_str = val.title() if isinstance(val, str) else "N/A"
             elif key == "batch_size":
-                if val is not None:
+                if isinstance(val, (int, float)):
                     val_str = int(val)
                 else:
                     val = "auto"
@@ -120,6 +130,21 @@
                         )
                     else:
                         val_str = val
+            elif key == "pretrained":
+                if isinstance(val, bool):
+                    val_str = "Yes (ImageNet)" if val else "No"
+                else:
+                    val_str = val if val is not None else "N/A"
+            elif key == "trainable":
+                if isinstance(val, bool):
+                    val_str = "Trainable" if val else "Frozen"
+                else:
+                    val_str = val if val is not None else "N/A"
+            elif key == "use_mixed_precision":
+                if isinstance(val, bool):
+                    val_str = "Yes" if val else "No"
+                else:
+                    val_str = val if val is not None else "N/A"
             else:
                 val_str = val if val is not None else "N/A"
             if val_str == "N/A" and key not in ["task_type"]:
@@ -155,7 +180,7 @@
         )
 
     html = f"""
-        <h2 style="text-align: center;">Model and Training Summary</h2>
+        <h2 style="text-align: center;">Training Configuration (Model, Data, Metrics)</h2>
         <div style="display: flex; justify-content: center;">
           <table style="border-collapse: collapse; width: 100%; table-layout: fixed;">
             <thead><tr>
@@ -519,15 +544,15 @@
 def build_tabbed_html(metrics_html: str, train_val_html: str, test_html: str) -> str:
     """
     Build a 3-tab interface:
-      - Config and Results Summary
+      - Config and Overall Performance Summary
       - Train/Validation Results
       - Test Results
     Includes a persistent "Help" button that toggles the metrics modal.
     """
     return f"""
 <div class="tabs">
-  <div class="tab active" onclick="showTab('metrics')">Config and Results Summary</div>
-  <div class="tab" onclick="showTab('trainval')">Train/Validation Results</div>
+  <div class="tab active" onclick="showTab('metrics')">Config and Overall Performance Summary</div>
+  <div class="tab" onclick="showTab('trainval')">Training and Validation Results</div>
   <div class="tab" onclick="showTab('test')">Test Results</div>
   <button id="openMetricsHelp" class="help-btn" title="Open metrics help">Help</button>
 </div>
--- a/image_learner.xml	Wed Nov 26 22:00:32 2025 +0000
+++ b/image_learner.xml	Fri Nov 28 15:45:49 2025 +0000
@@ -1,4 +1,4 @@
-<tool id="image_learner" name="Image Learner" version="0.1.3" profile="22.05">
+<tool id="image_learner" name="Image Learner" version="0.1.4" profile="22.05">
     <description>trains and evaluates an image classification/regression model</description>
     <requirements>
         <container type="docker">quay.io/goeckslab/galaxy-ludwig-gpu:0.10.1</container>
@@ -29,6 +29,16 @@
             ln -sf '$input_csv' "./${sanitized_input_csv}";
             #end if
 
+            #if $task_selection.task == "binary"
+                #set $selected_validation_metric = $task_selection.validation_metric_binary
+            #elif $task_selection.task == "classification"
+                #set $selected_validation_metric = $task_selection.validation_metric_multiclass
+            #elif $task_selection.task == "regression"
+                #set $selected_validation_metric = $task_selection.validation_metric_regression
+            #else
+                #set $selected_validation_metric = None
+            #end if
+
             python '$__tool_directory__/image_learner_cli.py'
                 --csv-file "./${sanitized_input_csv}"
                 --image-zip "$image_zip"
@@ -39,27 +49,38 @@
                         --fine-tune
                     #end if
                 #end if
-                #if $customize_defaults == "true"
-                    #if $epochs
-                        --epochs "$epochs"
+                #if $advanced_settings.customize_defaults == "true"
+                    #if $advanced_settings.epochs
+                        --epochs "$advanced_settings.epochs"
                     #end if
-                    #if $early_stop
-                        --early-stop "$early_stop"
+                    #if $advanced_settings.early_stop
+                        --early-stop "$advanced_settings.early_stop"
                     #end if
-                    #if $learning_rate_define == "true"
-                        --learning-rate "$learning_rate"
+                    #if $advanced_settings.learning_rate_condition.learning_rate_define == "true"
+                        --learning-rate "$advanced_settings.learning_rate_condition.learning_rate"
                     #end if
-                    #if $batch_size_define == "true"
-                        --batch-size "$batch_size"
+                    #if $advanced_settings.batch_size_condition.batch_size_define == "true"
+                        --batch-size "$advanced_settings.batch_size_condition.batch_size"
                     #end if
-                    --split-probabilities "$train_split" "$val_split" "$test_split"
-                    #if $threshold
-                        --threshold "$threshold"
+                    --split-probabilities "$advanced_settings.train_split" "$advanced_settings.val_split" "$advanced_settings.test_split"
+                    #if $advanced_settings.threshold
+                        --threshold "$advanced_settings.threshold"
                     #end if
                 #end if
                 #if $augmentation
                     --augmentation "$augmentation"
                 #end if
+                #if $selected_validation_metric
+                    --validation-metric "$selected_validation_metric"
+                #end if
+                #if $column_override.override_columns == "true"
+                    #if $column_override.target_column
+                        --target-column "$column_override.target_column"
+                    #end if
+                    #if $column_override.image_column
+                        --image-column "$column_override.image_column"
+                    #end if
+                #end if
                 --image-resize "$image_resize"
                 --random-seed "$random_seed"
                 --output-dir "." &&
@@ -74,6 +95,68 @@
     <inputs>
         <param name="input_csv" type="data" format="csv" optional="false" label="the metadata csv containing image_path column, label column and optional split column" />
         <param name="image_zip" type="data" format="zip" optional="false" label="Image zip" help="Image zip file containing your image data"/>
+        <conditional name="task_selection">
+            <param name="task" type="select" label="Task type" help="Pick task to see only metrics Ludwig accepts for that task; Auto lets the tool infer task and metric.">
+                <option value="auto" selected="true">Auto (infer and use defaults)</option>
+                <option value="binary">Binary Classification</option>
+                <option value="classification">Multi-class Classification</option>
+                <option value="regression">Regression</option>
+            </param>
+            <when value="binary">
+                <param name="validation_metric_binary" type="select" optional="true" label="Validation metric (binary)" help="Metrics accepted by Ludwig for binary outputs.">
+                    <option value="roc_auc" selected="true">ROC-AUC</option>
+                    <option value="accuracy">Accuracy</option>
+                    <option value="balanced_accuracy">Balanced Accuracy</option>
+                    <option value="precision">Precision</option>
+                    <option value="recall">Recall</option>
+                    <option value="f1">F1</option>
+                    <option value="specificity">Specificity</option>
+                    <option value="log_loss">Log Loss</option>
+                    <option value="loss">Loss</option>
+                </param>
+            </when>
+            <when value="classification">
+                <param name="validation_metric_multiclass" type="select" optional="true" label="Validation metric (multi-class)" help="Metrics accepted by Ludwig for multi-class outputs.">
+                    <option value="accuracy" selected="true">Accuracy</option>
+                    <option value="roc_auc">ROC-AUC</option>
+                    <option value="loss">Loss</option>
+                    <option value="balanced_accuracy">Balanced Accuracy</option>
+                    <option value="precision">Precision</option>
+                    <option value="recall">Recall</option>
+                    <option value="f1">F1</option>
+                    <option value="specificity">Specificity</option>
+                    <option value="log_loss">Log Loss</option>
+                </param>
+            </when>
+            <when value="regression">
+                <param name="validation_metric_regression" type="select" optional="true" label="Validation metric (regression)" help="Metrics accepted by Ludwig for regression outputs.">
+                    <option value="pearson_r" selected="true">Pearson r</option>
+                    <option value="mae">MAE</option>
+                    <option value="mse">MSE</option>
+                    <option value="rmse">RMSE</option>
+                    <option value="mape">MAPE</option>
+                    <option value="r2">R²</option>
+                    <option value="explained_variance">Explained Variance</option>
+                    <option value="loss">Loss</option>
+                </param>
+            </when>
+            <when value="auto">
+                <!-- No validation metric selection; tool will infer task and metric. -->
+            </when>
+        </conditional>
+        <conditional name="column_override">
+            <param name="override_columns" type="select" label="Overwrite label and/or image column names?" help="Select yes to specify custom column names instead of the defaults 'label' and 'image_path'.">
+                <option value="false" selected="true">No</option>
+                <option value="true">Yes</option>
+            </param>
+            <when value="true">
+                <param name="target_column" type="text" optional="true" label="Target/label column name" help="Overrides the default 'label' column name in the metadata CSV." />
+                <param name="image_column" type="text" optional="true" label="Image column name" help="Overrides the default 'image_path' column name in the metadata CSV." />
+            </when>
+            <when value="false">
+                <!-- No additional parameters -->
+            </when>
+        </conditional>
         <param name="model_name" type="select" optional="false" label="Select a model for your experiment" >
 
             <option value="resnet18">Resnet18</option>
@@ -325,10 +408,12 @@
             <param name="image_zip" value="mnist_subset.zip" ftype="zip" />
             <param name="model_name" value="resnet18" />
             <param name="augmentation" value="random_horizontal_flip,random_vertical_flip,random_rotate" />
+            <param name="task_selection|task" value="classification" />
+            <param name="task_selection|validation_metric_multiclass" value="accuracy" />
             <output name="output_report">
                 <assert_contents>
-                    <has_text text="Results Summary" />
-                    <has_text text="Train/Validation Results" />
+                    <has_text text="Config and Overall Performance Summary" />
+                    <has_text text="Training and Validation Results" />
                     <has_text text="Test Results" />
                 </assert_contents>
             </output>
@@ -347,8 +432,8 @@
             <param name="model_name" value="resnet18" />
             <output name="output_report">
                 <assert_contents>
-                    <has_text text="Results Summary" />
-                    <has_text text="Train/Validation Results" />
+                    <has_text text="Config and Overall Performance Summary" />
+                    <has_text text="Training and Validation Results" />
                     <has_text text="Test Results" />
                 </assert_contents>
             </output>
@@ -366,8 +451,8 @@
             <param name="model_name" value="caformer_s18" />
             <output name="output_report">
                 <assert_contents>
-                    <has_text text="Results Summary" />
-                    <has_text text="Train/Validation Results" />
+                    <has_text text="Config and Overall Performance Summary" />
+                    <has_text text="Training and Validation Results" />
                     <has_text text="Test Results" />
                 </assert_contents>
             </output>
@@ -388,8 +473,8 @@
             <param name="advanced_settings|epochs" value="5" />
             <output name="output_report">
                 <assert_contents>
-                    <has_text text="Results Summary" />
-                    <has_text text="Train/Validation Results" />
+                    <has_text text="Test Performance Summary" />
+                    <has_text text="Training and Validation Results" />
                     <has_text text="Test Results" />
                 </assert_contents>
             </output>
@@ -410,8 +495,8 @@
             <param name="image_resize" value="384x384" />
             <output name="output_report">
                 <assert_contents>
-                    <has_text text="Results Summary" />
-                    <has_text text="Train/Validation Results" />
+                    <has_text text="Config and Overall Performance Summary" />
+                    <has_text text="Training and Validation Results" />
                     <has_text text="Test Results" />
                 </assert_contents>
             </output>
@@ -433,8 +518,8 @@
             <param name="input_csv" value="binary_classification.csv" ftype="csv" />
             <param name="image_zip" value="binary_images.zip" ftype="zip" />
             <param name="model_name" value="resnet18" />
-            <param name="customize_defaults" value="true" />
-            <param name="threshold" value="0.6" />
+            <param name="advanced_settings|customize_defaults" value="true" />
+            <param name="advanced_settings|threshold" value="0.6" />
             <output name="output_report">
                 <assert_contents>
                     <has_text text="Results Summary" />
@@ -462,8 +547,8 @@
             <param name="input_csv" value="mnist_subset.csv" ftype="csv" />
             <param name="image_zip" value="mnist_subset.zip" ftype="zip" />
             <param name="model_name" value="resnet18" />
-            <param name="customize_defaults" value="true" />
-            <param name="epochs" value="3" />
+            <param name="advanced_settings|customize_defaults" value="true" />
+            <param name="advanced_settings|epochs" value="3" />
             <output name="output_report">
                 <assert_contents>
                     <has_text text="Results Summary" />
@@ -487,11 +572,15 @@
             <param name="model_name" value="resnet18" />
             <param name="advanced_settings|customize_defaults" value="true" />
             <param name="advanced_settings|threshold" value="0.6" />
+            <param name="task_selection|task" value="classification" />
+            <param name="task_selection|validation_metric_multiclass" value="balanced_accuracy" />
             <output name="output_report">
                 <assert_contents>
-                    <has_text text="Accuracy" />
-                    <has_text text="Precision" />
-                    <has_text text="Learning Curves Label Accuracy" />
+                    <has_text text="Config and Overall Performance Summary" />
+                    <has_text text="Training and Validation Results" />
+                    <has_text text="Test Results" />
+                    <has_text text="Threshold" />
+                    <has_text text="0.60" />
                 </assert_contents>
             </output>
             <output_collection name="output_pred_csv" type="list" >
--- a/image_learner_cli.py	Wed Nov 26 22:00:32 2025 +0000
+++ b/image_learner_cli.py	Fri Nov 28 15:45:49 2025 +0000
@@ -142,6 +142,42 @@
             "Overrides default 0.5."
         ),
     )
+    parser.add_argument(
+        "--validation-metric",
+        type=str,
+        default="roc_auc",
+        choices=[
+            "accuracy",
+            "loss",
+            "roc_auc",
+            "balanced_accuracy",
+            "precision",
+            "recall",
+            "f1",
+            "specificity",
+            "log_loss",
+            "pearson_r",
+            "mae",
+            "mse",
+            "rmse",
+            "mape",
+            "r2",
+            "explained_variance",
+        ],
+        help="Metric Ludwig uses to select the best model during training/validation.",
+    )
+    parser.add_argument(
+        "--target-column",
+        type=str,
+        default=None,
+        help="Name of the target/label column in the metadata file (defaults to 'label').",
+    )
+    parser.add_argument(
+        "--image-column",
+        type=str,
+        default=None,
+        help="Name of the image column in the metadata file (defaults to 'image_path').",
+    )
 
     args = parser.parse_args()
 
--- a/image_workflow.py	Wed Nov 26 22:00:32 2025 +0000
+++ b/image_workflow.py	Fri Nov 28 15:45:49 2025 +0000
@@ -127,16 +127,31 @@
             logger.error("Error loading metadata file", exc_info=True)
             raise
 
-        required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME}
-        missing = required - set(df.columns)
-        if missing:
-            raise ValueError(f"Missing CSV columns: {', '.join(missing)}")
+        label_col = self.args.target_column or LABEL_COLUMN_NAME
+        image_col = self.args.image_column or IMAGE_PATH_COLUMN_NAME
+
+        # Remember the user-specified columns for reporting
+        self.args.report_target_column = label_col
+        self.args.report_image_column = image_col
+
+        missing_cols = []
+        if label_col not in df.columns:
+            missing_cols.append(label_col)
+        if image_col not in df.columns:
+            missing_cols.append(image_col)
+        if missing_cols:
+            raise ValueError(
+                f"Missing required column(s) in metadata: {', '.join(missing_cols)}. "
+                "Update the XML selections or rename your columns."
+            )
+
+        if label_col != LABEL_COLUMN_NAME:
+            df = df.rename(columns={label_col: LABEL_COLUMN_NAME})
+        if image_col != IMAGE_PATH_COLUMN_NAME:
+            df = df.rename(columns={image_col: IMAGE_PATH_COLUMN_NAME})
 
         try:
-            # Use relative paths that Ludwig can resolve from its internal working directory
-            df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply(
-                lambda p: str(Path("images") / p)
-            )
+            df = self._map_image_paths_with_search(df)
         except Exception:
             logger.error("Error updating image paths", exc_info=True)
             raise
@@ -205,6 +220,71 @@
         self.label_metadata = metadata
         self.output_type_hint = "binary" if metadata.get("is_binary") else None
 
+    def _map_image_paths_with_search(self, df: pd.DataFrame) -> pd.DataFrame:
+        """Map image identifiers to actual files by searching the extracted directory."""
+        if not self.image_extract_dir:
+            raise RuntimeError("Image directory is not initialized.")
+
+        # Build lookup maps for fast resolution by stem or full name
+        lookup_by_stem = {}
+        lookup_by_name = {}
+        for fpath in self.image_extract_dir.rglob("*"):
+            if fpath.is_file():
+                stem_key = fpath.stem.lower()
+                name_key = fpath.name.lower()
+                # Prefer first encounter; warn on collisions
+                if stem_key in lookup_by_stem and lookup_by_stem[stem_key] != fpath:
+                    logger.warning(
+                        "Multiple files share the same stem '%s'. Using '%s'.",
+                        stem_key,
+                        lookup_by_stem[stem_key],
+                    )
+                else:
+                    lookup_by_stem[stem_key] = fpath
+                if name_key in lookup_by_name and lookup_by_name[name_key] != fpath:
+                    logger.warning(
+                        "Multiple files share the same name '%s'. Using '%s'.",
+                        name_key,
+                        lookup_by_name[name_key],
+                    )
+                else:
+                    lookup_by_name[name_key] = fpath
+
+        resolved_paths = []
+        missing_count = 0
+        missing_samples = []
+
+        for raw in df[IMAGE_PATH_COLUMN_NAME]:
+            raw_str = str(raw)
+            name_key = Path(raw_str).name.lower()
+            stem_key = Path(raw_str).stem.lower()
+            resolved = lookup_by_name.get(name_key) or lookup_by_stem.get(stem_key)
+
+            if resolved is None:
+                missing_count += 1
+                missing_samples.append(raw_str)
+                resolved_paths.append(pd.NA)
+                continue
+
+            try:
+                rel_path = resolved.relative_to(self.image_extract_dir)
+            except ValueError:
+                rel_path = resolved
+            resolved_paths.append(str(Path("images") / rel_path))
+
+        if missing_count:
+            logger.warning(
+                "Unable to locate %d image(s) from the metadata in the extracted images directory.",
+                missing_count,
+            )
+            preview = ", ".join(missing_samples[:5])
+            logger.warning("Missing samples (showing up to 5): %s", preview)
+
+        df = df.copy()
+        df[IMAGE_PATH_COLUMN_NAME] = resolved_paths
+        df = df.dropna(subset=[IMAGE_PATH_COLUMN_NAME]).reset_index(drop=True)
+        return df
+
 # Removed duplicate method
 
     def _detect_image_dimensions(self) -> Tuple[int, int]:
@@ -275,6 +355,9 @@
                 "threshold": self.args.threshold,
                 "label_metadata": self.label_metadata,
                 "output_type_hint": self.output_type_hint,
+                "validation_metric": self.args.validation_metric,
+                "target_column": getattr(self.args, "report_target_column", LABEL_COLUMN_NAME),
+                "image_column": getattr(self.args, "report_image_column", IMAGE_PATH_COLUMN_NAME),
             }
             yaml_str = self.backend.prepare_config(backend_args, split_cfg)
 
@@ -297,6 +380,9 @@
 
             if ran_ok:
                 logger.info("Workflow completed successfully.")
+                # Convert predictions parquet → csv
+                self.backend.convert_parquet_to_csv(self.args.output_dir)
+                logger.info("Converted Parquet to CSV.")
                 # Generate a very small set of plots to conserve disk space
                 self.backend.generate_plots(self.args.output_dir)
                 # Build HTML report (robust to missing metrics)
@@ -307,9 +393,6 @@
                     split_info,
                 )
                 logger.info(f"HTML report generated at: {report_file}")
-                # Convert predictions parquet → csv
-                self.backend.convert_parquet_to_csv(self.args.output_dir)
-                logger.info("Converted Parquet to CSV.")
                 # Post-process cleanup to reduce disk footprint for subsequent tests
                 try:
                     self._postprocess_cleanup(self.args.output_dir)
--- a/ludwig_backend.py	Wed Nov 26 22:00:32 2025 +0000
+++ b/ludwig_backend.py	Fri Nov 28 15:45:49 2025 +0000
@@ -31,7 +31,13 @@
 )
 from ludwig.utils.data_utils import get_split_path
 from metaformer_setup import get_visualizations_registry, META_DEFAULT_CFGS
-from plotly_plots import build_classification_plots
+from plotly_plots import (
+    build_classification_plots,
+    build_prediction_diagnostics,
+    build_regression_test_plots,
+    build_regression_train_val_plots,
+    build_train_validation_plots,
+)
 from utils import detect_output_type, extract_metrics_from_json
 
 logger = logging.getLogger("ImageLearner")
@@ -72,6 +78,8 @@
 class LudwigDirectBackend:
     """Backend for running Ludwig experiments directly via the internal experiment_cli function."""
 
+    _torchvision_patched = False
+
     def _detect_image_dimensions(self, image_zip_path: str) -> Tuple[int, int]:
         """Detect image dimensions from the first image in the dataset."""
         try:
@@ -344,6 +352,72 @@
                     logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions")
             except (ValueError, IndexError):
                 logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing")
+
+        def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]:
+            """Pick a validation metric that Ludwig will accept for the resolved task."""
+            default_map = {
+                "regression": "pearson_r",
+                "binary": "roc_auc",
+                "category": "accuracy",
+            }
+            allowed_map = {
+                "regression": {
+                    "pearson_r",
+                    "mean_absolute_error",
+                    "mean_squared_error",
+                    "root_mean_squared_error",
+                    "mean_absolute_percentage_error",
+                    "r2",
+                    "explained_variance",
+                    "loss",
+                },
+                # Ludwig rejects f1 and balanced_accuracy for binary outputs; keep to known-safe set.
+                "binary": {
+                    "roc_auc",
+                    "accuracy",
+                    "precision",
+                    "recall",
+                    "specificity",
+                    "log_loss",
+                    "loss",
+                },
+                "category": {
+                    "accuracy",
+                    "balanced_accuracy",
+                    "precision",
+                    "recall",
+                    "f1",
+                    "specificity",
+                    "log_loss",
+                    "loss",
+                },
+            }
+            alias_map = {
+                "regression": {
+                    "mae": "mean_absolute_error",
+                    "mse": "mean_squared_error",
+                    "rmse": "root_mean_squared_error",
+                    "mape": "mean_absolute_percentage_error",
+                },
+            }
+
+            default_metric = default_map.get(task)
+            allowed = allowed_map.get(task, set())
+            metric = requested or default_metric
+
+            if metric is None:
+                return None
+
+            metric = alias_map.get(task, {}).get(metric, metric)
+
+            if metric not in allowed:
+                if requested:
+                    logger.warning(
+                        f"Validation metric '{requested}' is not supported for {task} outputs; using '{default_metric}' instead."
+                    )
+                metric = default_metric
+            return metric
+
         if task_type == "regression":
             output_feat = {
                 "name": LABEL_COLUMN_NAME,
@@ -351,7 +425,7 @@
                 "decoder": {"type": "regressor"},
                 "loss": {"type": "mean_squared_error"},
             }
-            val_metric = config_params.get("validation_metric", "mean_squared_error")
+            val_metric = _resolve_validation_metric("regression", config_params.get("validation_metric"))
 
         else:
             if num_unique_labels == 2:
@@ -368,7 +442,10 @@
                     "type": "category",
                     "loss": {"type": "softmax_cross_entropy"},
                 }
-            val_metric = None
+            val_metric = _resolve_validation_metric(
+                "binary" if num_unique_labels == 2 else "category",
+                config_params.get("validation_metric"),
+            )
 
         conf: Dict[str, Any] = {
             "model_type": "ecd",
@@ -380,7 +457,7 @@
                 "early_stop": early_stop,
                 "batch_size": batch_size_cfg,
                 "learning_rate": learning_rate,
-                # only set validation_metric for regression
+                # set validation_metric when provided
                 **({"validation_metric": val_metric} if val_metric else {}),
             },
             "preprocessing": {
@@ -402,6 +479,41 @@
             )
             raise
 
+    def _patch_torchvision_download(self) -> None:
+        """
+        Torchvision weight downloads sometimes fail checksum validation behind
+        corporate proxies that rewrite binaries. Skip hash checking to allow
+        pre-trained weights to load in those environments.
+        """
+        if LudwigDirectBackend._torchvision_patched:
+            return
+        try:
+            import torch.hub as torch_hub
+
+            original = torch_hub.load_state_dict_from_url
+            original_download = torch_hub.download_url_to_file
+
+            def _no_hash(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None):
+                return original(
+                    url,
+                    model_dir=model_dir,
+                    map_location=map_location,
+                    progress=progress,
+                    check_hash=False,
+                    file_name=file_name,
+                )
+
+            def _download_no_hash(url, dst, hash_prefix=None, progress=True):
+                # Torchvision's download_url_to_file signature does not accept check_hash in older versions.
+                return original_download(url, dst, hash_prefix=None, progress=progress)
+
+            torch_hub.load_state_dict_from_url = _no_hash  # type: ignore[assignment]
+            torch_hub.download_url_to_file = _download_no_hash  # type: ignore[assignment]
+            LudwigDirectBackend._torchvision_patched = True
+            logger.info("Disabled torchvision weight hash verification to avoid proxy-corrupted downloads.")
+        except Exception as exc:
+            logger.warning(f"Could not patch torchvision download hash check: {exc}")
+
     def run_experiment(
         self,
         dataset_path: Path,
@@ -412,6 +524,9 @@
         """Invoke Ludwig's internal experiment_cli function to run the experiment."""
         logger.info("LudwigDirectBackend: Starting experiment execution.")
 
+        # Avoid strict hash validation for torchvision weights (common in proxied environments)
+        self._patch_torchvision_download()
+
         try:
             from ludwig.experiment import experiment_cli
         except ImportError as e:
@@ -506,24 +621,10 @@
         """Generate all registered Ludwig visualizations for the latest experiment run."""
         logger.info("Generating all Ludwig visualizations…")
 
+        # Keep only lightweight plots (drop compare_performance/roc_curves)
         test_plots = {
-            "compare_performance",
-            "compare_classifiers_performance_from_prob",
-            "compare_classifiers_performance_from_pred",
-            "compare_classifiers_performance_changing_k",
-            "compare_classifiers_multiclass_multimetric",
-            "compare_classifiers_predictions",
-            "confidence_thresholding_2thresholds_2d",
-            "confidence_thresholding_2thresholds_3d",
-            "confidence_thresholding",
-            "confidence_thresholding_data_vs_acc",
-            "binary_threshold_vs_metric",
-            "roc_curves",
             "roc_curves_from_test_statistics",
-            "calibration_1_vs_all",
-            "calibration_multiclass",
             "confusion_matrix",
-            "frequency_vs_f1",
         }
         train_plots = {
             "learning_curves",
@@ -627,6 +728,70 @@
         if not exp_dirs:
             raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}")
         exp_dir = exp_dirs[-1]
+        train_set_metadata_path = exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME
+        label_metadata_path = config.get("label_column_data_path")
+        if label_metadata_path:
+            label_metadata_path = Path(label_metadata_path)
+
+        # Pull additional config details from description.json if available
+        config_for_summary = dict(config)
+        if "target_column" not in config_for_summary or not config_for_summary.get("target_column"):
+            config_for_summary["target_column"] = LABEL_COLUMN_NAME
+        desc_path = exp_dir / DESCRIPTION_FILE_NAME
+        if desc_path.exists():
+            try:
+                with open(desc_path, "r") as f:
+                    desc_cfg = json.load(f).get("config", {})
+                encoder_cfg = (
+                    desc_cfg.get("input_features", [{}])[0].get("encoder", {})
+                    if isinstance(desc_cfg.get("input_features", [{}]), list)
+                    else {}
+                )
+                output_cfg = (
+                    desc_cfg.get("output_features", [{}])[0]
+                    if isinstance(desc_cfg.get("output_features", [{}]), list)
+                    else {}
+                )
+                trainer_cfg = desc_cfg.get("trainer", {}) if isinstance(desc_cfg, dict) else {}
+                loss_cfg = output_cfg.get("loss", {}) if isinstance(output_cfg, dict) else {}
+                opt_cfg = trainer_cfg.get("optimizer", {}) if isinstance(trainer_cfg, dict) else {}
+                clip_cfg = trainer_cfg.get("gradient_clipping", {}) if isinstance(trainer_cfg, dict) else {}
+
+                arch_type = encoder_cfg.get("type")
+                arch_variant = encoder_cfg.get("model_variant")
+                arch_name = None
+                if arch_type:
+                    arch_base = str(arch_type).replace("_", " ").title()
+                    arch_name = f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base
+
+                summary_fields = {
+                    "architecture": arch_name,
+                    "model_variant": arch_variant,
+                    "pretrained": encoder_cfg.get("use_pretrained"),
+                    "trainable": encoder_cfg.get("trainable"),
+                    "target_column": output_cfg.get("column"),
+                    "task_type": output_cfg.get("type"),
+                    "validation_metric": trainer_cfg.get("validation_metric"),
+                    "loss_function": loss_cfg.get("type"),
+                    "threshold": output_cfg.get("threshold"),
+                    "total_epochs": trainer_cfg.get("epochs"),
+                    "early_stop": trainer_cfg.get("early_stop"),
+                    "batch_size": trainer_cfg.get("batch_size"),
+                    "optimizer": opt_cfg.get("type"),
+                    "learning_rate": trainer_cfg.get("learning_rate"),
+                    "random_seed": desc_cfg.get("random_seed") or config.get("random_seed"),
+                    "use_mixed_precision": trainer_cfg.get("use_mixed_precision"),
+                    "gradient_clipping": clip_cfg.get("clipglobalnorm"),
+                }
+                for k, v in summary_fields.items():
+                    if v is None:
+                        continue
+                    # Do not override user-passed target/image column names in config
+                    if k in {"target_column", "image_column"} and config_for_summary.get(k):
+                        continue
+                    config_for_summary.setdefault(k, v)
+            except Exception as e:  # pragma: no cover - defensive
+                logger.warning(f"Could not merge description.json into config summary: {e}")
 
         base_viz_dir = exp_dir / "visualizations"
         train_viz_dir = base_viz_dir / "train"
@@ -698,9 +863,10 @@
         metrics_html = ""
         train_val_metrics_html = ""
         test_metrics_html = ""
+        output_type = None
+        train_stats_path = exp_dir / "training_statistics.json"
+        test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME
         try:
-            train_stats_path = exp_dir / "training_statistics.json"
-            test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME
             if train_stats_path.exists() and test_stats_path.exists():
                 with open(train_stats_path) as f:
                     train_stats = json.load(f)
@@ -725,10 +891,19 @@
         training_progress = self.get_training_process(output_dir)
         try:
             config_html = format_config_table_html(
-                config, split_info, training_progress, output_type
+                config_for_summary, split_info, training_progress, output_type
             )
         except Exception as e:
             logger.warning(f"Could not load config for HTML report: {e}")
+            config_html = (
+                "<h2 style='text-align: center;'>Model and Training Summary</h2>"
+                "<p style='text-align:center; color:#666;'>Configuration details unavailable.</p>"
+            )
+        if not config_html:
+            config_html = (
+                "<h2 style='text-align: center;'>Model and Training Summary</h2>"
+                "<p style='text-align:center; color:#666;'>No configuration details found.</p>"
+            )
 
         # ---------- image rendering with exclusions ----------
         def render_img_section(
@@ -776,6 +951,11 @@
                 for img in imgs
                 if img.name not in default_exclude
                 and img.name not in exclude_names
+                and not (
+                    "learning_curves" in img.stem
+                    and "loss" in img.stem
+                    and "label" in img.stem
+                )
             ]
 
             if not imgs:
@@ -802,7 +982,8 @@
                 )
             return html_section
 
-        tab1_content = config_html + metrics_html
+        # Show performance first, then config
+        tab1_content = metrics_html + config_html
 
         tab2_content = train_val_metrics_html + render_img_section(
             "Training and Validation Visualizations",
@@ -815,6 +996,21 @@
                 "precision_recall_curve.png",
             },
         )
+        if train_stats_path.exists():
+            try:
+                if output_type == "regression":
+                    tv_plots = build_regression_train_val_plots(str(train_stats_path))
+                else:
+                    tv_plots = build_train_validation_plots(str(train_stats_path))
+                for plot in tv_plots:
+                    tab2_content += (
+                        f"<h2 style='text-align: center;'>{plot['title']}</h2>"
+                        f"<div class='plotly-center'>{plot['html']}</div>"
+                    )
+                if tv_plots:
+                    logger.info(f"Generated {len(tv_plots)} train/val diagnostic plots")
+            except Exception as e:
+                logger.warning(f"Could not generate train/val plots: {e}")
 
         # --- Predictions vs Ground Truth table (REGRESSION ONLY) ---
         preds_section = ""
@@ -849,7 +1045,7 @@
                     "<div class='preds-controls'>"
                     "<button id='downloadPredsCsv' class='download-btn'>Download CSV</button>"
                     "</div>"
-                    "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:900px; margin-bottom:20px;'>"
+                    "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:350px; margin-bottom:20px;'>"
                     + preds_html
                     + "</div>"
                 )
@@ -857,27 +1053,75 @@
                 logger.warning(f"Could not build Predictions vs GT table: {e}")
 
         tab3_content = test_metrics_html + preds_section
+        test_plotly_added = False
+
+        if output_type == "regression" and train_stats_path.exists():
+            try:
+                test_plots = build_regression_test_plots(str(train_stats_path))
+                for plot in test_plots:
+                    tab3_content += (
+                        f"<h2 style='text-align: center;'>{plot['title']}</h2>"
+                        f"<div class='plotly-center'>{plot['html']}</div>"
+                    )
+                if test_plots:
+                    test_plotly_added = True
+                    logger.info(f"Generated {len(test_plots)} regression test plots")
+            except Exception as e:
+                logger.warning(f"Could not generate regression test plots: {e}")
 
         if output_type in ("binary", "category") and test_stats_path.exists():
             try:
                 interactive_plots = build_classification_plots(
                     str(test_stats_path),
                     str(train_stats_path) if train_stats_path.exists() else None,
+                    metadata_csv_path=str(label_metadata_path)
+                    if label_metadata_path and label_metadata_path.exists()
+                    else None,
+                    train_set_metadata_path=str(train_set_metadata_path)
+                    if train_set_metadata_path.exists()
+                    else None,
                 )
                 for plot in interactive_plots:
                     tab3_content += (
                         f"<h2 style='text-align: center;'>{plot['title']}</h2>"
                         f"<div class='plotly-center'>{plot['html']}</div>"
                     )
+                if interactive_plots:
+                    test_plotly_added = True
                 logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots")
             except Exception as e:
                 logger.warning(f"Could not generate Plotly plots: {e}")
 
+            # Add prediction diagnostics from predictions.csv
+            predictions_csv_path = exp_dir / "predictions.csv"
+            try:
+                diag_plots = build_prediction_diagnostics(
+                    str(predictions_csv_path),
+                    label_data_path=str(config.get("label_column_data_path"))
+                    if config.get("label_column_data_path")
+                    else None,
+                    threshold=config.get("threshold"),
+                )
+                for plot in diag_plots:
+                    tab3_content += (
+                        f"<h2 style='text-align: center;'>{plot['title']}</h2>"
+                        f"<div class='plotly-center'>{plot['html']}</div>"
+                    )
+                if diag_plots:
+                    test_plotly_added = True
+                    logger.info(f"Generated {len(diag_plots)} prediction diagnostic plots")
+            except Exception as e:
+                logger.warning(f"Could not generate prediction diagnostics: {e}")
+
+        # Fallback: include static PNGs if no interactive plots were added
+        if not test_plotly_added:
+            tab3_content += render_img_section(
+                "Test Visualizations (PNG fallback)",
+                test_viz_dir,
+                output_type,
+            )
+
         # Add static TEST PNGs (with default dedupe/exclusions)
-        tab3_content += render_img_section(
-            "Test Visualizations", test_viz_dir, output_type
-        )
-
         tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
         modal_html = get_metrics_help_modal()
         html += tabbed_html + modal_html + get_html_closing()
--- a/plotly_plots.py	Wed Nov 26 22:00:32 2025 +0000
+++ b/plotly_plots.py	Fri Nov 28 15:45:49 2025 +0000
@@ -7,13 +7,105 @@
 import plotly.graph_objects as go
 import plotly.io as pio
 from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME
-from sklearn.metrics import auc, roc_curve
-from sklearn.preprocessing import label_binarize
+
+
+def _style_fig(fig: go.Figure, font_size: int = 12) -> go.Figure:
+    """Apply consistent styling across Plotly figures."""
+    fig.update_layout(
+        font=dict(size=font_size),
+        plot_bgcolor="#ffffff",
+        paper_bgcolor="#ffffff",
+    )
+    fig.update_xaxes(gridcolor="#e8e8e8")
+    fig.update_yaxes(gridcolor="#e8e8e8")
+    return fig
+
+
+def _labels_from_metadata_dict(meta_dict: dict) -> List[str]:
+    """Extract ordered label names from Ludwig train_set_metadata."""
+    if not isinstance(meta_dict, dict):
+        return []
+
+    for key in ("idx2str", "idx2label", "vocab"):
+        seq = meta_dict.get(key)
+        if isinstance(seq, list) and seq:
+            return [str(v) for v in seq]
+
+    str2idx = meta_dict.get("str2idx")
+    if isinstance(str2idx, dict) and str2idx:
+        int_indices = [v for v in str2idx.values() if isinstance(v, int)]
+        if int_indices:
+            max_idx = max(int_indices)
+            ordered = [None] * (max_idx + 1)
+            for name, idx in str2idx.items():
+                if isinstance(idx, int) and 0 <= idx < len(ordered):
+                    ordered[idx] = name
+            return [str(v) for v in ordered if v is not None]
+
+    return []
+
+
+def _resolve_confusion_labels(
+    label_stats: dict,
+    n_classes: int,
+    metadata_csv_path: Optional[str],
+    train_set_metadata_path: Optional[str],
+) -> List[str]:
+    """Prefer original labels from metadata; fall back to stats if unavailable."""
+    if train_set_metadata_path:
+        try:
+            meta_path = Path(train_set_metadata_path)
+            if meta_path.exists():
+                with open(meta_path, "r") as f:
+                    meta_json = json.load(f)
+                label_meta = meta_json.get(LABEL_COLUMN_NAME)
+                if not isinstance(label_meta, dict):
+                    label_meta = next(
+                        (
+                            v
+                            for v in meta_json.values()
+                            if isinstance(v, dict)
+                            and any(k in v for k in ("idx2str", "str2idx", "idx2label", "vocab"))
+                        ),
+                        None,
+                    )
+                labels_from_meta = _labels_from_metadata_dict(label_meta) if label_meta else []
+                if labels_from_meta and len(labels_from_meta) >= n_classes:
+                    return [str(label) for label in labels_from_meta[:n_classes]]
+        except Exception as exc:
+            print(f"Warning: Unable to read labels from train_set_metadata: {exc}")
+
+    if metadata_csv_path:
+        try:
+            csv_path = Path(metadata_csv_path)
+            if csv_path.exists():
+                df_meta = pd.read_csv(csv_path)
+                if LABEL_COLUMN_NAME in df_meta.columns:
+                    uniques = df_meta[LABEL_COLUMN_NAME].dropna().unique().tolist()
+                    if uniques and len(uniques) >= n_classes:
+                        return [str(u) for u in uniques[:n_classes]]
+        except Exception as exc:
+            print(f"Warning: Unable to read labels from metadata CSV: {exc}")
+
+    pcs = label_stats.get("per_class_stats", {})
+    if pcs:
+        pcs_labels = [str(k) for k in pcs.keys()]
+        if len(pcs_labels) >= n_classes:
+            return pcs_labels[:n_classes]
+
+    labels = label_stats.get("labels")
+    if not labels:
+        labels = [str(i) for i in range(n_classes)]
+    if len(labels) < n_classes:
+        labels = labels + [str(i) for i in range(len(labels), n_classes)]
+    return [str(label) for label in labels[:n_classes]]
 
 
 def build_classification_plots(
     test_stats_path: str,
     training_stats_path: Optional[str] = None,
+    metadata_csv_path: Optional[str] = None,
+    train_set_metadata_path: Optional[str] = None,
 ) -> List[Dict[str, str]]:
     """
     Read Ludwig’s test_statistics.json and build three interactive Plotly panels:
@@ -21,6 +113,9 @@
       - ROC-AUC
       - Classification Report Heatmap
 
+    If metadata paths are provided, the confusion matrix axes will use the original
+    label values from the training metadata rather than integer-encoded labels.
+
     Returns a list of dicts, each with:
       {
         "title": <plot title>,
@@ -42,12 +137,12 @@
 
     # 0) Confusion Matrix
     cm = np.array(label_stats["confusion_matrix"], dtype=int)
-    # Try to get actual class names from per_class_stats keys (which contain the real labels)
-    pcs = label_stats.get("per_class_stats", {})
-    if pcs:
-        labels = list(pcs.keys())
-    else:
-        labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])])
+    labels = _resolve_confusion_labels(
+        label_stats,
+        n_classes,
+        metadata_csv_path=metadata_csv_path,
+        train_set_metadata_path=train_set_metadata_path,
+    )
     total = cm.sum()
 
     fig_cm = go.Figure(
@@ -70,6 +165,7 @@
         height=side_px,
         margin=dict(t=100, l=80, r=80, b=80),
     )
+    _style_fig(fig_cm)
 
     # annotate counts and percentages
     mval = cm.max() if cm.size else 0
@@ -110,16 +206,28 @@
         )
     })
 
-    # 1) ROC-AUC Curves (Multi-class)
-    roc_plot = _build_roc_auc_plot(test_stats_path, labels, common_cfg)
+    # 1) ROC Curve (from test_statistics)
+    roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels)
     if roc_plot:
         plots.append(roc_plot)
 
+    # 2) Precision-Recall Curve (from test_statistics)
+    pr_plot = _build_precision_recall_plot(label_stats, common_cfg)
+    if pr_plot:
+        plots.append(pr_plot)
+
     # 2) Classification Report Heatmap
     pcs = label_stats.get("per_class_stats", {})
     if pcs:
         classes = list(pcs.keys())
-        metrics = ["precision", "recall", "f1_score"]
+        metrics = [
+            "precision",
+            "recall",
+            "f1_score",
+            "accuracy",
+            "matthews_correlation_coefficient",
+            "specificity",
+        ]
         z, txt = [], []
         for c in classes:
             row, trow = [], []
@@ -133,7 +241,7 @@
         fig_cr = go.Figure(
             go.Heatmap(
                 z=z,
-                x=metrics,
+                x=[m.replace("_", " ") for m in metrics],
                 y=[str(c) for c in classes],
                 text=txt,
                 texttemplate="%{text}",
@@ -143,15 +251,16 @@
             )
         )
         fig_cr.update_layout(
-            title="Classification Report",
+            title="Per-Class metrics",
             xaxis_title="",
             yaxis_title="Class",
             width=side_px,
             height=side_px,
             margin=dict(t=80, l=80, r=80, b=80),
         )
+        _style_fig(fig_cr)
         plots.append({
-            "title": "Classification Report",
+            "title": "Per-Class metrics",
             "html": pio.to_html(
                 fig_cr,
                 full_html=False,
@@ -160,68 +269,667 @@
             )
         })
 
+    # 3) Prediction Diagnostics (from predictions.csv)
+    # Note: appended separately in generate_html_report, not returned here.
+
+    return plots
+
+
+def build_train_validation_plots(train_stats_path: str) -> List[Dict[str, str]]:
+    """Generate Train/Validation learning curve plots from training_statistics.json."""
+    if not train_stats_path or not Path(train_stats_path).exists():
+        return []
+    try:
+        with open(train_stats_path, "r") as f:
+            train_stats = json.load(f)
+    except Exception as exc:
+        print(f"Warning: Unable to read training statistics: {exc}")
+        return []
+
+    label_train = (train_stats.get("training") or {}).get("label", {})
+    label_val = (train_stats.get("validation") or {}).get("label", {})
+    if not label_train and not label_val:
+        return []
+    plots: List[Dict[str, str]] = []
+    include_js = True  # Load Plotly.js once for this group
+
+    def _get_series(stats: dict, metric: str) -> List[float]:
+        if metric not in stats:
+            return []
+        vals = stats.get(metric, [])
+        if isinstance(vals, list):
+            return [float(v) for v in vals]
+        try:
+            return [float(vals)]
+        except Exception:
+            return []
+
+    def _line_plot(metric_key: str, title: str, yaxis_title: str) -> Optional[Dict[str, str]]:
+        train_series = _get_series(label_train, metric_key)
+        val_series = _get_series(label_val, metric_key)
+        if not train_series and not val_series:
+            return None
+        epochs_train = list(range(1, len(train_series) + 1))
+        epochs_val = list(range(1, len(val_series) + 1))
+        fig = go.Figure()
+        if train_series:
+            fig.add_trace(
+                go.Scatter(
+                    x=epochs_train,
+                    y=train_series,
+                    mode="lines+markers",
+                    name="Train",
+                    line=dict(width=4),
+                )
+            )
+        if val_series:
+            fig.add_trace(
+                go.Scatter(
+                    x=epochs_val,
+                    y=val_series,
+                    mode="lines+markers",
+                    name="Validation",
+                    line=dict(width=4),
+                )
+            )
+        fig.update_layout(
+            title=dict(text=title, x=0.5),
+            xaxis_title="Epoch",
+            yaxis_title=yaxis_title,
+            width=760,
+            height=520,
+            hovermode="x unified",
+        )
+        _style_fig(fig)
+        return {
+            "title": title,
+            "html": pio.to_html(
+                fig,
+                full_html=False,
+                include_plotlyjs="cdn" if include_js else False,
+            ),
+        }
+
+    # Core learning curves
+    for key, title in [
+        ("roc_auc", "ROC-AUC across epochs"),
+        ("precision", "Precision across epochs"),
+        ("recall", "Recall/Sensitivity across epochs"),
+        ("specificity", "Specificity across epochs"),
+    ]:
+        plot = _line_plot(key, title, title.replace("Learning Curve", "").strip())
+        if plot:
+            plots.append(plot)
+            include_js = False
+
+    # Precision vs Recall evolution (validation)
+    val_prec = _get_series(label_val, "precision")
+    val_rec = _get_series(label_val, "recall")
+    if val_prec and val_rec:
+        epochs = list(range(1, min(len(val_prec), len(val_rec)) + 1))
+        fig_pr = go.Figure()
+        fig_pr.add_trace(
+            go.Scatter(
+                x=epochs,
+                y=val_prec[: len(epochs)],
+                mode="lines+markers",
+                name="Precision",
+            )
+        )
+        fig_pr.add_trace(
+            go.Scatter(
+                x=epochs,
+                y=val_rec[: len(epochs)],
+                mode="lines+markers",
+                name="Recall",
+            )
+        )
+        fig_pr.update_layout(
+            title=dict(text="Validation Precision and Recall by Epoch", x=0.5),
+            xaxis_title="Epoch",
+            yaxis_title="Value",
+            width=760,
+            height=520,
+            hovermode="x unified",
+        )
+        _style_fig(fig_pr)
+        plots.append({
+            "title": "Precision vs Recall Evolution",
+            "html": pio.to_html(
+                fig_pr,
+                full_html=False,
+                include_plotlyjs="cdn" if include_js else False,
+            ),
+        })
+        include_js = False
+
+    # F1-score derived
+    def _compute_f1(p: List[float], r: List[float]) -> List[float]:
+        f1_vals = []
+        for prec, rec in zip(p, r):
+            if (prec + rec) == 0:
+                f1_vals.append(0.0)
+            else:
+                f1_vals.append(2 * prec * rec / (prec + rec))
+        return f1_vals
+
+    f1_train = _compute_f1(_get_series(label_train, "precision"), _get_series(label_train, "recall"))
+    f1_val = _compute_f1(val_prec, val_rec)
+    if f1_train or f1_val:
+        fig = go.Figure()
+        if f1_train:
+            fig.add_trace(go.Scatter(x=list(range(1, len(f1_train) + 1)), y=f1_train, mode="lines+markers", name="Train", line=dict(width=4)))
+        if f1_val:
+            fig.add_trace(go.Scatter(x=list(range(1, len(f1_val) + 1)), y=f1_val, mode="lines+markers", name="Validation", line=dict(width=4)))
+        fig.update_layout(
+            title=dict(text="F1-Score across epochs (derived)", x=0.5),
+            xaxis_title="Epoch",
+            yaxis_title="F1-Score",
+            width=760,
+            height=520,
+            hovermode="x unified",
+        )
+        _style_fig(fig)
+        plots.append({
+            "title": "F1-Score across epochs (derived)",
+            "html": pio.to_html(
+                fig,
+                full_html=False,
+                include_plotlyjs="cdn" if include_js else False,
+            ),
+        })
+        include_js = False
+
+    # Overfitting Gap: Train vs Val ROC-AUC (gap)
+    roc_train = _get_series(label_train, "roc_auc")
+    roc_val = _get_series(label_val, "roc_auc")
+    if roc_train and roc_val:
+        epochs_gap = list(range(1, min(len(roc_train), len(roc_val)) + 1))
+        gaps = [t - v for t, v in zip(roc_train[:len(epochs_gap)], roc_val[:len(epochs_gap)])]
+        fig_gap = go.Figure()
+        fig_gap.add_trace(go.Scatter(x=epochs_gap, y=gaps, mode="lines+markers", name="Train - Val ROC-AUC", line=dict(width=4)))
+        fig_gap.update_layout(
+            title=dict(text="Overfitting gap: ROC-AUC across epochs", x=0.5),
+            xaxis_title="Epoch",
+            yaxis_title="Gap",
+            width=760,
+            height=520,
+            hovermode="x unified",
+        )
+        _style_fig(fig_gap)
+        plots.append({
+            "title": "Overfitting gap: ROC-AUC across epochs",
+            "html": pio.to_html(
+                fig_gap,
+                full_html=False,
+                include_plotlyjs="cdn" if include_js else False,
+            ),
+        })
+        include_js = False
+
+    # Best Epoch Dashboard (based on max val ROC-AUC)
+    if roc_val:
+        best_idx = int(np.argmax(roc_val))
+        best_epoch = best_idx + 1
+        spec_val = _get_series(label_val, "specificity")
+        metrics_at_best = {
+            "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None,
+            "Precision": val_prec[best_idx] if best_idx < len(val_prec) else None,
+            "Recall": val_rec[best_idx] if best_idx < len(val_rec) else None,
+            "Specificity": spec_val[best_idx] if best_idx < len(spec_val) else None,
+            "F1-Score": f1_val[best_idx] if best_idx < len(f1_val) else None,
+        }
+        fig_best = go.Figure()
+        for name, value in metrics_at_best.items():
+            if value is not None:
+                fig_best.add_trace(go.Bar(name=name, x=[name], y=[value]))
+        fig_best.update_layout(
+            title=dict(text=f"Best Epoch Dashboard (Val ROC-AUC @ epoch {best_epoch})", x=0.5),
+            xaxis_title="Metric",
+            yaxis_title="Value",
+            width=760,
+            height=520,
+            showlegend=False,
+        )
+        _style_fig(fig_best)
+        plots.append({
+            "title": "Best Validation Epoch Snapshot (Metrics)",
+            "html": pio.to_html(
+                fig_best,
+                full_html=False,
+                include_plotlyjs="cdn" if include_js else False,
+            ),
+        })
+        include_js = False
+
     return plots
 
 
-def _build_roc_auc_plot(test_stats_path: str, class_labels: List[str], config: dict) -> Optional[Dict[str, str]]:
-    """
-    Build an interactive ROC-AUC curve plot for multi-class classification.
-    Following sklearn's ROC example with micro-average and per-class curves.
+def _get_regression_series(split_stats: dict, metric: str) -> List[float]:
+    if metric not in split_stats:
+        return []
+    vals = split_stats.get(metric, [])
+    if isinstance(vals, list):
+        return [float(v) for v in vals]
+    try:
+        return [float(vals)]
+    except Exception:
+        return []
+
 
-    Args:
-        test_stats_path: Path to test_statistics.json
-        class_labels: List of class label names
-        config: Plotly config dict
+def _regression_line_plot(
+    train_split: dict,
+    val_split: dict,
+    metric_key: str,
+    title: str,
+    yaxis_title: str,
+    include_js: bool,
+) -> Optional[Dict[str, str]]:
+    train_series = _get_regression_series(train_split, metric_key)
+    val_series = _get_regression_series(val_split, metric_key)
+    if not train_series and not val_series:
+        return None
+    epochs_train = list(range(1, len(train_series) + 1))
+    epochs_val = list(range(1, len(val_series) + 1))
+    fig = go.Figure()
+    if train_series:
+        fig.add_trace(
+            go.Scatter(
+                x=epochs_train,
+                y=train_series,
+                mode="lines+markers",
+                name="Train",
+                line=dict(width=4),
+            )
+        )
+    if val_series:
+        fig.add_trace(
+            go.Scatter(
+                x=epochs_val,
+                y=val_series,
+                mode="lines+markers",
+                name="Validation",
+                line=dict(width=4),
+            )
+        )
+    fig.update_layout(
+        title=dict(text=title, x=0.5),
+        xaxis_title="Epoch",
+        yaxis_title=yaxis_title,
+        width=760,
+        height=520,
+        hovermode="x unified",
+    )
+    _style_fig(fig)
+    return {
+        "title": title,
+        "html": pio.to_html(
+            fig,
+            full_html=False,
+            include_plotlyjs="cdn" if include_js else False,
+        ),
+    }
+
+
+def build_regression_train_val_plots(train_stats_path: str) -> List[Dict[str, str]]:
+    """Generate regression Train/Validation learning curve plots from training_statistics.json."""
+    if not train_stats_path or not Path(train_stats_path).exists():
+        return []
+    try:
+        with open(train_stats_path, "r") as f:
+            train_stats = json.load(f)
+    except Exception as exc:
+        print(f"Warning: Unable to read training statistics: {exc}")
+        return []
+
+    label_train = (train_stats.get("training") or {}).get("label", {})
+    label_val = (train_stats.get("validation") or {}).get("label", {})
+    if not label_train and not label_val:
+        return []
+
+    plots: List[Dict[str, str]] = []
+    include_js = True
+    for metric_key, title, ytitle in [
+        ("mean_absolute_error", "Mean Absolute Error across epochs", "MAE"),
+        ("root_mean_squared_error", "Root Mean Squared Error across epochs", "RMSE"),
+        ("mean_absolute_percentage_error", "Mean Absolute Percentage Error across epochs", "MAPE"),
+        ("r2", "R² across epochs", "R²"),
+        ("loss", "Loss across epochs", "Loss"),
+    ]:
+        plot = _regression_line_plot(label_train, label_val, metric_key, title, ytitle, include_js)
+        if plot:
+            plots.append(plot)
+            include_js = False
+    return plots
+
 
-    Returns:
-        Dict with title and HTML, or None if data unavailable
-    """
+def build_regression_test_plots(train_stats_path: str) -> List[Dict[str, str]]:
+    """Generate regression Test learning curves from training_statistics.json."""
+    if not train_stats_path or not Path(train_stats_path).exists():
+        return []
     try:
-        # Get the experiment directory from test_stats_path
-        exp_dir = Path(test_stats_path).parent
+        with open(train_stats_path, "r") as f:
+            train_stats = json.load(f)
+    except Exception as exc:
+        print(f"Warning: Unable to read training statistics: {exc}")
+        return []
+
+    label_test = (train_stats.get("test") or {}).get("label", {})
+    if not label_test:
+        return []
 
-        # Load predictions with probabilities
-        predictions_path = exp_dir / "predictions.csv"
-        if not predictions_path.exists():
-            return None
+    plots: List[Dict[str, str]] = []
+    include_js = True
+    metrics = [
+        ("mean_absolute_error", "Mean Absolute Error Across Epochs", "MAE"),
+        ("root_mean_squared_error", "Root Mean Squared Error Across Epochs", "RMSE"),
+        ("mean_absolute_percentage_error", "Mean Absolute Percentage Error Across Epochs", "MAPE"),
+        ("r2", "R² Across Epochs", "R²"),
+        ("loss", "Loss Across Epochs", "Loss"),
+    ]
+    epochs = None
+    for metric_key, title, ytitle in metrics:
+        series = _get_regression_series(label_test, metric_key)
+        if not series:
+            continue
+        if epochs is None:
+            epochs = list(range(1, len(series) + 1))
+        fig = go.Figure()
+        fig.add_trace(
+            go.Scatter(
+                x=epochs,
+                y=series[: len(epochs)],
+                mode="lines+markers",
+                name="Test",
+                line=dict(width=4),
+            )
+        )
+        fig.update_layout(
+            title=dict(text=title, x=0.5),
+            xaxis_title="Epoch",
+            yaxis_title=ytitle,
+            width=760,
+            height=520,
+            hovermode="x unified",
+        )
+        _style_fig(fig)
+        plots.append({
+            "title": title,
+            "html": pio.to_html(
+                fig,
+                full_html=False,
+                include_plotlyjs="cdn" if include_js else False,
+            ),
+        })
+        include_js = False
+    return plots
 
-        df_pred = pd.read_csv(predictions_path)
+
+def _build_static_roc_plot(
+    label_stats: dict, config: dict, friendly_labels: Optional[List[str]] = None
+) -> Optional[Dict[str, str]]:
+    """Build ROC curve directly from test_statistics.json (single curve)."""
+    roc_data = label_stats.get("roc_curve")
+    if not isinstance(roc_data, dict):
+        return None
+
+    fpr = roc_data.get("false_positive_rate")
+    tpr = roc_data.get("true_positive_rate")
+    if not fpr or not tpr or len(fpr) != len(tpr):
+        return None
+
+    try:
+        fig = go.Figure()
+        fig.add_trace(
+            go.Scatter(
+                x=fpr,
+                y=tpr,
+                mode="lines+markers",
+                name="ROC Curve",
+                line=dict(color="#1f77b4", width=4),
+                hovertemplate="FPR: %{x:.3f}<br>TPR: %{y:.3f}<extra></extra>",
+            )
+        )
+        fig.add_trace(
+            go.Scatter(
+                x=[0, 1],
+                y=[0, 1],
+                mode="lines",
+                name="Random Classifier",
+                line=dict(color="gray", width=2, dash="dash"),
+                hovertemplate="Random Classifier<extra></extra>",
+            )
+        )
+
+        auc_val = label_stats.get("roc_auc") or label_stats.get("roc_auc_macro") or label_stats.get("roc_auc_micro")
+        auc_txt = f" (AUC = {auc_val:.3f})" if isinstance(auc_val, (int, float)) else ""
 
-        if SPLIT_COLUMN_NAME in df_pred.columns:
-            split_series = df_pred[SPLIT_COLUMN_NAME].astype(str).str.lower()
-            test_mask = split_series.isin({"2", "test", "testing"})
-            if test_mask.any():
-                df_pred = df_pred[test_mask].reset_index(drop=True)
+        # Determine which label is treated as positive for the curve
+        label_list: List = []
+        pcs = label_stats.get("per_class_stats", {})
+        if pcs:
+            label_list = list(pcs.keys())
+        if not label_list:
+            labels_from_stats = label_stats.get("labels")
+            if isinstance(labels_from_stats, list):
+                label_list = labels_from_stats
+
+        # Try to resolve index of the positive label explicitly provided by Ludwig
+        pos_label_raw = (
+            roc_data.get("positive_label")
+            or roc_data.get("positive_class")
+            or label_stats.get("positive_label")
+        )
+        pos_label_idx = None
+        if pos_label_raw is not None and isinstance(label_list, list):
+            try:
+                pos_label_idx = label_list.index(pos_label_raw)
+            except ValueError:
+                pos_label_idx = None
+
+        # Fallback: use the second label if available, otherwise the first
+        if pos_label_idx is None:
+            if isinstance(label_list, list) and len(label_list) >= 2:
+                pos_label_idx = 1
+            elif isinstance(label_list, list) and label_list:
+                pos_label_idx = 0
+
+        if pos_label_raw is None and isinstance(label_list, list) and pos_label_idx is not None:
+            pos_label_raw = label_list[pos_label_idx]
+
+        # Map to friendly label if we have one from metadata/CSV
+        pos_label_display = pos_label_raw
+        if (
+            friendly_labels
+            and isinstance(pos_label_idx, int)
+            and 0 <= pos_label_idx < len(friendly_labels)
+        ):
+            pos_label_display = friendly_labels[pos_label_idx]
+
+        pos_label_txt = (
+            f"Positive class: {pos_label_display}"
+            if pos_label_display is not None
+            else "Positive class: (not available)"
+        )
+
+        title_label = f"ROC Curve{auc_txt}"
+        if pos_label_display is not None:
+            title_label = f"ROC Curve (Positive Class: {pos_label_display}){auc_txt}"
 
-        if df_pred.empty:
-            return None
+        fig.update_layout(
+            title=dict(text=title_label, x=0.5),
+            xaxis_title="False Positive Rate",
+            yaxis_title="True Positive Rate",
+            width=700,
+            height=600,
+            margin=dict(t=80, l=80, r=80, b=110),
+            hovermode="closest",
+            legend=dict(
+                x=0.6,
+                y=0.1,
+                bgcolor="rgba(255,255,255,0.9)",
+                bordercolor="rgba(0,0,0,0.2)",
+                borderwidth=1,
+            ),
+        )
+        _style_fig(fig)
+        fig.update_xaxes(range=[0, 1.0])
+        fig.update_yaxes(range=[0, 1.05])
 
-        # Extract probability columns (label_probabilities_0, label_probabilities_1, etc.)
-        # or label_probabilities_<class_name> for string labels
-        prob_cols = [col for col in df_pred.columns if col.startswith('label_probabilities_') and col != 'label_probabilities']
+        fig.add_annotation(
+            x=0.5,
+            y=-0.15,
+            xref="paper",
+            yref="paper",
+            showarrow=False,
+            text=f"<span style='font-size:12px;color:#555;'>{pos_label_txt}</span>",
+            xanchor="center",
+        )
+
+        return {
+            "title": "ROC Curve",
+            "html": pio.to_html(
+                fig,
+                full_html=False,
+                include_plotlyjs=False,
+                config=config,
+            ),
+        }
+    except Exception as e:
+        print(f"Error building ROC plot: {e}")
+        return None
+
+
+def _build_precision_recall_plot(label_stats: dict, config: dict) -> Optional[Dict[str, str]]:
+    """Build Precision-Recall curve directly from test_statistics.json."""
+    pr_data = label_stats.get("precision_recall_curve")
+    if not isinstance(pr_data, dict):
+        return None
+
+    precisions = pr_data.get("precisions")
+    recalls = pr_data.get("recalls")
+    if not precisions or not recalls or len(precisions) != len(recalls):
+        return None
 
-        # Sort by class number if numeric, otherwise keep alphabetical order
-        if prob_cols and prob_cols[0].split('_')[-1].isdigit():
-            prob_cols.sort(key=lambda x: int(x.split('_')[-1]))
-        else:
-            prob_cols.sort()  # Alphabetical sort for string class names
+    try:
+        fig = go.Figure()
+        fig.add_trace(
+            go.Scatter(
+                x=recalls,
+                y=precisions,
+                mode="lines+markers",
+                name="Precision-Recall",
+                line=dict(color="#d62728", width=4),
+                hovertemplate="Recall: %{x:.3f}<br>Precision: %{y:.3f}<extra></extra>",
+            )
+        )
+
+        ap_val = (
+            label_stats.get("average_precision_macro")
+            or label_stats.get("average_precision_micro")
+            or label_stats.get("average_precision_samples")
+        )
+        ap_txt = f" (AP = {ap_val:.3f})" if isinstance(ap_val, (int, float)) else ""
+
+        fig.update_layout(
+            title=dict(text=f"Precision-Recall Curve{ap_txt}", x=0.5),
+            xaxis_title="Recall",
+            yaxis_title="Precision",
+            width=700,
+            height=600,
+            margin=dict(t=80, l=80, r=80, b=80),
+            hovermode="closest",
+            legend=dict(
+                x=0.6,
+                y=0.1,
+                bgcolor="rgba(255,255,255,0.9)",
+                bordercolor="rgba(0,0,0,0.2)",
+                borderwidth=1,
+            ),
+        )
+        _style_fig(fig)
+        fig.update_xaxes(range=[0, 1.0])
+        fig.update_yaxes(range=[0, 1.05])
+
+        return {
+            "title": "Precision-Recall Curve",
+            "html": pio.to_html(
+                fig,
+                full_html=False,
+                include_plotlyjs=False,
+                config=config,
+            ),
+        }
+    except Exception as e:
+        print(f"Error building Precision-Recall plot: {e}")
+        return None
+
 
-        if not prob_cols:
-            return None
+def build_prediction_diagnostics(
+    predictions_path: str,
+    label_data_path: Optional[str] = None,
+    split_value: int = 2,
+    threshold: Optional[float] = None,
+) -> List[Dict[str, str]]:
+    """Generate diagnostic plots from predictions.csv for classification tasks."""
+    preds_file = Path(predictions_path)
+    if not preds_file.exists():
+        return []
+
+    try:
+        df_pred = pd.read_csv(predictions_path)
+    except Exception as exc:
+        print(f"Warning: Unable to read predictions CSV: {exc}")
+        return []
+
+    plots: List[Dict[str, str]] = []
+
+    # Identify probability columns
+    prob_cols = [
+        c for c in df_pred.columns
+        if c.startswith("label_probabilities_") and c != "label_probabilities"
+    ]
+    prob_cols_sorted = sorted(prob_cols)
 
-        # Get probabilities matrix (n_samples x n_classes)
-        y_score = df_pred[prob_cols].values
-        n_classes = len(prob_cols)
+    def _select_positive_prob():
+        if not prob_cols_sorted:
+            return None, None
+        # Prefer a column indicating positive/event/true/1
+        preferred_keys = ("event", "true", "positive", "pos", "1")
+        for col in prob_cols_sorted:
+            suffix = col.replace("label_probabilities_", "").lower()
+            if any(k in suffix for k in preferred_keys):
+                return col, suffix
+        if len(prob_cols_sorted) == 2:
+            col = prob_cols_sorted[1]
+            return col, col.replace("label_probabilities_", "")
+        col = prob_cols_sorted[0]
+        return col, col.replace("label_probabilities_", "")
 
-        y_true = None
-        candidate_cols = [
+    pos_prob_col, pos_label_hint = _select_positive_prob()
+    pos_prob_series = df_pred[pos_prob_col] if pos_prob_col and pos_prob_col in df_pred else None
+
+    # Confidence series: prefer label_probability, otherwise positive prob, otherwise max prob
+    confidence_series = None
+    if "label_probability" in df_pred.columns:
+        confidence_series = df_pred["label_probability"]
+    elif pos_prob_series is not None:
+        confidence_series = pos_prob_series
+    elif prob_cols_sorted:
+        confidence_series = df_pred[prob_cols_sorted].max(axis=1)
+
+    # True labels
+    def _extract_labels():
+        candidates = [
             LABEL_COLUMN_NAME,
             f"{LABEL_COLUMN_NAME}_ground_truth",
             f"{LABEL_COLUMN_NAME}__ground_truth",
             f"{LABEL_COLUMN_NAME}_target",
             f"{LABEL_COLUMN_NAME}__target",
+            "label",
+            "label_true",
         ]
-        candidate_cols.extend(
+        candidates.extend(
             [
                 col
                 for col in df_pred.columns
@@ -230,174 +938,182 @@
                 and "predictions" not in col
             ]
         )
-        for col in candidate_cols:
-            if col in df_pred.columns and col not in prob_cols:
-                y_true = df_pred[col].values
-                break
+        for col in candidates:
+            if col in df_pred.columns and col not in prob_cols_sorted:
+                return df_pred[col]
+        if label_data_path and Path(label_data_path).exists():
+            try:
+                df_all = pd.read_csv(label_data_path)
+                if SPLIT_COLUMN_NAME in df_all.columns:
+                    df_all = df_all[df_all[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True)
+                if LABEL_COLUMN_NAME in df_all.columns:
+                    return df_all[LABEL_COLUMN_NAME].reset_index(drop=True)
+            except Exception as exc:
+                print(f"Warning: Unable to load labels from dataset: {exc}")
+        return None
 
-        if y_true is None:
-            desc_path = exp_dir / "description.json"
-            if desc_path.exists():
-                try:
-                    with open(desc_path, 'r') as f:
-                        desc = json.load(f)
-                    dataset_path = desc.get('dataset', '')
-                    if dataset_path and Path(dataset_path).exists():
-                        df_orig = pd.read_csv(dataset_path)
-                        if SPLIT_COLUMN_NAME in df_orig.columns:
-                            df_orig = df_orig[df_orig[SPLIT_COLUMN_NAME] == 2].reset_index(drop=True)
-                        if LABEL_COLUMN_NAME in df_orig.columns:
-                            y_true = df_orig[LABEL_COLUMN_NAME].values
-                            if len(y_true) != len(df_pred):
-                                print(
-                                    f"Warning: Test set size mismatch. Truncating to {len(df_pred)} samples for ROC plot."
-                                )
-                                y_true = y_true[:len(df_pred)]
-                    else:
-                        print("Warning: Original dataset referenced in description.json is unavailable.")
-                except Exception as exc:  # pragma: no cover - defensive
-                    print(f"Warning: Failed to recover labels from dataset: {exc}")
-
-        if y_true is None or len(y_true) == 0:
-            print("Warning: Unable to locate ground-truth labels for ROC plot.")
-            return None
-
-        if len(y_true) != len(y_score):
-            limit = min(len(y_true), len(y_score))
-            if limit == 0:
-                return None
-            print(f"Warning: Aligning prediction and label lengths to {limit} samples for ROC plot.")
-            y_true = y_true[:limit]
-            y_score = y_score[:limit]
+    labels_series = _extract_labels()
 
-        # Get actual class names from probability column names
-        actual_classes = [col.replace('label_probabilities_', '') for col in prob_cols]
-        display_classes = class_labels if len(class_labels) == n_classes else actual_classes
-
-        # Binarize the output following sklearn example
-        # Use actual class names if they're strings, otherwise use range
-        if isinstance(y_true[0], str):
-            y_test = label_binarize(y_true, classes=actual_classes)
-        else:
-            y_test = label_binarize(y_true, classes=list(range(n_classes)))
-
-        # Handle binary classification case
-        if y_test.ndim != 2:
-            y_test = np.atleast_2d(y_test)
+    # Plot 1: Confidence Histogram
+    if confidence_series is not None:
+        fig_conf = go.Figure()
+        fig_conf.add_trace(
+            go.Histogram(
+                x=confidence_series,
+                nbinsx=20,
+                marker=dict(color="#1f77b4", line=dict(color="#ffffff", width=1)),
+                opacity=0.8,
+                histnorm="percent",
+            )
+        )
+        fig_conf.update_layout(
+            title=dict(text="Prediction Confidence Distribution", x=0.5),
+            xaxis_title="Predicted probability (confidence)",
+            yaxis_title="Percentage (%)",
+            bargap=0.05,
+            width=700,
+            height=500,
+        )
+        _style_fig(fig_conf)
+        plots.append({
+            "title": "Prediction Confidence Distribution",
+            "html": pio.to_html(fig_conf, full_html=False, include_plotlyjs=False),
+        })
 
-        if n_classes == 2:
-            if y_test.shape[1] == 1:
-                y_test = np.hstack([1 - y_test, y_test])
-            elif y_test.shape[1] != 2:
-                print("Warning: Unexpected label binarization shape for binary ROC plot.")
-                return None
-        elif y_test.shape[1] != n_classes:
-            print("Warning: Label binarization did not produce expected class dimension; skipping ROC plot.")
-            return None
+    # The remaining plots require true labels and a positive-class probability
+    if labels_series is None or pos_prob_series is None:
+        return plots
+
+    # Align lengths
+    min_len = min(len(labels_series), len(pos_prob_series))
+    if min_len == 0:
+        return plots
+    y_true_raw = labels_series.iloc[:min_len]
+    y_score = np.array(pos_prob_series.iloc[:min_len], dtype=float)
 
-        # Compute ROC curve and ROC area for each class (following sklearn example)
-        fpr = dict()
-        tpr = dict()
-        roc_auc = dict()
+    # Determine positive label
+    unique_labels = pd.unique(y_true_raw)
+    unique_labels_list = list(unique_labels)
+    positive_label = None
+    if pos_label_hint and str(pos_label_hint) in [str(u) for u in unique_labels_list]:
+        positive_label = pos_label_hint
+    elif len(unique_labels_list) == 2:
+        positive_label = unique_labels_list[1]
+    else:
+        positive_label = unique_labels_list[0]
 
-        for i in range(n_classes):
-            if np.sum(y_test[:, i]) > 0:  # Check if class exists in test set
-                fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
-                roc_auc[i] = auc(fpr[i], tpr[i])
-
-        # Compute micro-average ROC curve and ROC area (sklearn example)
-        fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
-        roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
-
-        # Create ROC curve plot
-        fig_roc = go.Figure()
+    y_true = (y_true_raw == positive_label).astype(int).values
 
-        # Colors for different classes
-        colors = [
-            '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
-            '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'
-        ]
-
-        # Plot micro-average ROC curve first (most important)
-        fig_roc.add_trace(go.Scatter(
-            x=fpr["micro"],
-            y=tpr["micro"],
-            mode='lines',
-            name=f'Micro-average ROC (AUC = {roc_auc["micro"]:.3f})',
-            line=dict(color='deeppink', width=3, dash='dot'),
-            hovertemplate=('<b>Micro-average ROC</b><br>'
-                           'FPR: %{x:.3f}<br>'
-                           'TPR: %{y:.3f}<br>'
-                           f'AUC: {roc_auc["micro"]:.3f}<extra></extra>')
-        ))
-
-        # Plot ROC curve for each class
-        for i in range(n_classes):
-            if i in roc_auc:  # Only plot if class exists in test set
-                class_name = display_classes[i] if i < len(display_classes) else f"Class {i}"
-                color = colors[i % len(colors)]
-
-                fig_roc.add_trace(go.Scatter(
-                    x=fpr[i],
-                    y=tpr[i],
-                    mode='lines',
-                    name=f'{class_name} (AUC = {roc_auc[i]:.3f})',
-                    line=dict(color=color, width=2),
-                    hovertemplate=(f'<b>{class_name}</b><br>'
-                                   'FPR: %{x:.3f}<br>'
-                                   'TPR: %{y:.3f}<br>'
-                                   f'AUC: {roc_auc[i]:.3f}<extra></extra>')
-                ))
+    # Plot 2: Calibration Curve
+    bins = np.linspace(0.0, 1.0, 11)
+    bin_ids = np.digitize(y_score, bins, right=True)
+    bin_centers = []
+    frac_positives = []
+    for b in range(1, len(bins)):
+        mask = bin_ids == b
+        if not np.any(mask):
+            continue
+        bin_centers.append(y_score[mask].mean())
+        frac_positives.append(y_true[mask].mean())
+    if bin_centers and frac_positives:
+        fig_cal = go.Figure()
+        fig_cal.add_trace(
+            go.Scatter(
+                x=bin_centers,
+                y=frac_positives,
+                mode="lines+markers",
+                name="Calibration",
+                line=dict(color="#2ca02c", width=4),
+            )
+        )
+        fig_cal.add_trace(
+            go.Scatter(
+                x=[0, 1],
+                y=[0, 1],
+                mode="lines",
+                name="Perfect Calibration",
+                line=dict(color="gray", width=2, dash="dash"),
+            )
+        )
+        fig_cal.update_layout(
+            title=dict(text="Calibration Curve", x=0.5),
+            xaxis_title="Predicted probability",
+            yaxis_title="Observed frequency",
+            width=700,
+            height=500,
+        )
+        _style_fig(fig_cal)
+        plots.append({
+            "title": "Calibration Curve (Predicted Probability vs Observed Frequency)",
+            "html": pio.to_html(fig_cal, full_html=False, include_plotlyjs=False),
+        })
 
-        # Add diagonal line (random classifier)
-        fig_roc.add_trace(go.Scatter(
-            x=[0, 1],
-            y=[0, 1],
-            mode='lines',
-            name='Random Classifier',
-            line=dict(color='gray', width=1, dash='dash'),
-            hovertemplate='Random Classifier<br>AUC = 0.500<extra></extra>'
-        ))
-
-        # Calculate macro-average AUC
-        class_aucs = [roc_auc[i] for i in range(n_classes) if i in roc_auc]
-        if class_aucs:
-            macro_auc = np.mean(class_aucs)
-            title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f}, Macro-avg = {macro_auc:.3f})"
-        else:
-            title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f})"
+    # Plot 3: Threshold vs Metrics
+    thresholds = np.linspace(0.0, 1.0, 21)
+    accs, f1s, sens, specs = [], [], [], []
+    for t in thresholds:
+        y_pred = (y_score >= t).astype(int)
+        tp = np.sum((y_true == 1) & (y_pred == 1))
+        tn = np.sum((y_true == 0) & (y_pred == 0))
+        fp = np.sum((y_true == 0) & (y_pred == 1))
+        fn = np.sum((y_true == 1) & (y_pred == 0))
+        acc = (tp + tn) / max(len(y_true), 1)
+        prec = tp / max(tp + fp, 1e-9)
+        rec = tp / max(tp + fn, 1e-9)
+        f1 = 0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec)
+        sensitivity = rec
+        specificity = tn / max(tn + fp, 1e-9)
+        accs.append(acc)
+        f1s.append(f1)
+        sens.append(sensitivity)
+        specs.append(specificity)
 
-        fig_roc.update_layout(
-            title=dict(text=title_text, x=0.5),
-            xaxis_title="False Positive Rate",
-            yaxis_title="True Positive Rate",
-            width=700,
-            height=600,
-            margin=dict(t=80, l=80, r=80, b=80),
-            legend=dict(
-                x=0.6,
-                y=0.1,
-                bgcolor="rgba(255,255,255,0.9)",
-                bordercolor="rgba(0,0,0,0.2)",
-                borderwidth=1
-            ),
-            hovermode='closest'
-        )
+    fig_thresh = go.Figure()
+    fig_thresh.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4)))
+    fig_thresh.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1", line=dict(width=4)))
+    fig_thresh.add_trace(go.Scatter(x=thresholds, y=sens, mode="lines", name="Sensitivity", line=dict(width=4)))
+    fig_thresh.add_trace(go.Scatter(x=thresholds, y=specs, mode="lines", name="Specificity", line=dict(width=4)))
+    fig_thresh.update_layout(
+        title=dict(text="Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", x=0.5),
+        xaxis_title="Decision threshold",
+        yaxis_title="Metric value",
+        width=700,
+        height=500,
+        legend=dict(
+            x=0.7,
+            y=0.2,
+            bgcolor="rgba(255,255,255,0.9)",
+            bordercolor="rgba(0,0,0,0.2)",
+            borderwidth=1,
+        ),
+        shapes=[
+            dict(
+                type="line",
+                x0=threshold,
+                x1=threshold,
+                y0=0,
+                y1=1,
+                xref="x",
+                yref="paper",
+                line=dict(color="#d62728", width=2, dash="dash"),
+            )
+        ] if isinstance(threshold, (int, float)) else [],
+        annotations=[
+            dict(
+                x=threshold,
+                y=1.02,
+                xref="x",
+                yref="paper",
+                showarrow=False,
+                text=f"Threshold = {threshold:.2f}",
+                font=dict(size=11, color="#d62728"),
+            )
+        ] if isinstance(threshold, (int, float)) else [],
+    )
+    _style_fig(fig_thresh)
+    plots.append({
+        "title": "Threshold Sweep: Accuracy, F1, Sensitivity, Specificity",
+        "html": pio.to_html(fig_thresh, full_html=False, include_plotlyjs=False),
+    })
 
-        # Set equal aspect ratio and proper range
-        fig_roc.update_xaxes(range=[0, 1.0])
-        fig_roc.update_yaxes(range=[0, 1.05])
-
-        return {
-            "title": "ROC-AUC Curves",
-            "html": pio.to_html(
-                fig_roc,
-                full_html=False,
-                include_plotlyjs=False,
-                config=config
-            )
-        }
-
-    except Exception as e:
-        print(f"Error building ROC-AUC plot: {e}")
-        return None
+    return plots