changeset 8:85e6f4b2ad18 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 8a42eb9b33df7e1df5ad5153b380e20b910a05b6
author goeckslab
date Thu, 14 Aug 2025 14:53:10 +0000
parents 801a8b6973fb
children
files constants.py image_learner.xml image_learner_cli.py plotly_plots.py utils.py
diffstat 5 files changed, 670 insertions(+), 429 deletions(-) [+]
line wrap: on
line diff
--- a/constants.py	Fri Aug 08 13:06:28 2025 +0000
+++ b/constants.py	Thu Aug 14 14:53:10 2025 +0000
@@ -87,28 +87,28 @@
 }
 METRIC_DISPLAY_NAMES = {
     "accuracy": "Accuracy",
-    "accuracy_micro": "Accuracy-Micro",
+    "accuracy_micro": "Micro Accuracy",
     "loss": "Loss",
     "roc_auc": "ROC-AUC",
-    "roc_auc_macro": "ROC-AUC-Macro",
-    "roc_auc_micro": "ROC-AUC-Micro",
+    "roc_auc_macro": "Macro ROC-AUC",
+    "roc_auc_micro": "Micro ROC-AUC",
     "hits_at_k": "Hits at K",
     "precision": "Precision",
     "recall": "Recall",
     "specificity": "Specificity",
     "kappa_score": "Cohen's Kappa",
     "token_accuracy": "Token Accuracy",
-    "avg_precision_macro": "Precision-Macro",
-    "avg_recall_macro": "Recall-Macro",
-    "avg_f1_score_macro": "F1-score-Macro",
-    "avg_precision_micro": "Precision-Micro",
-    "avg_recall_micro": "Recall-Micro",
-    "avg_f1_score_micro": "F1-score-Micro",
-    "avg_precision_weighted": "Precision-Weighted",
-    "avg_recall_weighted": "Recall-Weighted",
-    "avg_f1_score_weighted": "F1-score-Weighted",
-    "average_precision_macro": "Precision-Average-Macro",
-    "average_precision_micro": "Precision-Average-Micro",
+    "avg_precision_macro": "Macro Precision",
+    "avg_recall_macro": "Macro Recall",
+    "avg_f1_score_macro": "Macro F1-score",
+    "avg_precision_micro": "Micro Precision",
+    "avg_recall_micro": "Micro Recall",
+    "avg_f1_score_micro": "Micro F1-score",
+    "avg_precision_weighted": "Weighted Precision",
+    "avg_recall_weighted": "Weighted Recall",
+    "avg_f1_score_weighted": "Weighted F1-score",
+    "average_precision_macro": "Macro Precision-Average",
+    "average_precision_micro": "Micro Precision-Average",
     "average_precision_samples": "Precision-Average-Samples",
     "mean_squared_error": "Mean Squared Error",
     "mean_absolute_error": "Mean Absolute Error",
--- a/image_learner.xml	Fri Aug 08 13:06:28 2025 +0000
+++ b/image_learner.xml	Thu Aug 14 14:53:10 2025 +0000
@@ -1,7 +1,7 @@
-<tool id="image_learner" name="Image Learner" version="0.1.1" profile="22.05">
-    <description>trains and evaluates an image classification/regression model</description>
+<tool id="image_learner" name="Image Learner for Classification" version="0.1.2" profile="22.05">
+    <description>trains and evaluates a image classification model</description>
     <requirements>
-        <container type="docker">quay.io/goeckslab/galaxy-ludwig-gpu:0.10.1</container>
+        <container type="docker">quay.io/goeckslab/galaxy-ludwig-gpu:latest</container>
     </requirements>
     <required_files>
         <include path="utils.py" />
@@ -144,13 +144,14 @@
         <conditional name="scratch_fine_tune">
             <param name="use_pretrained" type="select"
                 label="Use pretrained weights?"
-                help="If select no, the encoder, combiner, and decoder will all be initialized and trained from scratch. (e.g. when your images are very different from ImageNet or no suitable pretrained model exists.)">
+                help="If select no, the encoder, combiner, and decoder will all be initialized and trained from scratch.  
+               (e.g. when your images are very different from ImageNet or no suitable pretrained model exists.)">
                 <option value="false">No</option>
                 <option value="true" selected="true">Yes</option>
             </param>
             <when value="true">
                 <param name="fine_tune" type="select" label="Fine tune the encoder?"
-                    help="Whether to fine tune the encoder(combiner and decoder will be fine-tued anyway)" >
+                    help="Whether to fine tune the encoder(combiner and decoder will be fine-tuned anyway)" >
                     <option value="false" >No</option>
                     <option value="true" selected="true">Yes</option>
                 </param>
@@ -218,6 +219,7 @@
                         label="Test split proportion (only works if no split column in the metadata csv)"
                         value="0.2"
                         help="Fraction of data for testing (e.g., 0.2) train split + val split + test split should = 1."/>
+                <param name="threshold" type="float" value="0.5" min="0.0" max="1.0" optional="true" label="Decision Threshold (binary only)" help="Set the decision threshold for binary classification (0.0–1.0). Only applies when task is binary; default is 0.5." />
             </when>
             <when value="false">
                 <!-- No additional parameters to show if the user selects 'No' -->
@@ -307,8 +309,6 @@
                     <has_text text="Test Results" />
                 </assert_contents>
             </output>
-            <output name="output_report" file="expected_regression.html" compare="sim_size"/>
-
             <output_collection name="output_pred_csv" type="list" >
                 <element name="predictions.csv" >
                     <assert_contents>
@@ -317,18 +317,16 @@
                 </element>
             </output_collection>
         </test>
-    </tests>
+        </tests>
     <help>
         <![CDATA[
 **What it does**
-Image Learner for Classification/regression: trains and evaluates a image classification/regression model. 
+Image Learner for Classification: trains and evaluates a image classification model. 
 It uses the metadata csv to find the image paths and labels. 
 The metadata csv should contain a column with the name 'image_path' and a column with the name 'label'.
 Optionally, you can also add a column with the name 'split' to specify which split each row belongs to (train, val, test). 
 If you do not provide a split column, the tool will automatically split the data into train, val, and test sets based on the proportions you specify or [0.7, 0.1, 0.2] by default.
 
-**If the selected label column has more than 10 unique values, the tool will automatically treat the task as a regression problem and apply appropriate metrics (e.g., MSE, RMSE, R²).**
-
 
 **Outputs**
 The tool will output a trained model in the form of a ludwig_model file,
--- a/image_learner_cli.py	Fri Aug 08 13:06:28 2025 +0000
+++ b/image_learner_cli.py	Thu Aug 14 14:53:10 2025 +0000
@@ -31,6 +31,7 @@
 )
 from ludwig.utils.data_utils import get_split_path
 from ludwig.visualize import get_visualizations_registry
+from plotly_plots import build_classification_plots
 from sklearn.model_selection import train_test_split
 from utils import (
     build_tabbed_html,
@@ -52,6 +53,7 @@
     config: dict,
     split_info: Optional[str] = None,
     training_progress: dict = None,
+    output_type: Optional[str] = None,
 ) -> str:
     display_keys = [
         "task_type",
@@ -63,114 +65,119 @@
         "learning_rate",
         "random_seed",
         "early_stop",
+        "threshold",
     ]
-
     rows = []
-
     for key in display_keys:
-        val = config.get(key, "N/A")
-        if key == "task_type":
-            val = val.title() if isinstance(val, str) else val
-        if key == "batch_size":
-            if val is not None:
-                val = int(val)
-            else:
-                if training_progress:
-                    val = "Auto-selected batch size by Ludwig:<br>"
-                    resolved_val = training_progress.get("batch_size")
-                    val += f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>"
+        val = config.get(key, None)
+        if key == "threshold":
+            if output_type != "binary":
+                continue
+            val = val if val is not None else 0.5
+            val_str = f"{val:.2f}"
+            if val == 0.5:
+                val_str += " (default)"
+        else:
+            if key == "task_type":
+                val_str = val.title() if isinstance(val, str) else "N/A"
+            elif key == "batch_size":
+                if val is not None:
+                    val_str = int(val)
+                else:
+                    if training_progress:
+                        resolved_val = training_progress.get("batch_size")
+                        val_str = (
+                            "Auto-selected batch size by Ludwig:<br>"
+                            f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>"
+                        )
+                    else:
+                        val_str = "auto"
+            elif key == "learning_rate":
+                if val is not None and val != "auto":
+                    val_str = f"{val:.6f}"
                 else:
-                    val = "auto"
-        if key == "learning_rate":
-            resolved_val = None
-            if val is None or val == "auto":
-                if training_progress:
-                    resolved_val = training_progress.get("learning_rate")
-                    val = (
-                        "Auto-selected learning rate by Ludwig:<br>"
-                        f"<span style='font-size: 0.85em;'>"
-                        f"{resolved_val if resolved_val else val}</span><br>"
-                        "<span style='font-size: 0.85em;'>"
-                        "Based on model architecture and training setup "
-                        "(e.g., fine-tuning).<br>"
-                        "See <a href='https://ludwig.ai/latest/configuration/trainer/"
-                        "#trainer-parameters' target='_blank'>"
-                        "Ludwig Trainer Parameters</a> for details."
-                        "</span>"
-                    )
+                    if training_progress:
+                        resolved_val = training_progress.get("learning_rate")
+                        val_str = (
+                            "Auto-selected learning rate by Ludwig:<br>"
+                            f"<span style='font-size: 0.85em;'>"
+                            f"{resolved_val if resolved_val else 'auto'}</span><br>"
+                            "<span style='font-size: 0.85em;'>"
+                            "Based on model architecture and training setup "
+                            "(e.g., fine-tuning).<br>"
+                            "</span>"
+                        )
+                    else:
+                        val_str = (
+                            "Auto-selected by Ludwig<br>"
+                            "<span style='font-size: 0.85em;'>"
+                            "Automatically tuned based on architecture and dataset.<br>"
+                            "See <a href='https://ludwig.ai/latest/configuration/trainer/"
+                            "#trainer-parameters' target='_blank'>"
+                            "Ludwig Trainer Parameters</a> for details."
+                            "</span>"
+                        )
+            elif key == "epochs":
+                if val is None:
+                    val_str = "N/A"
                 else:
-                    val = (
-                        "Auto-selected by Ludwig<br>"
-                        "<span style='font-size: 0.85em;'>"
-                        "Automatically tuned based on architecture and dataset.<br>"
-                        "See <a href='https://ludwig.ai/latest/configuration/trainer/"
-                        "#trainer-parameters' target='_blank'>"
-                        "Ludwig Trainer Parameters</a> for details."
-                        "</span>"
-                    )
+                    if (
+                        training_progress
+                        and "epoch" in training_progress
+                        and val > training_progress["epoch"]
+                    ):
+                        val_str = (
+                            f"Because of early stopping: the training "
+                            f"stopped at epoch {training_progress['epoch']}"
+                        )
+                    else:
+                        val_str = val
             else:
-                val = f"{val:.6f}"
-        if key == "epochs":
-            if (
-                training_progress
-                and "epoch" in training_progress
-                and val > training_progress["epoch"]
-            ):
-                val = (
-                    f"Because of early stopping: the training "
-                    f"stopped at epoch {training_progress['epoch']}"
-                )
-
-        if val is None:
-            continue
+                val_str = val if val is not None else "N/A"
+            if val_str == "N/A" and key not in ["task_type"]:  # Skip if N/A for non-essential
+                continue
         rows.append(
             f"<tr>"
             f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>"
             f"{key.replace('_', ' ').title()}</td>"
             f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>"
-            f"{val}</td>"
+            f"{val_str}</td>"
             f"</tr>"
         )
-
     aug_cfg = config.get("augmentation")
     if aug_cfg:
         types = [str(a.get("type", "")) for a in aug_cfg]
         aug_val = ", ".join(types)
         rows.append(
-            "<tr>"
-            "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>"
-            "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>"
-            f"{aug_val}</td>"
-            "</tr>"
+            f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>"
+            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{aug_val}</td></tr>"
         )
-
     if split_info:
         rows.append(
-            f"<tr>"
-            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>"
-            f"Data Split</td>"
-            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>"
-            f"{split_info}</td>"
-            f"</tr>"
+            f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td>"
+            f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{split_info}</td></tr>"
         )
-
-    return (
-        "<h2 style='text-align: center;'>Training Setup</h2>"
-        "<div style='display: flex; justify-content: center;'>"
-        "<table style='border-collapse: collapse; width: 60%; table-layout: auto;'>"
-        "<thead><tr>"
-        "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>"
-        "Parameter</th>"
-        "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>"
-        "Value</th>"
-        "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>"
-        "<p style='text-align: center; font-size: 0.9em;'>"
-        "Model trained using Ludwig.<br>"
-        "If want to learn more about Ludwig default settings,"
-        "please check their <a href='https://ludwig.ai' target='_blank'>"
-        "website(ludwig.ai)</a>."
-        "</p><hr>"
-    )
+    html = f"""
+        <h2 style="text-align: center;">Model and Training Summary</h2>
+        <div style="display: flex; justify-content: center;">
+          <table style="border-collapse: collapse; width: 100%; table-layout: fixed;">
+            <thead><tr>
+              <th style="padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Parameter</th>
+              <th style="padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Value</th>
+            </tr></thead>
+            <tbody>
+              {''.join(rows)}
+            </tbody>
+          </table>
+        </div><br>
+        <p style="text-align: center; font-size: 0.9em;">
+          Model trained using <a href="https://ludwig.ai/" target="_blank" rel="noopener noreferrer">Ludwig</a>.
+          <a href="https://ludwig.ai/latest/configuration/" target="_blank" rel="noopener noreferrer">
+            Ludwig documentation provides detailed information about default model and training parameters
+          </a>
+        </p><hr>
+        """
+    return html
 
 
 def detect_output_type(test_stats):
@@ -244,7 +251,6 @@
                 "roc_auc": get_last_value(label_stats, "roc_auc"),
                 "hits_at_k": get_last_value(label_stats, "hits_at_k"),
             }
-
     # Test metrics: dynamic extraction according to exclusions
     test_label_stats = test_stats.get("label", {})
     if not test_label_stats:
@@ -252,13 +258,11 @@
     else:
         combined_stats = test_stats.get("combined", {})
         overall_stats = test_label_stats.get("overall_stats", {})
-
         # Define exclusions
         if output_type == "binary":
             exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"}
         else:
             exclude = {"per_class_stats", "confusion_matrix"}
-
         # 1. Get all scalar test_label_stats not excluded
         test_metrics = {}
         for k, v in test_label_stats.items():
@@ -268,17 +272,13 @@
                 continue
             if isinstance(v, (int, float, str, bool)):
                 test_metrics[k] = v
-
         # 2. Add overall_stats (flattened)
         for k, v in overall_stats.items():
             test_metrics[k] = v
-
         # 3. Optionally include combined/loss if present and not already
         if "loss" in combined_stats and "loss" not in test_metrics:
             test_metrics["loss"] = combined_stats["loss"]
-
         metrics["test"] = test_metrics
-
     return metrics
 
 
@@ -291,6 +291,11 @@
     )
 
 
+# -----------------------------------------
+# 2) MODEL PERFORMANCE (Train/Val/Test) TABLE
+# -----------------------------------------
+
+
 def format_stats_table_html(train_stats: dict, test_stats: dict) -> str:
     """Formats a combined HTML table for training, validation, and test metrics."""
     output_type = detect_output_type(test_stats)
@@ -310,35 +315,33 @@
             te = all_metrics["test"].get(metric_key)
             if all(x is not None for x in [t, v, te]):
                 rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"])
-
     if not rows:
         return "<table><tr><td>No metric values found.</td></tr></table>"
-
     html = (
         "<h2 style='text-align: center;'>Model Performance Summary</h2>"
         "<div style='display: flex; justify-content: center;'>"
-        "<table style='border-collapse: collapse; table-layout: auto;'>"
+        "<table class='performance-summary' style='border-collapse: collapse;'>"
         "<thead><tr>"
-        "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; "
-        "white-space: nowrap;'>Metric</th>"
-        "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; "
-        "white-space: nowrap;'>Train</th>"
-        "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; "
-        "white-space: nowrap;'>Validation</th>"
-        "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; "
-        "white-space: nowrap;'>Test</th>"
+        "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>"
+        "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th>"
+        "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th>"
+        "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th>"
         "</tr></thead><tbody>"
     )
     for row in rows:
         html += generate_table_row(
             row,
-            "padding: 10px; border: 1px solid #ccc; text-align: center; "
-            "white-space: nowrap;",
+            "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;"
         )
     html += "</tbody></table></div><br>"
     return html
 
 
+# -------------------------------------------
+# 3) TRAIN/VALIDATION PERFORMANCE SUMMARY TABLE
+# -------------------------------------------
+
+
 def format_train_val_stats_table_html(train_stats: dict, test_stats: dict) -> str:
     """Formats an HTML table for training and validation metrics."""
     output_type = detect_output_type(test_stats)
@@ -354,33 +357,32 @@
             v = all_metrics["validation"].get(metric_key)
             if t is not None and v is not None:
                 rows.append([display_name, f"{t:.4f}", f"{v:.4f}"])
-
     if not rows:
         return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>"
-
     html = (
         "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>"
         "<div style='display: flex; justify-content: center;'>"
-        "<table style='border-collapse: collapse; table-layout: auto;'>"
+        "<table class='performance-summary' style='border-collapse: collapse;'>"
         "<thead><tr>"
-        "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; "
-        "white-space: nowrap;'>Metric</th>"
-        "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; "
-        "white-space: nowrap;'>Train</th>"
-        "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; "
-        "white-space: nowrap;'>Validation</th>"
+        "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>"
+        "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th>"
+        "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th>"
         "</tr></thead><tbody>"
     )
     for row in rows:
         html += generate_table_row(
             row,
-            "padding: 10px; border: 1px solid #ccc; text-align: center; "
-            "white-space: nowrap;",
+            "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;"
         )
     html += "</tbody></table></div><br>"
     return html
 
 
+# -----------------------------------------
+# 4) TEST‐ONLY PERFORMANCE SUMMARY TABLE
+# -----------------------------------------
+
+
 def format_test_merged_stats_table_html(
     test_metrics: Dict[str, Optional[float]],
 ) -> str:
@@ -391,26 +393,21 @@
         value = test_metrics[key]
         if value is not None:
             rows.append([display_name, f"{value:.4f}"])
-
     if not rows:
         return "<table><tr><td>No test metric values found.</td></tr></table>"
-
     html = (
         "<h2 style='text-align: center;'>Test Performance Summary</h2>"
         "<div style='display: flex; justify-content: center;'>"
-        "<table style='border-collapse: collapse; table-layout: auto;'>"
+        "<table class='performance-summary' style='border-collapse: collapse;'>"
         "<thead><tr>"
-        "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; "
-        "white-space: nowrap;'>Metric</th>"
-        "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; "
-        "white-space: nowrap;'>Test</th>"
+        "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>"
+        "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th>"
         "</tr></thead><tbody>"
     )
     for row in rows:
         html += generate_table_row(
             row,
-            "padding: 10px; border: 1px solid #ccc; text-align: center; "
-            "white-space: nowrap;",
+            "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;"
         )
     html += "</tbody></table></div><br>"
     return html
@@ -426,13 +423,10 @@
     """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation)."""
     out = df.copy()
     out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int)
-
     idx_train = out.index[out[split_column] == 0].tolist()
-
     if not idx_train:
         logger.info("No rows with split=0; nothing to do.")
         return out
-
     # Always use stratify if possible
     stratify_arr = None
     if label_column and label_column in out.columns:
@@ -450,7 +444,6 @@
             logger.info("Using stratified split for validation set")
         else:
             logger.warning("Only one label class found; cannot stratify")
-
     if validation_size <= 0:
         logger.info("validation_size <= 0; keeping all as train.")
         return out
@@ -458,7 +451,6 @@
         logger.info("validation_size >= 1; moving all train → validation.")
         out.loc[idx_train, split_column] = 1
         return out
-
     # Always try stratified split first
     try:
         train_idx, val_idx = train_test_split(
@@ -476,7 +468,6 @@
             random_state=random_state,
             stratify=None,
         )
-
     out.loc[train_idx, split_column] = 0
     out.loc[val_idx, split_column] = 1
     out[split_column] = out[split_column].astype(int)
@@ -492,31 +483,24 @@
 ) -> pd.DataFrame:
     """Create a stratified random split when no split column exists."""
     out = df.copy()
-
     # initialize split column
     out[split_column] = 0
-
     if not label_column or label_column not in out.columns:
         logger.warning("No label column found; using random split without stratification")
         # fall back to simple random assignment
         indices = out.index.tolist()
         np.random.seed(random_state)
         np.random.shuffle(indices)
-
         n_total = len(indices)
         n_train = int(n_total * split_probabilities[0])
         n_val = int(n_total * split_probabilities[1])
-
         out.loc[indices[:n_train], split_column] = 0
         out.loc[indices[n_train:n_train + n_val], split_column] = 1
         out.loc[indices[n_train + n_val:], split_column] = 2
-
         return out.astype({split_column: int})
-
     # check if stratification is possible
     label_counts = out[label_column].value_counts()
     min_samples_per_class = label_counts.min()
-
     # ensure we have enough samples for stratification:
     # Each class must have at least as many samples as the number of splits,
     # so that each split can receive at least one sample per class.
@@ -529,19 +513,14 @@
         indices = out.index.tolist()
         np.random.seed(random_state)
         np.random.shuffle(indices)
-
         n_total = len(indices)
         n_train = int(n_total * split_probabilities[0])
         n_val = int(n_total * split_probabilities[1])
-
         out.loc[indices[:n_train], split_column] = 0
         out.loc[indices[n_train:n_train + n_val], split_column] = 1
         out.loc[indices[n_train + n_val:], split_column] = 2
-
         return out.astype({split_column: int})
-
     logger.info("Using stratified random split for train/validation/test sets")
-
     # first split: separate test set
     train_val_idx, test_idx = train_test_split(
         out.index.tolist(),
@@ -549,7 +528,6 @@
         random_state=random_state,
         stratify=out[label_column],
     )
-
     # second split: separate training and validation from remaining data
     val_size_adjusted = split_probabilities[1] / (split_probabilities[0] + split_probabilities[1])
     train_idx, val_idx = train_test_split(
@@ -558,21 +536,17 @@
         random_state=random_state,
         stratify=out.loc[train_val_idx, label_column],
     )
-
     # assign split values
     out.loc[train_idx, split_column] = 0
     out.loc[val_idx, split_column] = 1
     out.loc[test_idx, split_column] = 2
-
     logger.info("Successfully applied stratified random split")
     logger.info(f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")
-
     return out.astype({split_column: int})
 
 
 class Backend(Protocol):
     """Interface for a machine learning backend."""
-
     def prepare_config(
         self,
         config_params: Dict[str, Any],
@@ -604,14 +578,12 @@
 
 class LudwigDirectBackend:
     """Backend for running Ludwig experiments directly via the internal experiment_cli function."""
-
     def prepare_config(
         self,
         config_params: Dict[str, Any],
         split_config: Dict[str, Any],
     ) -> str:
         logger.info("LudwigDirectBackend: Preparing YAML configuration.")
-
         model_name = config_params.get("model_name", "resnet18")
         use_pretrained = config_params.get("use_pretrained", False)
         fine_tune = config_params.get("fine_tune", False)
@@ -634,9 +606,7 @@
             }
         else:
             encoder_config = {"type": raw_encoder}
-
         batch_size_cfg = batch_size or "auto"
-
         label_column_path = config_params.get("label_column_data_path")
         label_series = None
         if label_column_path is not None and Path(label_column_path).exists():
@@ -644,7 +614,6 @@
                 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME]
             except Exception as e:
                 logger.warning(f"Could not read label column for task detection: {e}")
-
         if (
             label_series is not None
             and ptypes.is_numeric_dtype(label_series.dtype)
@@ -653,9 +622,7 @@
             task_type = "regression"
         else:
             task_type = "classification"
-
         config_params["task_type"] = task_type
-
         image_feat: Dict[str, Any] = {
             "name": IMAGE_PATH_COLUMN_NAME,
             "type": "image",
@@ -663,7 +630,6 @@
         }
         if config_params.get("augmentation") is not None:
             image_feat["augmentation"] = config_params["augmentation"]
-
         if task_type == "regression":
             output_feat = {
                 "name": LABEL_COLUMN_NAME,
@@ -679,15 +645,15 @@
                 },
             }
             val_metric = config_params.get("validation_metric", "mean_squared_error")
-
         else:
             num_unique_labels = (
                 label_series.nunique() if label_series is not None else 2
             )
             output_type = "binary" if num_unique_labels == 2 else "category"
             output_feat = {"name": LABEL_COLUMN_NAME, "type": output_type}
+            if output_type == "binary" and config_params.get("threshold") is not None:
+                output_feat["threshold"] = float(config_params["threshold"])
             val_metric = None
-
         conf: Dict[str, Any] = {
             "model_type": "ecd",
             "input_features": [image_feat],
@@ -707,7 +673,6 @@
                 "in_memory": False,
             },
         }
-
         logger.debug("LudwigDirectBackend: Config dict built.")
         try:
             yaml_str = yaml.dump(conf, sort_keys=False, indent=2)
@@ -729,7 +694,6 @@
     ) -> None:
         """Invoke Ludwig's internal experiment_cli function to run the experiment."""
         logger.info("LudwigDirectBackend: Starting experiment execution.")
-
         try:
             from ludwig.experiment import experiment_cli
         except ImportError as e:
@@ -738,9 +702,7 @@
                 exc_info=True,
             )
             raise RuntimeError("Ludwig import failed.") from e
-
         output_dir.mkdir(parents=True, exist_ok=True)
-
         try:
             experiment_cli(
                 dataset=str(dataset_path),
@@ -771,16 +733,13 @@
             output_dir.glob("experiment_run*"),
             key=lambda p: p.stat().st_mtime,
         )
-
         if not exp_dirs:
             logger.warning(f"No experiment run directories found in {output_dir}")
             return None
-
         progress_file = exp_dirs[-1] / "model" / "training_progress.json"
         if not progress_file.exists():
             logger.warning(f"No training_progress.json found in {progress_file}")
             return None
-
         try:
             with progress_file.open("r", encoding="utf-8") as f:
                 data = json.load(f)
@@ -816,7 +775,6 @@
     def generate_plots(self, output_dir: Path) -> None:
         """Generate all registered Ludwig visualizations for the latest experiment run."""
         logger.info("Generating all Ludwig visualizations…")
-
         test_plots = {
             "compare_performance",
             "compare_classifiers_performance_from_prob",
@@ -840,7 +798,6 @@
             "learning_curves",
             "compare_classifiers_performance_subset",
         }
-
         output_dir = Path(output_dir)
         exp_dirs = sorted(
             output_dir.glob("experiment_run*"),
@@ -850,7 +807,6 @@
             logger.warning(f"No experiment run dirs found in {output_dir}")
             return
         exp_dir = exp_dirs[-1]
-
         viz_dir = exp_dir / "visualizations"
         viz_dir.mkdir(exist_ok=True)
         train_viz = viz_dir / "train"
@@ -865,7 +821,6 @@
         test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME)
         probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME)
         gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME)
-
         dataset_path = None
         split_file = None
         desc = exp_dir / DESCRIPTION_FILE_NAME
@@ -874,7 +829,6 @@
                 cfg = json.load(f)
             dataset_path = _check(Path(cfg.get("dataset", "")))
             split_file = _check(Path(get_split_path(cfg.get("dataset", ""))))
-
         output_feature = ""
         if desc.exists():
             try:
@@ -885,7 +839,6 @@
             with open(test_stats, "r") as f:
                 stats = json.load(f)
             output_feature = next(iter(stats.keys()), "")
-
         viz_registry = get_visualizations_registry()
         for viz_name, viz_func in viz_registry.items():
             if viz_name in train_plots:
@@ -894,7 +847,6 @@
                 viz_dir_plot = test_viz
             else:
                 continue
-
             try:
                 viz_func(
                     training_statistics=[training_stats] if training_stats else [],
@@ -914,7 +866,6 @@
                 logger.info(f"✔ Generated {viz_name}")
             except Exception as e:
                 logger.warning(f"✘ Skipped {viz_name}: {e}")
-
         logger.info(f"All visualizations written to {viz_dir}")
 
     def generate_html_report(
@@ -930,7 +881,6 @@
         report_path = cwd / report_name
         output_dir = Path(output_dir)
         output_type = None
-
         exp_dirs = sorted(
             output_dir.glob("experiment_run*"),
             key=lambda p: p.stat().st_mtime,
@@ -938,14 +888,11 @@
         if not exp_dirs:
             raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}")
         exp_dir = exp_dirs[-1]
-
         base_viz_dir = exp_dir / "visualizations"
         train_viz_dir = base_viz_dir / "train"
         test_viz_dir = base_viz_dir / "test"
-
         html = get_html_template()
         html += f"<h1>{title}</h1>"
-
         metrics_html = ""
         train_val_metrics_html = ""
         test_metrics_html = ""
@@ -971,7 +918,6 @@
             logger.warning(
                 f"Could not load stats for HTML report: {type(e).__name__}: {e}"
             )
-
         config_html = ""
         training_progress = self.get_training_process(output_dir)
         try:
@@ -986,93 +932,77 @@
         ) -> str:
             if not dir_path.exists():
                 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>"
-
+            # collect every PNG
             imgs = list(dir_path.glob("*.png"))
+            # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files ---
+            imgs = [
+                img for img in imgs
+                if not (
+                    img.name == "confusion_matrix.png"
+                    or img.name.startswith("confusion_matrix__label_top")
+                    or img.name == "roc_curves.png"
+                )
+            ]
             if not imgs:
                 return f"<h2>{title}</h2><p><em>No plots found.</em></p>"
-
-            if title == "Test Visualizations" and output_type == "binary":
+            if output_type == "binary":
                 order = [
-                    "confusion_matrix__label_top2.png",
                     "roc_curves_from_prediction_statistics.png",
                     "compare_performance_label.png",
                     "confusion_matrix_entropy__label_top2.png",
+                    # ...you can tweak ordering as needed
                 ]
                 img_names = {img.name: img for img in imgs}
-                ordered_imgs = [
-                    img_names[fname] for fname in order if fname in img_names
-                ]
-                remaining = sorted(
-                    [
-                        img
-                        for img in imgs
-                        if img.name not in order and img.name != "roc_curves.png"
-                    ]
-                )
-                imgs = ordered_imgs + remaining
-
-            elif title == "Test Visualizations" and output_type == "category":
+                ordered = [img_names[n] for n in order if n in img_names]
+                others = sorted(img for img in imgs if img.name not in order)
+                imgs = ordered + others
+            elif output_type == "category":
                 unwanted = {
                     "compare_classifiers_multiclass_multimetric__label_best10.png",
                     "compare_classifiers_multiclass_multimetric__label_top10.png",
                     "compare_classifiers_multiclass_multimetric__label_worst10.png",
                 }
                 display_order = [
-                    "confusion_matrix__label_top10.png",
                     "roc_curves.png",
                     "compare_performance_label.png",
                     "compare_classifiers_performance_from_prob.png",
-                    "compare_classifiers_multiclass_multimetric__label_sorted.png",
                     "confusion_matrix_entropy__label_top10.png",
                 ]
-                img_names = {img.name: img for img in imgs if img.name not in unwanted}
-                ordered_imgs = [
-                    img_names[fname] for fname in display_order if fname in img_names
-                ]
-                remaining = sorted(
-                    [img for img in img_names.values() if img.name not in display_order]
-                )
-                imgs = ordered_imgs + remaining
-
+                # filter and order
+                valid_imgs = [img for img in imgs if img.name not in unwanted]
+                img_map = {img.name: img for img in valid_imgs}
+                ordered = [img_map[n] for n in display_order if n in img_map]
+                others = sorted(img for img in valid_imgs if img.name not in display_order)
+                imgs = ordered + others
             else:
-                if output_type == "category":
-                    unwanted = {
-                        "compare_classifiers_multiclass_multimetric__label_best10.png",
-                        "compare_classifiers_multiclass_multimetric__label_top10.png",
-                        "compare_classifiers_multiclass_multimetric__label_worst10.png",
-                    }
-                    imgs = sorted([img for img in imgs if img.name not in unwanted])
-                else:
-                    imgs = sorted(imgs)
-
-            section_html = f"<h2 style='text-align: center;'>{title}</h2><div>"
+                # regression: just sort whatever's left
+                imgs = sorted(imgs)
+            # render each remaining PNG
+            html = ""
             for img in imgs:
                 b64 = encode_image_to_base64(str(img))
-                section_html += (
+                img_title = img.stem.replace("_", " ").title()
+                html += (
+                    f"<h2 style='text-align: center;'>{img_title}</h2>"
                     f'<div class="plot" style="margin-bottom:20px;text-align:center;">'
-                    f"<h3>{img.stem.replace('_', ' ').title()}</h3>"
                     f'<img src="data:image/png;base64,{b64}" '
                     f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />'
                     f"</div>"
                 )
-            section_html += "</div>"
-            return section_html
+            return html
 
         tab1_content = config_html + metrics_html
-
         tab2_content = train_val_metrics_html + render_img_section(
-            "Training & Validation Visualizations", train_viz_dir
+            "Training and Validation Visualizations", train_viz_dir
         )
-
         # --- Predictions vs Ground Truth table ---
         preds_section = ""
         parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
-        if parquet_path.exists():
+        if output_type == "regression" and parquet_path.exists():
             try:
                 # 1) load predictions from Parquet
                 df_preds = pd.read_parquet(parquet_path).reset_index(drop=True)
                 # assume the column containing your model's prediction is named "prediction"
-                # or contains that substring:
                 pred_col = next(
                     (c for c in df_preds.columns if "prediction" in c.lower()),
                     None,
@@ -1080,40 +1010,58 @@
                 if pred_col is None:
                     raise ValueError("No prediction column found in Parquet output")
                 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"})
-
                 # 2) load ground truth for the test split from prepared CSV
                 df_all = pd.read_csv(config["label_column_data_path"])
-                df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][
-                    LABEL_COLUMN_NAME
-                ].reset_index(drop=True)
-
-                # 3) concatenate side‐by‐side
+                df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][LABEL_COLUMN_NAME].reset_index(drop=True)
+                # 3) concatenate side-by-side
                 df_table = pd.concat([df_gt, df_pred], axis=1)
                 df_table.columns = [LABEL_COLUMN_NAME, "prediction"]
-
                 # 4) render as HTML
                 preds_html = df_table.to_html(index=False, classes="predictions-table")
                 preds_section = (
-                    "<h2 style='text-align: center;'>Predictions vs. Ground Truth</h2>"
-                    "<div style='overflow-x:auto; margin-bottom:20px;'>"
+                    "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>"
+                    "<div style='overflow-y:auto; max-height:400px; overflow-x:auto; margin-bottom:20px;'>"
                     + preds_html
                     + "</div>"
                 )
             except Exception as e:
                 logger.warning(f"Could not build Predictions vs GT table: {e}")
-        # Test tab = Metrics + Preds table + Visualizations
-
-        tab3_content = (
-            test_metrics_html
-            + preds_section
-            + render_img_section("Test Visualizations", test_viz_dir, output_type)
-        )
-
+        tab3_content = test_metrics_html + preds_section
+        if output_type in ("binary", "category"):
+            training_stats_path = exp_dir / "training_statistics.json"
+            interactive_plots = build_classification_plots(
+                str(test_stats_path),
+                str(training_stats_path),
+            )
+            for plot in interactive_plots:
+                # 2) inject the static "roc_curves_from_prediction_statistics.png"
+                if plot["title"] == "ROC-AUC":
+                    static_img = test_viz_dir / "roc_curves_from_prediction_statistics.png"
+                    if static_img.exists():
+                        b64 = encode_image_to_base64(str(static_img))
+                        tab3_content += (
+                            "<h2 style='text-align: center;'>"
+                            "Roc Curves From Prediction Statistics"
+                            "</h2>"
+                            f'<div class="plot" style="margin-bottom:20px;text-align:center;">'
+                            f'<img src="data:image/png;base64,{b64}" '
+                            f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />'
+                            "</div>"
+                        )
+                # always render the plotly panels exactly as before
+                tab3_content += (
+                    f"<h2 style='text-align: center;'>{plot['title']}</h2>"
+                    + plot["html"]
+                )
+            tab3_content += render_img_section(
+                "Test Visualizations",
+                test_viz_dir,
+                output_type
+            )
         # assemble the tabs and help modal
         tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
         modal_html = get_metrics_help_modal()
         html += tabbed_html + modal_html + get_html_closing()
-
         try:
             with open(report_path, "w") as f:
                 f.write(html)
@@ -1121,13 +1069,11 @@
         except Exception as e:
             logger.error(f"Failed to write HTML report: {e}")
             raise
-
         return report_path
 
 
 class WorkflowOrchestrator:
     """Manages the image-classification workflow."""
-
     def __init__(self, args: argparse.Namespace, backend: Backend):
         self.args = args
         self.backend = backend
@@ -1167,19 +1113,16 @@
         """Load CSV, update image paths, handle splits, and write prepared CSV."""
         if not self.temp_dir or not self.image_extract_dir:
             raise RuntimeError("Temp dirs not initialized before data prep.")
-
         try:
             df = pd.read_csv(self.args.csv_file)
             logger.info(f"Loaded CSV: {self.args.csv_file}")
         except Exception:
             logger.error("Error loading CSV file", exc_info=True)
             raise
-
         required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME}
         missing = required - set(df.columns)
         if missing:
             raise ValueError(f"Missing CSV columns: {', '.join(missing)}")
-
         try:
             df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply(
                 lambda p: str((self.image_extract_dir / p).resolve())
@@ -1187,7 +1130,6 @@
         except Exception:
             logger.error("Error updating image paths", exc_info=True)
             raise
-
         if SPLIT_COLUMN_NAME in df.columns:
             df, split_config, split_info = self._process_fixed_split(df)
         else:
@@ -1208,16 +1150,13 @@
                 f"{[int(p * 100) for p in self.args.split_probabilities]}% "
                 f"for train/val/test with balanced label distribution."
             )
-
         final_csv = self.temp_dir / TEMP_CSV_FILENAME
         try:
-
             df.to_csv(final_csv, index=False)
             logger.info(f"Saved prepared data to {final_csv}")
         except Exception:
             logger.error("Error saving prepared CSV", exc_info=True)
             raise
-
         return final_csv, split_config, split_info
 
     def _process_fixed_split(
@@ -1232,10 +1171,8 @@
             )
             if df[SPLIT_COLUMN_NAME].isna().any():
                 logger.warning("Split column contains non-numeric/missing values.")
-
             unique = set(df[SPLIT_COLUMN_NAME].dropna().unique())
             logger.info(f"Unique split values: {unique}")
-
             if unique == {0, 2}:
                 df = split_data_0_2(
                     df,
@@ -1256,9 +1193,7 @@
                 logger.info("Using fixed split as-is.")
             else:
                 raise ValueError(f"Unexpected split values: {unique}")
-
             return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info
-
         except Exception:
             logger.error("Error processing fixed split", exc_info=True)
             raise
@@ -1274,14 +1209,11 @@
         """Execute the full workflow end-to-end."""
         logger.info("Starting workflow...")
         self.args.output_dir.mkdir(parents=True, exist_ok=True)
-
         try:
             self._create_temp_dirs()
             self._extract_images()
             csv_path, split_cfg, split_info = self._prepare_data()
-
             use_pretrained = self.args.use_pretrained or self.args.fine_tune
-
             backend_args = {
                 "model_name": self.args.model_name,
                 "fine_tune": self.args.fine_tune,
@@ -1295,13 +1227,12 @@
                 "early_stop": self.args.early_stop,
                 "label_column_data_path": csv_path,
                 "augmentation": self.args.augmentation,
+                "threshold": self.args.threshold,
             }
             yaml_str = self.backend.prepare_config(backend_args, split_cfg)
-
             config_file = self.temp_dir / TEMP_CONFIG_FILENAME
             config_file.write_text(yaml_str)
             logger.info(f"Wrote backend config: {config_file}")
-
             self.backend.run_experiment(
                 csv_path,
                 config_file,
@@ -1349,8 +1280,6 @@
     aug_list = []
     for tok in aug_string.split(","):
         key = tok.strip()
-        if not key:
-            continue
         if key not in mapping:
             valid = ", ".join(mapping.keys())
             raise ValueError(f"Unknown augmentation '{key}'. Valid choices: {valid}")
@@ -1428,7 +1357,7 @@
     parser.add_argument(
         "--validation-size",
         type=float,
-        default=0.1,
+        default=0.15,
         help="Fraction for validation (0.0–1.0)",
     )
     parser.add_argument(
@@ -1472,9 +1401,16 @@
             "E.g. --augmentation random_horizontal_flip,random_rotate"
         ),
     )
-
+    parser.add_argument(
+        "--threshold",
+        type=float,
+        default=None,
+        help=(
+            "Decision threshold for binary classification (0.0–1.0)."
+            "Overrides default 0.5."
+        )
+    )
     args = parser.parse_args()
-
     if not 0.0 <= args.validation_size <= 1.0:
         parser.error("validation-size must be between 0.0 and 1.0")
     if not args.csv_file.is_file():
@@ -1487,10 +1423,8 @@
             setattr(args, "augmentation", augmentation_setup)
         except ValueError as e:
             parser.error(str(e))
-
     backend_instance = LudwigDirectBackend()
     orchestrator = WorkflowOrchestrator(args, backend_instance)
-
     exit_code = 0
     try:
         orchestrator.run()
@@ -1505,7 +1439,6 @@
 if __name__ == "__main__":
     try:
         import ludwig
-
         logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}")
     except ImportError:
         logger.error(
@@ -1513,5 +1446,4 @@
             "('pip install ludwig[image]')"
         )
         sys.exit(1)
-
     main()
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/plotly_plots.py	Thu Aug 14 14:53:10 2025 +0000
@@ -0,0 +1,148 @@
+import json
+from typing import Dict, List, Optional
+
+import numpy as np
+import plotly.graph_objects as go
+import plotly.io as pio
+
+
+def build_classification_plots(
+    test_stats_path: str,
+    training_stats_path: Optional[str] = None,
+) -> List[Dict[str, str]]:
+    """
+    Read Ludwig’s test_statistics.json and build three interactive Plotly panels:
+      - Confusion Matrix
+      - ROC-AUC
+      - Classification Report Heatmap
+
+    Returns a list of dicts, each with:
+      {
+        "title": <plot title>,
+        "html":  <HTML fragment for embedding>
+      }
+    """
+    # --- Load test stats ---
+    with open(test_stats_path, "r") as f:
+        test_stats = json.load(f)
+    label_stats = test_stats["label"]
+
+    # common sizing
+    cell = 40
+    n_classes = len(label_stats["confusion_matrix"])
+    side_px = max(cell * n_classes + 200, 600)
+    common_cfg = {"displayModeBar": True, "scrollZoom": True}
+
+    plots: List[Dict[str, str]] = []
+
+    # 0) Confusion Matrix
+    cm = np.array(label_stats["confusion_matrix"], dtype=int)
+    labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])])
+    total = cm.sum()
+
+    fig_cm = go.Figure(
+        go.Heatmap(
+            z=cm,
+            x=labels,
+            y=labels,
+            colorscale="Blues",
+            showscale=True,
+            colorbar=dict(title="Count"),
+        )
+    )
+    fig_cm.update_traces(xgap=2, ygap=2)
+    fig_cm.update_layout(
+        title=dict(text="Confusion Matrix", x=0.5),
+        xaxis_title="Predicted",
+        yaxis_title="Observed",
+        yaxis_autorange="reversed",
+        width=side_px,
+        height=side_px,
+        margin=dict(t=100, l=80, r=80, b=80),
+    )
+
+    # annotate counts and percentages
+    mval = cm.max() if cm.size else 0
+    thresh = mval / 2
+    for i in range(cm.shape[0]):
+        for j in range(cm.shape[1]):
+            v = cm[i, j]
+            pct = (v / total * 100) if total > 0 else 0
+            color = "white" if v > thresh else "black"
+            fig_cm.add_annotation(
+                x=labels[j],
+                y=labels[i],
+                text=f"<b>{v}</b>",
+                showarrow=False,
+                font=dict(color=color, size=14),
+                xanchor="center",
+                yanchor="bottom",
+                yshift=2,
+            )
+            fig_cm.add_annotation(
+                x=labels[j],
+                y=labels[i],
+                text=f"{pct:.1f}%",
+                showarrow=False,
+                font=dict(color=color, size=13),
+                xanchor="center",
+                yanchor="top",
+                yshift=-2,
+            )
+
+    plots.append({
+        "title": "Confusion Matrix",
+        "html": pio.to_html(
+            fig_cm,
+            full_html=False,
+            include_plotlyjs="cdn",
+            config=common_cfg
+        )
+    })
+
+    # 2) Classification Report Heatmap
+    pcs = label_stats.get("per_class_stats", {})
+    if pcs:
+        classes = list(pcs.keys())
+        metrics = ["precision", "recall", "f1_score"]
+        z, txt = [], []
+        for c in classes:
+            row, trow = [], []
+            for m in metrics:
+                val = pcs[c].get(m, 0)
+                row.append(val)
+                trow.append(f"{val:.2f}")
+            z.append(row)
+            txt.append(trow)
+
+        fig_cr = go.Figure(
+            go.Heatmap(
+                z=z,
+                x=metrics,
+                y=[str(c) for c in classes],
+                text=txt,
+                texttemplate="%{text}",
+                colorscale="Reds",
+                showscale=True,
+                colorbar=dict(title="Value"),
+            )
+        )
+        fig_cr.update_layout(
+            title="Classification Report",
+            xaxis_title="",
+            yaxis_title="Class",
+            width=side_px,
+            height=side_px,
+            margin=dict(t=80, l=80, r=80, b=80),
+        )
+        plots.append({
+            "title": "Classification Report",
+            "html": pio.to_html(
+                fig_cr,
+                full_html=False,
+                include_plotlyjs=False,
+                config=common_cfg
+            )
+        })
+
+    return plots
--- a/utils.py	Fri Aug 08 13:06:28 2025 +0000
+++ b/utils.py	Thu Aug 14 14:53:10 2025 +0000
@@ -8,6 +8,8 @@
     <head>
         <meta charset="UTF-8">
         <title>Galaxy-Ludwig Report</title>
+
+        <!-- your existing styles -->
         <style>
           body {
               font-family: Arial, sans-serif;
@@ -32,29 +34,21 @@
               color: #4CAF50;
               padding-bottom: 5px;
           }
+          /* baseline table setup */
           table {
               border-collapse: collapse;
               margin: 20px 0;
               width: 100%;
-              table-layout: fixed; /* Enforces consistent column widths */
+              table-layout: fixed;
           }
           table, th, td {
               border: 1px solid #ddd;
           }
           th, td {
               padding: 8px;
-              text-align: center; /* Center-align text */
-              vertical-align: middle; /* Center-align content vertically */
-              word-wrap: break-word; /* Break long words to avoid overflow */
-          }
-          th:first-child, td:first-child {
-              width: 5%; /* Smaller width for the first column */
-          }
-          th:nth-child(2), td:nth-child(2) {
-              width: 50%; /* Wider for the metric/description column */
-          }
-          th:last-child, td:last-child {
-              width: 25%; /* Value column gets remaining space */
+              text-align: center;
+              vertical-align: middle;
+              word-wrap: break-word;
           }
           th {
               background-color: #4CAF50;
@@ -68,7 +62,105 @@
               max-width: 100%;
               height: auto;
           }
+
+          /* -------------------
+             SORTABLE COLUMNS
+             ------------------- */
+          table.performance-summary th.sortable {
+            cursor: pointer;
+            position: relative;
+            user-select: none;
+          }
+          /* hide arrows by default */
+          table.performance-summary th.sortable::after {
+            content: '';
+            position: absolute;
+            right: 12px;
+            top: 50%;
+            transform: translateY(-50%);
+            font-size: 0.8em;
+            color: #666;
+          }
+          /* three states */
+          table.performance-summary th.sortable.sorted-none::after {
+            content: '⇅';
+          }
+          table.performance-summary th.sortable.sorted-asc::after {
+            content: '↑';
+          }
+          table.performance-summary th.sortable.sorted-desc::after {
+            content: '↓';
+          }
         </style>
+
+        <!-- sorting script -->
+        <script>
+        document.addEventListener('DOMContentLoaded', () => {
+          // 1) record each row's original position
+          document.querySelectorAll('table.performance-summary tbody').forEach(tbody => {
+            Array.from(tbody.rows).forEach((row, i) => {
+              row.dataset.originalOrder = i;
+            });
+          });
+
+          const getText = cell => cell.innerText.trim();
+          const comparer = (idx, asc) => (a, b) => {
+            const v1 = getText(a.children[idx]);
+            const v2 = getText(b.children[idx]);
+            const n1 = parseFloat(v1), n2 = parseFloat(v2);
+            if (!isNaN(n1) && !isNaN(n2)) {
+              return asc ? n1 - n2 : n2 - n1;
+            }
+            return asc
+              ? v1.localeCompare(v2)
+              : v2.localeCompare(v1);
+          };
+
+          document
+            .querySelectorAll('table.performance-summary th.sortable')
+            .forEach(th => {
+              // initialize to "none" state
+              th.classList.add('sorted-none');
+              th.addEventListener('click', () => {
+                const table = th.closest('table');
+                const allTh = table.querySelectorAll('th.sortable');
+
+                // 1) determine current state BEFORE clearing classes
+                let curr = th.classList.contains('sorted-asc')
+                  ? 'asc'
+                  : th.classList.contains('sorted-desc')
+                    ? 'desc'
+                    : 'none';
+                // 2) cycle to next state
+                let next = curr === 'none'
+                  ? 'asc'
+                  : curr === 'asc'
+                    ? 'desc'
+                    : 'none';
+
+                // 3) clear all sort markers
+                allTh.forEach(h =>
+                  h.classList.remove('sorted-none','sorted-asc','sorted-desc')
+                );
+                // 4) apply the new marker
+                th.classList.add(`sorted-${next}`);
+
+                // 5) sort or restore original order
+                const tbody = table.querySelector('tbody');
+                let rows = Array.from(tbody.rows);
+                if (next === 'none') {
+                  rows.sort((a, b) =>
+                    a.dataset.originalOrder - b.dataset.originalOrder
+                  );
+                } else {
+                  const idx = Array.from(th.parentNode.children).indexOf(th);
+                  rows.sort(comparer(idx, next === 'asc'));
+                }
+                rows.forEach(r => tbody.appendChild(r));
+              });
+            });
+        });
+        </script>
     </head>
     <body>
     <div class="container">
@@ -203,7 +295,7 @@
 </style>
 
 <div class="tabs">
-  <div class="tab active" onclick="showTab('metrics')">Config &amp; Results Summary</div>
+  <div class="tab active" onclick="showTab('metrics')">Config and Results Summary</div>
   <div class="tab" onclick="showTab('trainval')">Train/Validation Results</div>
   <div class="tab" onclick="showTab('test')">Test Results</div>
   <!-- always-visible help button -->
@@ -232,122 +324,193 @@
 
 
 def get_metrics_help_modal() -> str:
-    modal_html = """
-<div id="metricsHelpModal" class="modal">
-  <div class="modal-content">
-    <span class="close">×</span>
-    <h2>Model Evaluation Metrics — Help Guide</h2>
-    <div class="metrics-guide">
-      <h3>1) General Metrics</h3>
-      <p><strong>Loss:</strong> Measures the difference between predicted and actual values. Lower is better. Often used for optimization during training.</p>
-      <p><strong>Accuracy:</strong> Proportion of correct predictions among all predictions. Simple but can be misleading for imbalanced datasets.</p>
-      <p><strong>Micro Accuracy:</strong> Calculates accuracy by summing up all individual true positives and true negatives across all classes, making it suitable for multiclass or multilabel problems.</p>
-      <p><strong>Token Accuracy:</strong> Measures how often the predicted tokens (e.g., in sequences) match the true tokens. Useful in sequence prediction tasks like NLP.</p>
-      <h3>2) Precision, Recall & Specificity</h3>
-      <p><strong>Precision:</strong> Out of all positive predictions, how many were correct. Precision = TP / (TP + FP). Helps when false positives are costly.</p>
-      <p><strong>Recall (Sensitivity):</strong> Out of all actual positives, how many were predicted correctly. Recall = TP / (TP + FN). Important when missing positives is risky.</p>
-      <p><strong>Specificity:</strong> True negative rate. Measures how well the model identifies negatives. Specificity = TN / (TN + FP). Useful in medical testing to avoid false alarms.</p>
-      <h3>3) Macro, Micro, and Weighted Averages</h3>
-      <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric across all classes, treating each class equally, regardless of class frequency. Best when class sizes are balanced.</p>
-      <p><strong>Micro Precision / Recall / F1:</strong> Aggregates TP, FP, FN across all classes before computing the metric. Gives a global view and is ideal for class-imbalanced problems.</p>
-      <p><strong>Weighted Precision / Recall / F1:</strong> Averages each metric across classes, weighted by the number of true instances per class. Balances importance of classes based on frequency.</p>
-      <h3>4) Average Precision (PR-AUC Variants)</h3>
-      <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged across all classes equally. Useful for balanced multi-class problems.</p>
-      <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC using all instances. Best for imbalanced data or multi-label classification.</p>
-      <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged across individual samples (not classes). Ideal for multi-label problems where each sample can belong to multiple classes.</p>
-      <h3>5) ROC-AUC Variants</h3>
-      <p><strong>ROC-AUC:</strong> Measures model's ability to distinguish between classes. AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p>
-      <p><strong>Macro ROC-AUC:</strong> Averages the AUC across all classes equally. Suitable when classes are balanced and of equal importance.</p>
-      <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions across all classes. Useful in multiclass or multilabel settings with imbalance.</p>
-      <h3>6) Ranking Metrics</h3>
-      <p><strong>Hits at K:</strong> Measures whether the true label is among the top-K predictions. Common in recommendation systems and retrieval tasks.</p>
-      <h3>7) Confusion Matrix Stats (Per Class)</h3>
-      <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions for positives and negatives respectively.</p>
-      <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions — false alarms and missed detections.</p>
-      <h3>8) Other Useful Metrics</h3>
-      <p><strong>Cohen's Kappa:</strong> Measures agreement between predicted and actual values adjusted for chance. Useful for multiclass classification with imbalanced labels.</p>
-      <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure of prediction quality that takes into account TP, TN, FP, and FN. Particularly effective for imbalanced datasets.</p>
-      <h3>9) Metric Recommendations</h3>
-      <ul>
-        <li>Use <strong>Accuracy + F1</strong> for balanced data.</li>
-        <li>Use <strong>Precision, Recall, ROC-AUC</strong> for imbalanced datasets.</li>
-        <li>Use <strong>Average Precision Micro</strong> for multilabel or class-imbalanced problems.</li>
-        <li>Use <strong>Macro scores</strong> when all classes should be treated equally.</li>
-        <li>Use <strong>Weighted scores</strong> when class imbalance should be accounted for without ignoring small classes.</li>
-        <li>Use <strong>Confusion Matrix stats</strong> to analyze class-wise performance.</li>
-        <li>Use <strong>Hits at K</strong> for recommendation or ranking-based tasks.</li>
-      </ul>
-    </div>
-  </div>
-</div>
-"""
-    modal_css = """
-<style>
-.modal {
-  display: none;
-  position: fixed;
-  z-index: 1;
-  left: 0;
-  top: 0;
-  width: 100%;
-  height: 100%;
-  overflow: auto;
-  background-color: rgba(0,0,0,0.4);
-}
-.modal-content {
-  background-color: #fefefe;
-  margin: 15% auto;
-  padding: 20px;
-  border: 1px solid #888;
-  width: 80%;
-  max-width: 800px;
-}
-.close {
-  color: #aaa;
-  float: right;
-  font-size: 28px;
-  font-weight: bold;
-}
-.close:hover,
-.close:focus {
-  color: black;
-  text-decoration: none;
-  cursor: pointer;
-}
-.metrics-guide h3 {
-  margin-top: 20px;
-}
-.metrics-guide p {
-  margin: 5px 0;
-}
-.metrics-guide ul {
-  margin: 10px 0;
-  padding-left: 20px;
-}
-</style>
-"""
-    modal_js = """
-<script>
-document.addEventListener("DOMContentLoaded", function() {
-  var modal = document.getElementById("metricsHelpModal");
-  var openBtn = document.getElementById("openMetricsHelp");
-  var span = document.getElementsByClassName("close")[0];
-  if (openBtn && modal) {
-    openBtn.onclick = function() {
-      modal.style.display = "block";
-    };
-  }
-  if (span && modal) {
-    span.onclick = function() {
-      modal.style.display = "none";
-    };
-  }
-  window.onclick = function(event) {
-    if (event.target == modal) {
-      modal.style.display = "none";
-    }
-  }
-});
-</script>
-"""
+    modal_html = (
+        '<div id="metricsHelpModal" class="modal">'
+        '  <div class="modal-content">'
+        '    <span class="close">×</span>'
+        '    <h2>Model Evaluation Metrics — Help Guide</h2>'
+        '    <div class="metrics-guide">'
+        '      <h3>1) General Metrics (Regression and Classification)</h3>'
+        '      <p><strong>Loss (Regression & Classification):</strong> '
+        'Measures the difference between predicted and actual values, '
+        'optimized during training. Lower is better. '
+        'For regression, this is often Mean Squared Error (MSE) or '
+        'Mean Absolute Error (MAE). For classification, it’s typically '
+        'cross-entropy or log loss.</p>'
+        '      <h3>2) Regression Metrics</h3>'
+        '      <p><strong>Mean Absolute Error (MAE):</strong> '
+        'Average of absolute differences between predicted and actual values, '
+        'in the same units as the target. Use for interpretable error measurement '
+        'when all errors are equally important. Less sensitive to outliers than MSE.</p>'
+        '      <p><strong>Mean Squared Error (MSE):</strong> '
+        'Average of squared differences between predicted and actual values. '
+        'Penalizes larger errors more heavily, useful when large deviations are critical. '
+        'Often used as the loss function in regression.</p>'
+        '      <p><strong>Root Mean Squared Error (RMSE):</strong> '
+        'Square root of MSE, in the same units as the target. '
+        'Balances interpretability and sensitivity to large errors. '
+        'Widely used for regression evaluation.</p>'
+        '      <p><strong>Mean Absolute Percentage Error (MAPE):</strong> '
+        'Average absolute error as a percentage of actual values. '
+        'Scale-independent, ideal for comparing relative errors across datasets. '
+        'Avoid when actual values are near zero.</p>'
+        '      <p><strong>Root Mean Squared Percentage Error (RMSPE):</strong> '
+        'Square root of mean squared percentage error. Scale-independent, '
+        'penalizes larger relative errors more than MAPE. Use for forecasting '
+        'or when relative accuracy matters.</p>'
+        '      <p><strong>R² Score:</strong> Proportion of variance in the target '
+        'explained by the model. Ranges from negative infinity to 1 (perfect prediction). '
+        'Use to assess model fit; negative values indicate poor performance '
+        'compared to predicting the mean.</p>'
+        '      <h3>3) Classification Metrics</h3>'
+        '      <p><strong>Accuracy:</strong> Proportion of correct predictions '
+        'among all predictions. Simple but misleading for imbalanced datasets, '
+        'where high accuracy may hide poor performance on minority classes.</p>'
+        '      <p><strong>Micro Accuracy:</strong> Sums true positives and true negatives '
+        'across all classes before computing accuracy. Suitable for multiclass or '
+        'multilabel problems with imbalanced data.</p>'
+        '      <p><strong>Token Accuracy:</strong> Measures how often predicted tokens '
+        '(e.g., in sequences) match true tokens. Common in NLP tasks like text generation '
+        'or token classification.</p>'
+        '      <p><strong>Precision:</strong> Proportion of positive predictions that are '
+        'correct (TP / (TP + FP)). Use when false positives are costly, e.g., spam detection.</p>'
+        '      <p><strong>Recall (Sensitivity):</strong> Proportion of actual positives '
+        'correctly predicted (TP / (TP + FN)). Use when missing positives is risky, '
+        'e.g., disease detection.</p>'
+        '      <p><strong>Specificity:</strong> True negative rate (TN / (TN + FP)). '
+        'Measures ability to identify negatives. Useful in medical testing to avoid '
+        'false alarms.</p>'
+        '      <h3>4) Classification: Macro, Micro, and Weighted Averages</h3>'
+        '      <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric '
+        'across all classes, treating each equally. Best for balanced datasets where '
+        'all classes are equally important.</p>'
+        '      <p><strong>Micro Precision / Recall / F1:</strong> Aggregates true positives, '
+        'false positives, and false negatives across all classes before computing. '
+        'Ideal for imbalanced or multilabel classification.</p>'
+        '      <p><strong>Weighted Precision / Recall / F1:</strong> Averages metrics '
+        'across classes, weighted by the number of true instances per class. Balances '
+        'class importance based on frequency.</p>'
+        '      <h3>5) Classification: Average Precision (PR-AUC Variants)</h3>'
+        '      <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged '
+        'equally across classes. Use for balanced multiclass problems.</p>'
+        '      <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC '
+        'using all instances. Best for imbalanced or multilabel classification.</p>'
+        '      <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged '
+        'across individual samples. Ideal for multilabel tasks where samples have multiple '
+        'labels.</p>'
+        '      <h3>6) Classification: ROC-AUC Variants</h3>'
+        '      <p><strong>ROC-AUC:</strong> Measures ability to distinguish between classes. '
+        'AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p>'
+        '      <p><strong>Macro ROC-AUC:</strong> Averages AUC across all classes equally. '
+        'Suitable for balanced multiclass problems.</p>'
+        '      <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions '
+        'across all classes. Useful for imbalanced or multilabel settings.</p>'
+        '      <h3>7) Classification: Confusion Matrix Stats (Per Class)</h3>'
+        '      <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions '
+        'for positives and negatives, respectively.</p>'
+        '      <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions '
+        '— false alarms and missed detections.</p>'
+        '      <h3>8) Classification: Ranking Metrics</h3>'
+        '      <p><strong>Hits at K:</strong> Measures whether the true label is among the '
+        'top-K predictions. Common in recommendation systems and retrieval tasks.</p>'
+        '      <h3>9) Other Metrics (Classification)</h3>'
+        '      <p><strong>Cohen\'s Kappa:</strong> Measures agreement between predicted and '
+        'actual labels, adjusted for chance. Useful for multiclass classification with '
+        'imbalanced data.</p>'
+        '      <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure '
+        'using TP, TN, FP, and FN. Effective for imbalanced datasets.</p>'
+        '      <h3>10) Metric Recommendations</h3>'
+        '      <ul>'
+        '        <li><strong>Regression:</strong> Use <strong>RMSE</strong> or '
+        '<strong>MAE</strong> for general evaluation, <strong>MAPE</strong> for relative '
+        'errors, and <strong>R²</strong> to assess model fit. Use <strong>MSE</strong> or '
+        '<strong>RMSPE</strong> when large errors are critical.</li>'
+        '        <li><strong>Classification (Balanced Data):</strong> Use <strong>Accuracy</strong> '
+        'and <strong>F1</strong> for overall performance.</li>'
+        '        <li><strong>Classification (Imbalanced Data):</strong> Use <strong>Precision</strong>, '
+        '<strong>Recall</strong>, and <strong>ROC-AUC</strong> to focus on minority class '
+        'performance.</li>'
+        '        <li><strong>Multilabel or Imbalanced Classification:</strong> Use '
+        '<strong>Micro Precision/Recall/F1</strong> or <strong>Micro ROC-AUC</strong>.</li>'
+        '        <li><strong>Balanced Multiclass:</strong> Use <strong>Macro Precision/Recall/F1</strong> '
+        'or <strong>Macro ROC-AUC</strong>.</li>'
+        '        <li><strong>Class Frequency Matters:</strong> Use <strong>Weighted Precision/Recall/F1</strong> '
+        'to account for class imbalance.</li>'
+        '        <li><strong>Recommendation/Ranking:</strong> Use <strong>Hits at K</strong> for retrieval tasks.</li>'
+        '        <li><strong>Detailed Analysis:</strong> Use <strong>Confusion Matrix stats</strong> '
+        'for class-wise performance in classification.</li>'
+        '      </ul>'
+        '    </div>'
+        '  </div>'
+        '</div>'
+    )
+    modal_css = (
+        "<style>"
+        ".modal {"
+        "  display: none;"
+        "  position: fixed;"
+        "  z-index: 1;"
+        "  left: 0;"
+        "  top: 0;"
+        "  width: 100%;"
+        "  height: 100%;"
+        "  overflow: auto;"
+        "  background-color: rgba(0,0,0,0.4);"
+        "}"
+        ".modal-content {"
+        "  background-color: #fefefe;"
+        "  margin: 15% auto;"
+        "  padding: 20px;"
+        "  border: 1px solid #888;"
+        "  width: 80%;"
+        "  max-width: 800px;"
+        "}"
+        ".close {"
+        "  color: #aaa;"
+        "  float: right;"
+        "  font-size: 28px;"
+        "  font-weight: bold;"
+        "}"
+        ".close:hover,"
+        ".close:focus {"
+        "  color: black;"
+        "  text-decoration: none;"
+        "  cursor: pointer;"
+        "}"
+        ".metrics-guide h3 {"
+        "  margin-top: 20px;"
+        "}"
+        ".metrics-guide p {"
+        "  margin: 5px 0;"
+        "}"
+        ".metrics-guide ul {"
+        "  margin: 10px 0;"
+        "  padding-left: 20px;"
+        "}"
+        "</style>"
+    )
+    modal_js = (
+        "<script>"
+        'document.addEventListener("DOMContentLoaded", function() {'
+        '  var modal = document.getElementById("metricsHelpModal");'
+        '  var openBtn = document.getElementById("openMetricsHelp");'
+        '  var span = document.getElementsByClassName("close")[0];'
+        "  if (openBtn && modal) {"
+        "    openBtn.onclick = function() {"
+        "      modal.style.display = \"block\";"
+        "    };"
+        "  }"
+        "  if (span && modal) {"
+        "    span.onclick = function() {"
+        "      modal.style.display = \"none\";"
+        "    };"
+        "  }"
+        "  window.onclick = function(event) {"
+        "    if (event.target == modal) {"
+        "      modal.style.display = \"none\";"
+        "    }"
+        "  }"
+        "});"
+        "</script>"
+    )
     return modal_css + modal_html + modal_js