Mercurial > repos > goeckslab > image_learner
changeset 8:85e6f4b2ad18 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 8a42eb9b33df7e1df5ad5153b380e20b910a05b6
author | goeckslab |
---|---|
date | Thu, 14 Aug 2025 14:53:10 +0000 |
parents | 801a8b6973fb |
children | |
files | constants.py image_learner.xml image_learner_cli.py plotly_plots.py utils.py |
diffstat | 5 files changed, 670 insertions(+), 429 deletions(-) [+] |
line wrap: on
line diff
--- a/constants.py Fri Aug 08 13:06:28 2025 +0000 +++ b/constants.py Thu Aug 14 14:53:10 2025 +0000 @@ -87,28 +87,28 @@ } METRIC_DISPLAY_NAMES = { "accuracy": "Accuracy", - "accuracy_micro": "Accuracy-Micro", + "accuracy_micro": "Micro Accuracy", "loss": "Loss", "roc_auc": "ROC-AUC", - "roc_auc_macro": "ROC-AUC-Macro", - "roc_auc_micro": "ROC-AUC-Micro", + "roc_auc_macro": "Macro ROC-AUC", + "roc_auc_micro": "Micro ROC-AUC", "hits_at_k": "Hits at K", "precision": "Precision", "recall": "Recall", "specificity": "Specificity", "kappa_score": "Cohen's Kappa", "token_accuracy": "Token Accuracy", - "avg_precision_macro": "Precision-Macro", - "avg_recall_macro": "Recall-Macro", - "avg_f1_score_macro": "F1-score-Macro", - "avg_precision_micro": "Precision-Micro", - "avg_recall_micro": "Recall-Micro", - "avg_f1_score_micro": "F1-score-Micro", - "avg_precision_weighted": "Precision-Weighted", - "avg_recall_weighted": "Recall-Weighted", - "avg_f1_score_weighted": "F1-score-Weighted", - "average_precision_macro": "Precision-Average-Macro", - "average_precision_micro": "Precision-Average-Micro", + "avg_precision_macro": "Macro Precision", + "avg_recall_macro": "Macro Recall", + "avg_f1_score_macro": "Macro F1-score", + "avg_precision_micro": "Micro Precision", + "avg_recall_micro": "Micro Recall", + "avg_f1_score_micro": "Micro F1-score", + "avg_precision_weighted": "Weighted Precision", + "avg_recall_weighted": "Weighted Recall", + "avg_f1_score_weighted": "Weighted F1-score", + "average_precision_macro": "Macro Precision-Average", + "average_precision_micro": "Micro Precision-Average", "average_precision_samples": "Precision-Average-Samples", "mean_squared_error": "Mean Squared Error", "mean_absolute_error": "Mean Absolute Error",
--- a/image_learner.xml Fri Aug 08 13:06:28 2025 +0000 +++ b/image_learner.xml Thu Aug 14 14:53:10 2025 +0000 @@ -1,7 +1,7 @@ -<tool id="image_learner" name="Image Learner" version="0.1.1" profile="22.05"> - <description>trains and evaluates an image classification/regression model</description> +<tool id="image_learner" name="Image Learner for Classification" version="0.1.2" profile="22.05"> + <description>trains and evaluates a image classification model</description> <requirements> - <container type="docker">quay.io/goeckslab/galaxy-ludwig-gpu:0.10.1</container> + <container type="docker">quay.io/goeckslab/galaxy-ludwig-gpu:latest</container> </requirements> <required_files> <include path="utils.py" /> @@ -144,13 +144,14 @@ <conditional name="scratch_fine_tune"> <param name="use_pretrained" type="select" label="Use pretrained weights?" - help="If select no, the encoder, combiner, and decoder will all be initialized and trained from scratch. (e.g. when your images are very different from ImageNet or no suitable pretrained model exists.)"> + help="If select no, the encoder, combiner, and decoder will all be initialized and trained from scratch. + (e.g. when your images are very different from ImageNet or no suitable pretrained model exists.)"> <option value="false">No</option> <option value="true" selected="true">Yes</option> </param> <when value="true"> <param name="fine_tune" type="select" label="Fine tune the encoder?" - help="Whether to fine tune the encoder(combiner and decoder will be fine-tued anyway)" > + help="Whether to fine tune the encoder(combiner and decoder will be fine-tuned anyway)" > <option value="false" >No</option> <option value="true" selected="true">Yes</option> </param> @@ -218,6 +219,7 @@ label="Test split proportion (only works if no split column in the metadata csv)" value="0.2" help="Fraction of data for testing (e.g., 0.2) train split + val split + test split should = 1."/> + <param name="threshold" type="float" value="0.5" min="0.0" max="1.0" optional="true" label="Decision Threshold (binary only)" help="Set the decision threshold for binary classification (0.0–1.0). Only applies when task is binary; default is 0.5." /> </when> <when value="false"> <!-- No additional parameters to show if the user selects 'No' --> @@ -307,8 +309,6 @@ <has_text text="Test Results" /> </assert_contents> </output> - <output name="output_report" file="expected_regression.html" compare="sim_size"/> - <output_collection name="output_pred_csv" type="list" > <element name="predictions.csv" > <assert_contents> @@ -317,18 +317,16 @@ </element> </output_collection> </test> - </tests> + </tests> <help> <![CDATA[ **What it does** -Image Learner for Classification/regression: trains and evaluates a image classification/regression model. +Image Learner for Classification: trains and evaluates a image classification model. It uses the metadata csv to find the image paths and labels. The metadata csv should contain a column with the name 'image_path' and a column with the name 'label'. Optionally, you can also add a column with the name 'split' to specify which split each row belongs to (train, val, test). If you do not provide a split column, the tool will automatically split the data into train, val, and test sets based on the proportions you specify or [0.7, 0.1, 0.2] by default. -**If the selected label column has more than 10 unique values, the tool will automatically treat the task as a regression problem and apply appropriate metrics (e.g., MSE, RMSE, R²).** - **Outputs** The tool will output a trained model in the form of a ludwig_model file,
--- a/image_learner_cli.py Fri Aug 08 13:06:28 2025 +0000 +++ b/image_learner_cli.py Thu Aug 14 14:53:10 2025 +0000 @@ -31,6 +31,7 @@ ) from ludwig.utils.data_utils import get_split_path from ludwig.visualize import get_visualizations_registry +from plotly_plots import build_classification_plots from sklearn.model_selection import train_test_split from utils import ( build_tabbed_html, @@ -52,6 +53,7 @@ config: dict, split_info: Optional[str] = None, training_progress: dict = None, + output_type: Optional[str] = None, ) -> str: display_keys = [ "task_type", @@ -63,114 +65,119 @@ "learning_rate", "random_seed", "early_stop", + "threshold", ] - rows = [] - for key in display_keys: - val = config.get(key, "N/A") - if key == "task_type": - val = val.title() if isinstance(val, str) else val - if key == "batch_size": - if val is not None: - val = int(val) - else: - if training_progress: - val = "Auto-selected batch size by Ludwig:<br>" - resolved_val = training_progress.get("batch_size") - val += f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>" + val = config.get(key, None) + if key == "threshold": + if output_type != "binary": + continue + val = val if val is not None else 0.5 + val_str = f"{val:.2f}" + if val == 0.5: + val_str += " (default)" + else: + if key == "task_type": + val_str = val.title() if isinstance(val, str) else "N/A" + elif key == "batch_size": + if val is not None: + val_str = int(val) + else: + if training_progress: + resolved_val = training_progress.get("batch_size") + val_str = ( + "Auto-selected batch size by Ludwig:<br>" + f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>" + ) + else: + val_str = "auto" + elif key == "learning_rate": + if val is not None and val != "auto": + val_str = f"{val:.6f}" else: - val = "auto" - if key == "learning_rate": - resolved_val = None - if val is None or val == "auto": - if training_progress: - resolved_val = training_progress.get("learning_rate") - val = ( - "Auto-selected learning rate by Ludwig:<br>" - f"<span style='font-size: 0.85em;'>" - f"{resolved_val if resolved_val else val}</span><br>" - "<span style='font-size: 0.85em;'>" - "Based on model architecture and training setup " - "(e.g., fine-tuning).<br>" - "See <a href='https://ludwig.ai/latest/configuration/trainer/" - "#trainer-parameters' target='_blank'>" - "Ludwig Trainer Parameters</a> for details." - "</span>" - ) + if training_progress: + resolved_val = training_progress.get("learning_rate") + val_str = ( + "Auto-selected learning rate by Ludwig:<br>" + f"<span style='font-size: 0.85em;'>" + f"{resolved_val if resolved_val else 'auto'}</span><br>" + "<span style='font-size: 0.85em;'>" + "Based on model architecture and training setup " + "(e.g., fine-tuning).<br>" + "</span>" + ) + else: + val_str = ( + "Auto-selected by Ludwig<br>" + "<span style='font-size: 0.85em;'>" + "Automatically tuned based on architecture and dataset.<br>" + "See <a href='https://ludwig.ai/latest/configuration/trainer/" + "#trainer-parameters' target='_blank'>" + "Ludwig Trainer Parameters</a> for details." + "</span>" + ) + elif key == "epochs": + if val is None: + val_str = "N/A" else: - val = ( - "Auto-selected by Ludwig<br>" - "<span style='font-size: 0.85em;'>" - "Automatically tuned based on architecture and dataset.<br>" - "See <a href='https://ludwig.ai/latest/configuration/trainer/" - "#trainer-parameters' target='_blank'>" - "Ludwig Trainer Parameters</a> for details." - "</span>" - ) + if ( + training_progress + and "epoch" in training_progress + and val > training_progress["epoch"] + ): + val_str = ( + f"Because of early stopping: the training " + f"stopped at epoch {training_progress['epoch']}" + ) + else: + val_str = val else: - val = f"{val:.6f}" - if key == "epochs": - if ( - training_progress - and "epoch" in training_progress - and val > training_progress["epoch"] - ): - val = ( - f"Because of early stopping: the training " - f"stopped at epoch {training_progress['epoch']}" - ) - - if val is None: - continue + val_str = val if val is not None else "N/A" + if val_str == "N/A" and key not in ["task_type"]: # Skip if N/A for non-essential + continue rows.append( f"<tr>" f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" f"{key.replace('_', ' ').title()}</td>" f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" - f"{val}</td>" + f"{val_str}</td>" f"</tr>" ) - aug_cfg = config.get("augmentation") if aug_cfg: types = [str(a.get("type", "")) for a in aug_cfg] aug_val = ", ".join(types) rows.append( - "<tr>" - "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>" - "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" - f"{aug_val}</td>" - "</tr>" + f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>" + f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{aug_val}</td></tr>" ) - if split_info: rows.append( - f"<tr>" - f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" - f"Data Split</td>" - f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" - f"{split_info}</td>" - f"</tr>" + f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td>" + f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{split_info}</td></tr>" ) - - return ( - "<h2 style='text-align: center;'>Training Setup</h2>" - "<div style='display: flex; justify-content: center;'>" - "<table style='border-collapse: collapse; width: 60%; table-layout: auto;'>" - "<thead><tr>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>" - "Parameter</th>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>" - "Value</th>" - "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" - "<p style='text-align: center; font-size: 0.9em;'>" - "Model trained using Ludwig.<br>" - "If want to learn more about Ludwig default settings," - "please check their <a href='https://ludwig.ai' target='_blank'>" - "website(ludwig.ai)</a>." - "</p><hr>" - ) + html = f""" + <h2 style="text-align: center;">Model and Training Summary</h2> + <div style="display: flex; justify-content: center;"> + <table style="border-collapse: collapse; width: 100%; table-layout: fixed;"> + <thead><tr> + <th style="padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Parameter</th> + <th style="padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Value</th> + </tr></thead> + <tbody> + {''.join(rows)} + </tbody> + </table> + </div><br> + <p style="text-align: center; font-size: 0.9em;"> + Model trained using <a href="https://ludwig.ai/" target="_blank" rel="noopener noreferrer">Ludwig</a>. + <a href="https://ludwig.ai/latest/configuration/" target="_blank" rel="noopener noreferrer"> + Ludwig documentation provides detailed information about default model and training parameters + </a> + </p><hr> + """ + return html def detect_output_type(test_stats): @@ -244,7 +251,6 @@ "roc_auc": get_last_value(label_stats, "roc_auc"), "hits_at_k": get_last_value(label_stats, "hits_at_k"), } - # Test metrics: dynamic extraction according to exclusions test_label_stats = test_stats.get("label", {}) if not test_label_stats: @@ -252,13 +258,11 @@ else: combined_stats = test_stats.get("combined", {}) overall_stats = test_label_stats.get("overall_stats", {}) - # Define exclusions if output_type == "binary": exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} else: exclude = {"per_class_stats", "confusion_matrix"} - # 1. Get all scalar test_label_stats not excluded test_metrics = {} for k, v in test_label_stats.items(): @@ -268,17 +272,13 @@ continue if isinstance(v, (int, float, str, bool)): test_metrics[k] = v - # 2. Add overall_stats (flattened) for k, v in overall_stats.items(): test_metrics[k] = v - # 3. Optionally include combined/loss if present and not already if "loss" in combined_stats and "loss" not in test_metrics: test_metrics["loss"] = combined_stats["loss"] - metrics["test"] = test_metrics - return metrics @@ -291,6 +291,11 @@ ) +# ----------------------------------------- +# 2) MODEL PERFORMANCE (Train/Val/Test) TABLE +# ----------------------------------------- + + def format_stats_table_html(train_stats: dict, test_stats: dict) -> str: """Formats a combined HTML table for training, validation, and test metrics.""" output_type = detect_output_type(test_stats) @@ -310,35 +315,33 @@ te = all_metrics["test"].get(metric_key) if all(x is not None for x in [t, v, te]): rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"]) - if not rows: return "<table><tr><td>No metric values found.</td></tr></table>" - html = ( "<h2 style='text-align: center;'>Model Performance Summary</h2>" "<div style='display: flex; justify-content: center;'>" - "<table style='border-collapse: collapse; table-layout: auto;'>" + "<table class='performance-summary' style='border-collapse: collapse;'>" "<thead><tr>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " - "white-space: nowrap;'>Metric</th>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " - "white-space: nowrap;'>Train</th>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " - "white-space: nowrap;'>Validation</th>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " - "white-space: nowrap;'>Test</th>" + "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>" + "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th>" + "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th>" + "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th>" "</tr></thead><tbody>" ) for row in rows: html += generate_table_row( row, - "padding: 10px; border: 1px solid #ccc; text-align: center; " - "white-space: nowrap;", + "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" ) html += "</tbody></table></div><br>" return html +# ------------------------------------------- +# 3) TRAIN/VALIDATION PERFORMANCE SUMMARY TABLE +# ------------------------------------------- + + def format_train_val_stats_table_html(train_stats: dict, test_stats: dict) -> str: """Formats an HTML table for training and validation metrics.""" output_type = detect_output_type(test_stats) @@ -354,33 +357,32 @@ v = all_metrics["validation"].get(metric_key) if t is not None and v is not None: rows.append([display_name, f"{t:.4f}", f"{v:.4f}"]) - if not rows: return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>" - html = ( "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>" "<div style='display: flex; justify-content: center;'>" - "<table style='border-collapse: collapse; table-layout: auto;'>" + "<table class='performance-summary' style='border-collapse: collapse;'>" "<thead><tr>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " - "white-space: nowrap;'>Metric</th>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " - "white-space: nowrap;'>Train</th>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " - "white-space: nowrap;'>Validation</th>" + "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>" + "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th>" + "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th>" "</tr></thead><tbody>" ) for row in rows: html += generate_table_row( row, - "padding: 10px; border: 1px solid #ccc; text-align: center; " - "white-space: nowrap;", + "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" ) html += "</tbody></table></div><br>" return html +# ----------------------------------------- +# 4) TEST‐ONLY PERFORMANCE SUMMARY TABLE +# ----------------------------------------- + + def format_test_merged_stats_table_html( test_metrics: Dict[str, Optional[float]], ) -> str: @@ -391,26 +393,21 @@ value = test_metrics[key] if value is not None: rows.append([display_name, f"{value:.4f}"]) - if not rows: return "<table><tr><td>No test metric values found.</td></tr></table>" - html = ( "<h2 style='text-align: center;'>Test Performance Summary</h2>" "<div style='display: flex; justify-content: center;'>" - "<table style='border-collapse: collapse; table-layout: auto;'>" + "<table class='performance-summary' style='border-collapse: collapse;'>" "<thead><tr>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " - "white-space: nowrap;'>Metric</th>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " - "white-space: nowrap;'>Test</th>" + "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>" + "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th>" "</tr></thead><tbody>" ) for row in rows: html += generate_table_row( row, - "padding: 10px; border: 1px solid #ccc; text-align: center; " - "white-space: nowrap;", + "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" ) html += "</tbody></table></div><br>" return html @@ -426,13 +423,10 @@ """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" out = df.copy() out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) - idx_train = out.index[out[split_column] == 0].tolist() - if not idx_train: logger.info("No rows with split=0; nothing to do.") return out - # Always use stratify if possible stratify_arr = None if label_column and label_column in out.columns: @@ -450,7 +444,6 @@ logger.info("Using stratified split for validation set") else: logger.warning("Only one label class found; cannot stratify") - if validation_size <= 0: logger.info("validation_size <= 0; keeping all as train.") return out @@ -458,7 +451,6 @@ logger.info("validation_size >= 1; moving all train → validation.") out.loc[idx_train, split_column] = 1 return out - # Always try stratified split first try: train_idx, val_idx = train_test_split( @@ -476,7 +468,6 @@ random_state=random_state, stratify=None, ) - out.loc[train_idx, split_column] = 0 out.loc[val_idx, split_column] = 1 out[split_column] = out[split_column].astype(int) @@ -492,31 +483,24 @@ ) -> pd.DataFrame: """Create a stratified random split when no split column exists.""" out = df.copy() - # initialize split column out[split_column] = 0 - if not label_column or label_column not in out.columns: logger.warning("No label column found; using random split without stratification") # fall back to simple random assignment indices = out.index.tolist() np.random.seed(random_state) np.random.shuffle(indices) - n_total = len(indices) n_train = int(n_total * split_probabilities[0]) n_val = int(n_total * split_probabilities[1]) - out.loc[indices[:n_train], split_column] = 0 out.loc[indices[n_train:n_train + n_val], split_column] = 1 out.loc[indices[n_train + n_val:], split_column] = 2 - return out.astype({split_column: int}) - # check if stratification is possible label_counts = out[label_column].value_counts() min_samples_per_class = label_counts.min() - # ensure we have enough samples for stratification: # Each class must have at least as many samples as the number of splits, # so that each split can receive at least one sample per class. @@ -529,19 +513,14 @@ indices = out.index.tolist() np.random.seed(random_state) np.random.shuffle(indices) - n_total = len(indices) n_train = int(n_total * split_probabilities[0]) n_val = int(n_total * split_probabilities[1]) - out.loc[indices[:n_train], split_column] = 0 out.loc[indices[n_train:n_train + n_val], split_column] = 1 out.loc[indices[n_train + n_val:], split_column] = 2 - return out.astype({split_column: int}) - logger.info("Using stratified random split for train/validation/test sets") - # first split: separate test set train_val_idx, test_idx = train_test_split( out.index.tolist(), @@ -549,7 +528,6 @@ random_state=random_state, stratify=out[label_column], ) - # second split: separate training and validation from remaining data val_size_adjusted = split_probabilities[1] / (split_probabilities[0] + split_probabilities[1]) train_idx, val_idx = train_test_split( @@ -558,21 +536,17 @@ random_state=random_state, stratify=out.loc[train_val_idx, label_column], ) - # assign split values out.loc[train_idx, split_column] = 0 out.loc[val_idx, split_column] = 1 out.loc[test_idx, split_column] = 2 - logger.info("Successfully applied stratified random split") logger.info(f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}") - return out.astype({split_column: int}) class Backend(Protocol): """Interface for a machine learning backend.""" - def prepare_config( self, config_params: Dict[str, Any], @@ -604,14 +578,12 @@ class LudwigDirectBackend: """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" - def prepare_config( self, config_params: Dict[str, Any], split_config: Dict[str, Any], ) -> str: logger.info("LudwigDirectBackend: Preparing YAML configuration.") - model_name = config_params.get("model_name", "resnet18") use_pretrained = config_params.get("use_pretrained", False) fine_tune = config_params.get("fine_tune", False) @@ -634,9 +606,7 @@ } else: encoder_config = {"type": raw_encoder} - batch_size_cfg = batch_size or "auto" - label_column_path = config_params.get("label_column_data_path") label_series = None if label_column_path is not None and Path(label_column_path).exists(): @@ -644,7 +614,6 @@ label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] except Exception as e: logger.warning(f"Could not read label column for task detection: {e}") - if ( label_series is not None and ptypes.is_numeric_dtype(label_series.dtype) @@ -653,9 +622,7 @@ task_type = "regression" else: task_type = "classification" - config_params["task_type"] = task_type - image_feat: Dict[str, Any] = { "name": IMAGE_PATH_COLUMN_NAME, "type": "image", @@ -663,7 +630,6 @@ } if config_params.get("augmentation") is not None: image_feat["augmentation"] = config_params["augmentation"] - if task_type == "regression": output_feat = { "name": LABEL_COLUMN_NAME, @@ -679,15 +645,15 @@ }, } val_metric = config_params.get("validation_metric", "mean_squared_error") - else: num_unique_labels = ( label_series.nunique() if label_series is not None else 2 ) output_type = "binary" if num_unique_labels == 2 else "category" output_feat = {"name": LABEL_COLUMN_NAME, "type": output_type} + if output_type == "binary" and config_params.get("threshold") is not None: + output_feat["threshold"] = float(config_params["threshold"]) val_metric = None - conf: Dict[str, Any] = { "model_type": "ecd", "input_features": [image_feat], @@ -707,7 +673,6 @@ "in_memory": False, }, } - logger.debug("LudwigDirectBackend: Config dict built.") try: yaml_str = yaml.dump(conf, sort_keys=False, indent=2) @@ -729,7 +694,6 @@ ) -> None: """Invoke Ludwig's internal experiment_cli function to run the experiment.""" logger.info("LudwigDirectBackend: Starting experiment execution.") - try: from ludwig.experiment import experiment_cli except ImportError as e: @@ -738,9 +702,7 @@ exc_info=True, ) raise RuntimeError("Ludwig import failed.") from e - output_dir.mkdir(parents=True, exist_ok=True) - try: experiment_cli( dataset=str(dataset_path), @@ -771,16 +733,13 @@ output_dir.glob("experiment_run*"), key=lambda p: p.stat().st_mtime, ) - if not exp_dirs: logger.warning(f"No experiment run directories found in {output_dir}") return None - progress_file = exp_dirs[-1] / "model" / "training_progress.json" if not progress_file.exists(): logger.warning(f"No training_progress.json found in {progress_file}") return None - try: with progress_file.open("r", encoding="utf-8") as f: data = json.load(f) @@ -816,7 +775,6 @@ def generate_plots(self, output_dir: Path) -> None: """Generate all registered Ludwig visualizations for the latest experiment run.""" logger.info("Generating all Ludwig visualizations…") - test_plots = { "compare_performance", "compare_classifiers_performance_from_prob", @@ -840,7 +798,6 @@ "learning_curves", "compare_classifiers_performance_subset", } - output_dir = Path(output_dir) exp_dirs = sorted( output_dir.glob("experiment_run*"), @@ -850,7 +807,6 @@ logger.warning(f"No experiment run dirs found in {output_dir}") return exp_dir = exp_dirs[-1] - viz_dir = exp_dir / "visualizations" viz_dir.mkdir(exist_ok=True) train_viz = viz_dir / "train" @@ -865,7 +821,6 @@ test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) - dataset_path = None split_file = None desc = exp_dir / DESCRIPTION_FILE_NAME @@ -874,7 +829,6 @@ cfg = json.load(f) dataset_path = _check(Path(cfg.get("dataset", ""))) split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) - output_feature = "" if desc.exists(): try: @@ -885,7 +839,6 @@ with open(test_stats, "r") as f: stats = json.load(f) output_feature = next(iter(stats.keys()), "") - viz_registry = get_visualizations_registry() for viz_name, viz_func in viz_registry.items(): if viz_name in train_plots: @@ -894,7 +847,6 @@ viz_dir_plot = test_viz else: continue - try: viz_func( training_statistics=[training_stats] if training_stats else [], @@ -914,7 +866,6 @@ logger.info(f"✔ Generated {viz_name}") except Exception as e: logger.warning(f"✘ Skipped {viz_name}: {e}") - logger.info(f"All visualizations written to {viz_dir}") def generate_html_report( @@ -930,7 +881,6 @@ report_path = cwd / report_name output_dir = Path(output_dir) output_type = None - exp_dirs = sorted( output_dir.glob("experiment_run*"), key=lambda p: p.stat().st_mtime, @@ -938,14 +888,11 @@ if not exp_dirs: raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") exp_dir = exp_dirs[-1] - base_viz_dir = exp_dir / "visualizations" train_viz_dir = base_viz_dir / "train" test_viz_dir = base_viz_dir / "test" - html = get_html_template() html += f"<h1>{title}</h1>" - metrics_html = "" train_val_metrics_html = "" test_metrics_html = "" @@ -971,7 +918,6 @@ logger.warning( f"Could not load stats for HTML report: {type(e).__name__}: {e}" ) - config_html = "" training_progress = self.get_training_process(output_dir) try: @@ -986,93 +932,77 @@ ) -> str: if not dir_path.exists(): return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" - + # collect every PNG imgs = list(dir_path.glob("*.png")) + # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files --- + imgs = [ + img for img in imgs + if not ( + img.name == "confusion_matrix.png" + or img.name.startswith("confusion_matrix__label_top") + or img.name == "roc_curves.png" + ) + ] if not imgs: return f"<h2>{title}</h2><p><em>No plots found.</em></p>" - - if title == "Test Visualizations" and output_type == "binary": + if output_type == "binary": order = [ - "confusion_matrix__label_top2.png", "roc_curves_from_prediction_statistics.png", "compare_performance_label.png", "confusion_matrix_entropy__label_top2.png", + # ...you can tweak ordering as needed ] img_names = {img.name: img for img in imgs} - ordered_imgs = [ - img_names[fname] for fname in order if fname in img_names - ] - remaining = sorted( - [ - img - for img in imgs - if img.name not in order and img.name != "roc_curves.png" - ] - ) - imgs = ordered_imgs + remaining - - elif title == "Test Visualizations" and output_type == "category": + ordered = [img_names[n] for n in order if n in img_names] + others = sorted(img for img in imgs if img.name not in order) + imgs = ordered + others + elif output_type == "category": unwanted = { "compare_classifiers_multiclass_multimetric__label_best10.png", "compare_classifiers_multiclass_multimetric__label_top10.png", "compare_classifiers_multiclass_multimetric__label_worst10.png", } display_order = [ - "confusion_matrix__label_top10.png", "roc_curves.png", "compare_performance_label.png", "compare_classifiers_performance_from_prob.png", - "compare_classifiers_multiclass_multimetric__label_sorted.png", "confusion_matrix_entropy__label_top10.png", ] - img_names = {img.name: img for img in imgs if img.name not in unwanted} - ordered_imgs = [ - img_names[fname] for fname in display_order if fname in img_names - ] - remaining = sorted( - [img for img in img_names.values() if img.name not in display_order] - ) - imgs = ordered_imgs + remaining - + # filter and order + valid_imgs = [img for img in imgs if img.name not in unwanted] + img_map = {img.name: img for img in valid_imgs} + ordered = [img_map[n] for n in display_order if n in img_map] + others = sorted(img for img in valid_imgs if img.name not in display_order) + imgs = ordered + others else: - if output_type == "category": - unwanted = { - "compare_classifiers_multiclass_multimetric__label_best10.png", - "compare_classifiers_multiclass_multimetric__label_top10.png", - "compare_classifiers_multiclass_multimetric__label_worst10.png", - } - imgs = sorted([img for img in imgs if img.name not in unwanted]) - else: - imgs = sorted(imgs) - - section_html = f"<h2 style='text-align: center;'>{title}</h2><div>" + # regression: just sort whatever's left + imgs = sorted(imgs) + # render each remaining PNG + html = "" for img in imgs: b64 = encode_image_to_base64(str(img)) - section_html += ( + img_title = img.stem.replace("_", " ").title() + html += ( + f"<h2 style='text-align: center;'>{img_title}</h2>" f'<div class="plot" style="margin-bottom:20px;text-align:center;">' - f"<h3>{img.stem.replace('_', ' ').title()}</h3>" f'<img src="data:image/png;base64,{b64}" ' f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' f"</div>" ) - section_html += "</div>" - return section_html + return html tab1_content = config_html + metrics_html - tab2_content = train_val_metrics_html + render_img_section( - "Training & Validation Visualizations", train_viz_dir + "Training and Validation Visualizations", train_viz_dir ) - # --- Predictions vs Ground Truth table --- preds_section = "" parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME - if parquet_path.exists(): + if output_type == "regression" and parquet_path.exists(): try: # 1) load predictions from Parquet df_preds = pd.read_parquet(parquet_path).reset_index(drop=True) # assume the column containing your model's prediction is named "prediction" - # or contains that substring: pred_col = next( (c for c in df_preds.columns if "prediction" in c.lower()), None, @@ -1080,40 +1010,58 @@ if pred_col is None: raise ValueError("No prediction column found in Parquet output") df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"}) - # 2) load ground truth for the test split from prepared CSV df_all = pd.read_csv(config["label_column_data_path"]) - df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][ - LABEL_COLUMN_NAME - ].reset_index(drop=True) - - # 3) concatenate side‐by‐side + df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][LABEL_COLUMN_NAME].reset_index(drop=True) + # 3) concatenate side-by-side df_table = pd.concat([df_gt, df_pred], axis=1) df_table.columns = [LABEL_COLUMN_NAME, "prediction"] - # 4) render as HTML preds_html = df_table.to_html(index=False, classes="predictions-table") preds_section = ( - "<h2 style='text-align: center;'>Predictions vs. Ground Truth</h2>" - "<div style='overflow-x:auto; margin-bottom:20px;'>" + "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>" + "<div style='overflow-y:auto; max-height:400px; overflow-x:auto; margin-bottom:20px;'>" + preds_html + "</div>" ) except Exception as e: logger.warning(f"Could not build Predictions vs GT table: {e}") - # Test tab = Metrics + Preds table + Visualizations - - tab3_content = ( - test_metrics_html - + preds_section - + render_img_section("Test Visualizations", test_viz_dir, output_type) - ) - + tab3_content = test_metrics_html + preds_section + if output_type in ("binary", "category"): + training_stats_path = exp_dir / "training_statistics.json" + interactive_plots = build_classification_plots( + str(test_stats_path), + str(training_stats_path), + ) + for plot in interactive_plots: + # 2) inject the static "roc_curves_from_prediction_statistics.png" + if plot["title"] == "ROC-AUC": + static_img = test_viz_dir / "roc_curves_from_prediction_statistics.png" + if static_img.exists(): + b64 = encode_image_to_base64(str(static_img)) + tab3_content += ( + "<h2 style='text-align: center;'>" + "Roc Curves From Prediction Statistics" + "</h2>" + f'<div class="plot" style="margin-bottom:20px;text-align:center;">' + f'<img src="data:image/png;base64,{b64}" ' + f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' + "</div>" + ) + # always render the plotly panels exactly as before + tab3_content += ( + f"<h2 style='text-align: center;'>{plot['title']}</h2>" + + plot["html"] + ) + tab3_content += render_img_section( + "Test Visualizations", + test_viz_dir, + output_type + ) # assemble the tabs and help modal tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) modal_html = get_metrics_help_modal() html += tabbed_html + modal_html + get_html_closing() - try: with open(report_path, "w") as f: f.write(html) @@ -1121,13 +1069,11 @@ except Exception as e: logger.error(f"Failed to write HTML report: {e}") raise - return report_path class WorkflowOrchestrator: """Manages the image-classification workflow.""" - def __init__(self, args: argparse.Namespace, backend: Backend): self.args = args self.backend = backend @@ -1167,19 +1113,16 @@ """Load CSV, update image paths, handle splits, and write prepared CSV.""" if not self.temp_dir or not self.image_extract_dir: raise RuntimeError("Temp dirs not initialized before data prep.") - try: df = pd.read_csv(self.args.csv_file) logger.info(f"Loaded CSV: {self.args.csv_file}") except Exception: logger.error("Error loading CSV file", exc_info=True) raise - required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} missing = required - set(df.columns) if missing: raise ValueError(f"Missing CSV columns: {', '.join(missing)}") - try: df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( lambda p: str((self.image_extract_dir / p).resolve()) @@ -1187,7 +1130,6 @@ except Exception: logger.error("Error updating image paths", exc_info=True) raise - if SPLIT_COLUMN_NAME in df.columns: df, split_config, split_info = self._process_fixed_split(df) else: @@ -1208,16 +1150,13 @@ f"{[int(p * 100) for p in self.args.split_probabilities]}% " f"for train/val/test with balanced label distribution." ) - final_csv = self.temp_dir / TEMP_CSV_FILENAME try: - df.to_csv(final_csv, index=False) logger.info(f"Saved prepared data to {final_csv}") except Exception: logger.error("Error saving prepared CSV", exc_info=True) raise - return final_csv, split_config, split_info def _process_fixed_split( @@ -1232,10 +1171,8 @@ ) if df[SPLIT_COLUMN_NAME].isna().any(): logger.warning("Split column contains non-numeric/missing values.") - unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) logger.info(f"Unique split values: {unique}") - if unique == {0, 2}: df = split_data_0_2( df, @@ -1256,9 +1193,7 @@ logger.info("Using fixed split as-is.") else: raise ValueError(f"Unexpected split values: {unique}") - return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info - except Exception: logger.error("Error processing fixed split", exc_info=True) raise @@ -1274,14 +1209,11 @@ """Execute the full workflow end-to-end.""" logger.info("Starting workflow...") self.args.output_dir.mkdir(parents=True, exist_ok=True) - try: self._create_temp_dirs() self._extract_images() csv_path, split_cfg, split_info = self._prepare_data() - use_pretrained = self.args.use_pretrained or self.args.fine_tune - backend_args = { "model_name": self.args.model_name, "fine_tune": self.args.fine_tune, @@ -1295,13 +1227,12 @@ "early_stop": self.args.early_stop, "label_column_data_path": csv_path, "augmentation": self.args.augmentation, + "threshold": self.args.threshold, } yaml_str = self.backend.prepare_config(backend_args, split_cfg) - config_file = self.temp_dir / TEMP_CONFIG_FILENAME config_file.write_text(yaml_str) logger.info(f"Wrote backend config: {config_file}") - self.backend.run_experiment( csv_path, config_file, @@ -1349,8 +1280,6 @@ aug_list = [] for tok in aug_string.split(","): key = tok.strip() - if not key: - continue if key not in mapping: valid = ", ".join(mapping.keys()) raise ValueError(f"Unknown augmentation '{key}'. Valid choices: {valid}") @@ -1428,7 +1357,7 @@ parser.add_argument( "--validation-size", type=float, - default=0.1, + default=0.15, help="Fraction for validation (0.0–1.0)", ) parser.add_argument( @@ -1472,9 +1401,16 @@ "E.g. --augmentation random_horizontal_flip,random_rotate" ), ) - + parser.add_argument( + "--threshold", + type=float, + default=None, + help=( + "Decision threshold for binary classification (0.0–1.0)." + "Overrides default 0.5." + ) + ) args = parser.parse_args() - if not 0.0 <= args.validation_size <= 1.0: parser.error("validation-size must be between 0.0 and 1.0") if not args.csv_file.is_file(): @@ -1487,10 +1423,8 @@ setattr(args, "augmentation", augmentation_setup) except ValueError as e: parser.error(str(e)) - backend_instance = LudwigDirectBackend() orchestrator = WorkflowOrchestrator(args, backend_instance) - exit_code = 0 try: orchestrator.run() @@ -1505,7 +1439,6 @@ if __name__ == "__main__": try: import ludwig - logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") except ImportError: logger.error( @@ -1513,5 +1446,4 @@ "('pip install ludwig[image]')" ) sys.exit(1) - main()
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/plotly_plots.py Thu Aug 14 14:53:10 2025 +0000 @@ -0,0 +1,148 @@ +import json +from typing import Dict, List, Optional + +import numpy as np +import plotly.graph_objects as go +import plotly.io as pio + + +def build_classification_plots( + test_stats_path: str, + training_stats_path: Optional[str] = None, +) -> List[Dict[str, str]]: + """ + Read Ludwig’s test_statistics.json and build three interactive Plotly panels: + - Confusion Matrix + - ROC-AUC + - Classification Report Heatmap + + Returns a list of dicts, each with: + { + "title": <plot title>, + "html": <HTML fragment for embedding> + } + """ + # --- Load test stats --- + with open(test_stats_path, "r") as f: + test_stats = json.load(f) + label_stats = test_stats["label"] + + # common sizing + cell = 40 + n_classes = len(label_stats["confusion_matrix"]) + side_px = max(cell * n_classes + 200, 600) + common_cfg = {"displayModeBar": True, "scrollZoom": True} + + plots: List[Dict[str, str]] = [] + + # 0) Confusion Matrix + cm = np.array(label_stats["confusion_matrix"], dtype=int) + labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])]) + total = cm.sum() + + fig_cm = go.Figure( + go.Heatmap( + z=cm, + x=labels, + y=labels, + colorscale="Blues", + showscale=True, + colorbar=dict(title="Count"), + ) + ) + fig_cm.update_traces(xgap=2, ygap=2) + fig_cm.update_layout( + title=dict(text="Confusion Matrix", x=0.5), + xaxis_title="Predicted", + yaxis_title="Observed", + yaxis_autorange="reversed", + width=side_px, + height=side_px, + margin=dict(t=100, l=80, r=80, b=80), + ) + + # annotate counts and percentages + mval = cm.max() if cm.size else 0 + thresh = mval / 2 + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + v = cm[i, j] + pct = (v / total * 100) if total > 0 else 0 + color = "white" if v > thresh else "black" + fig_cm.add_annotation( + x=labels[j], + y=labels[i], + text=f"<b>{v}</b>", + showarrow=False, + font=dict(color=color, size=14), + xanchor="center", + yanchor="bottom", + yshift=2, + ) + fig_cm.add_annotation( + x=labels[j], + y=labels[i], + text=f"{pct:.1f}%", + showarrow=False, + font=dict(color=color, size=13), + xanchor="center", + yanchor="top", + yshift=-2, + ) + + plots.append({ + "title": "Confusion Matrix", + "html": pio.to_html( + fig_cm, + full_html=False, + include_plotlyjs="cdn", + config=common_cfg + ) + }) + + # 2) Classification Report Heatmap + pcs = label_stats.get("per_class_stats", {}) + if pcs: + classes = list(pcs.keys()) + metrics = ["precision", "recall", "f1_score"] + z, txt = [], [] + for c in classes: + row, trow = [], [] + for m in metrics: + val = pcs[c].get(m, 0) + row.append(val) + trow.append(f"{val:.2f}") + z.append(row) + txt.append(trow) + + fig_cr = go.Figure( + go.Heatmap( + z=z, + x=metrics, + y=[str(c) for c in classes], + text=txt, + texttemplate="%{text}", + colorscale="Reds", + showscale=True, + colorbar=dict(title="Value"), + ) + ) + fig_cr.update_layout( + title="Classification Report", + xaxis_title="", + yaxis_title="Class", + width=side_px, + height=side_px, + margin=dict(t=80, l=80, r=80, b=80), + ) + plots.append({ + "title": "Classification Report", + "html": pio.to_html( + fig_cr, + full_html=False, + include_plotlyjs=False, + config=common_cfg + ) + }) + + return plots
--- a/utils.py Fri Aug 08 13:06:28 2025 +0000 +++ b/utils.py Thu Aug 14 14:53:10 2025 +0000 @@ -8,6 +8,8 @@ <head> <meta charset="UTF-8"> <title>Galaxy-Ludwig Report</title> + + <!-- your existing styles --> <style> body { font-family: Arial, sans-serif; @@ -32,29 +34,21 @@ color: #4CAF50; padding-bottom: 5px; } + /* baseline table setup */ table { border-collapse: collapse; margin: 20px 0; width: 100%; - table-layout: fixed; /* Enforces consistent column widths */ + table-layout: fixed; } table, th, td { border: 1px solid #ddd; } th, td { padding: 8px; - text-align: center; /* Center-align text */ - vertical-align: middle; /* Center-align content vertically */ - word-wrap: break-word; /* Break long words to avoid overflow */ - } - th:first-child, td:first-child { - width: 5%; /* Smaller width for the first column */ - } - th:nth-child(2), td:nth-child(2) { - width: 50%; /* Wider for the metric/description column */ - } - th:last-child, td:last-child { - width: 25%; /* Value column gets remaining space */ + text-align: center; + vertical-align: middle; + word-wrap: break-word; } th { background-color: #4CAF50; @@ -68,7 +62,105 @@ max-width: 100%; height: auto; } + + /* ------------------- + SORTABLE COLUMNS + ------------------- */ + table.performance-summary th.sortable { + cursor: pointer; + position: relative; + user-select: none; + } + /* hide arrows by default */ + table.performance-summary th.sortable::after { + content: ''; + position: absolute; + right: 12px; + top: 50%; + transform: translateY(-50%); + font-size: 0.8em; + color: #666; + } + /* three states */ + table.performance-summary th.sortable.sorted-none::after { + content: '⇅'; + } + table.performance-summary th.sortable.sorted-asc::after { + content: '↑'; + } + table.performance-summary th.sortable.sorted-desc::after { + content: '↓'; + } </style> + + <!-- sorting script --> + <script> + document.addEventListener('DOMContentLoaded', () => { + // 1) record each row's original position + document.querySelectorAll('table.performance-summary tbody').forEach(tbody => { + Array.from(tbody.rows).forEach((row, i) => { + row.dataset.originalOrder = i; + }); + }); + + const getText = cell => cell.innerText.trim(); + const comparer = (idx, asc) => (a, b) => { + const v1 = getText(a.children[idx]); + const v2 = getText(b.children[idx]); + const n1 = parseFloat(v1), n2 = parseFloat(v2); + if (!isNaN(n1) && !isNaN(n2)) { + return asc ? n1 - n2 : n2 - n1; + } + return asc + ? v1.localeCompare(v2) + : v2.localeCompare(v1); + }; + + document + .querySelectorAll('table.performance-summary th.sortable') + .forEach(th => { + // initialize to "none" state + th.classList.add('sorted-none'); + th.addEventListener('click', () => { + const table = th.closest('table'); + const allTh = table.querySelectorAll('th.sortable'); + + // 1) determine current state BEFORE clearing classes + let curr = th.classList.contains('sorted-asc') + ? 'asc' + : th.classList.contains('sorted-desc') + ? 'desc' + : 'none'; + // 2) cycle to next state + let next = curr === 'none' + ? 'asc' + : curr === 'asc' + ? 'desc' + : 'none'; + + // 3) clear all sort markers + allTh.forEach(h => + h.classList.remove('sorted-none','sorted-asc','sorted-desc') + ); + // 4) apply the new marker + th.classList.add(`sorted-${next}`); + + // 5) sort or restore original order + const tbody = table.querySelector('tbody'); + let rows = Array.from(tbody.rows); + if (next === 'none') { + rows.sort((a, b) => + a.dataset.originalOrder - b.dataset.originalOrder + ); + } else { + const idx = Array.from(th.parentNode.children).indexOf(th); + rows.sort(comparer(idx, next === 'asc')); + } + rows.forEach(r => tbody.appendChild(r)); + }); + }); + }); + </script> </head> <body> <div class="container"> @@ -203,7 +295,7 @@ </style> <div class="tabs"> - <div class="tab active" onclick="showTab('metrics')">Config & Results Summary</div> + <div class="tab active" onclick="showTab('metrics')">Config and Results Summary</div> <div class="tab" onclick="showTab('trainval')">Train/Validation Results</div> <div class="tab" onclick="showTab('test')">Test Results</div> <!-- always-visible help button --> @@ -232,122 +324,193 @@ def get_metrics_help_modal() -> str: - modal_html = """ -<div id="metricsHelpModal" class="modal"> - <div class="modal-content"> - <span class="close">×</span> - <h2>Model Evaluation Metrics — Help Guide</h2> - <div class="metrics-guide"> - <h3>1) General Metrics</h3> - <p><strong>Loss:</strong> Measures the difference between predicted and actual values. Lower is better. Often used for optimization during training.</p> - <p><strong>Accuracy:</strong> Proportion of correct predictions among all predictions. Simple but can be misleading for imbalanced datasets.</p> - <p><strong>Micro Accuracy:</strong> Calculates accuracy by summing up all individual true positives and true negatives across all classes, making it suitable for multiclass or multilabel problems.</p> - <p><strong>Token Accuracy:</strong> Measures how often the predicted tokens (e.g., in sequences) match the true tokens. Useful in sequence prediction tasks like NLP.</p> - <h3>2) Precision, Recall & Specificity</h3> - <p><strong>Precision:</strong> Out of all positive predictions, how many were correct. Precision = TP / (TP + FP). Helps when false positives are costly.</p> - <p><strong>Recall (Sensitivity):</strong> Out of all actual positives, how many were predicted correctly. Recall = TP / (TP + FN). Important when missing positives is risky.</p> - <p><strong>Specificity:</strong> True negative rate. Measures how well the model identifies negatives. Specificity = TN / (TN + FP). Useful in medical testing to avoid false alarms.</p> - <h3>3) Macro, Micro, and Weighted Averages</h3> - <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric across all classes, treating each class equally, regardless of class frequency. Best when class sizes are balanced.</p> - <p><strong>Micro Precision / Recall / F1:</strong> Aggregates TP, FP, FN across all classes before computing the metric. Gives a global view and is ideal for class-imbalanced problems.</p> - <p><strong>Weighted Precision / Recall / F1:</strong> Averages each metric across classes, weighted by the number of true instances per class. Balances importance of classes based on frequency.</p> - <h3>4) Average Precision (PR-AUC Variants)</h3> - <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged across all classes equally. Useful for balanced multi-class problems.</p> - <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC using all instances. Best for imbalanced data or multi-label classification.</p> - <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged across individual samples (not classes). Ideal for multi-label problems where each sample can belong to multiple classes.</p> - <h3>5) ROC-AUC Variants</h3> - <p><strong>ROC-AUC:</strong> Measures model's ability to distinguish between classes. AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p> - <p><strong>Macro ROC-AUC:</strong> Averages the AUC across all classes equally. Suitable when classes are balanced and of equal importance.</p> - <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions across all classes. Useful in multiclass or multilabel settings with imbalance.</p> - <h3>6) Ranking Metrics</h3> - <p><strong>Hits at K:</strong> Measures whether the true label is among the top-K predictions. Common in recommendation systems and retrieval tasks.</p> - <h3>7) Confusion Matrix Stats (Per Class)</h3> - <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions for positives and negatives respectively.</p> - <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions — false alarms and missed detections.</p> - <h3>8) Other Useful Metrics</h3> - <p><strong>Cohen's Kappa:</strong> Measures agreement between predicted and actual values adjusted for chance. Useful for multiclass classification with imbalanced labels.</p> - <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure of prediction quality that takes into account TP, TN, FP, and FN. Particularly effective for imbalanced datasets.</p> - <h3>9) Metric Recommendations</h3> - <ul> - <li>Use <strong>Accuracy + F1</strong> for balanced data.</li> - <li>Use <strong>Precision, Recall, ROC-AUC</strong> for imbalanced datasets.</li> - <li>Use <strong>Average Precision Micro</strong> for multilabel or class-imbalanced problems.</li> - <li>Use <strong>Macro scores</strong> when all classes should be treated equally.</li> - <li>Use <strong>Weighted scores</strong> when class imbalance should be accounted for without ignoring small classes.</li> - <li>Use <strong>Confusion Matrix stats</strong> to analyze class-wise performance.</li> - <li>Use <strong>Hits at K</strong> for recommendation or ranking-based tasks.</li> - </ul> - </div> - </div> -</div> -""" - modal_css = """ -<style> -.modal { - display: none; - position: fixed; - z-index: 1; - left: 0; - top: 0; - width: 100%; - height: 100%; - overflow: auto; - background-color: rgba(0,0,0,0.4); -} -.modal-content { - background-color: #fefefe; - margin: 15% auto; - padding: 20px; - border: 1px solid #888; - width: 80%; - max-width: 800px; -} -.close { - color: #aaa; - float: right; - font-size: 28px; - font-weight: bold; -} -.close:hover, -.close:focus { - color: black; - text-decoration: none; - cursor: pointer; -} -.metrics-guide h3 { - margin-top: 20px; -} -.metrics-guide p { - margin: 5px 0; -} -.metrics-guide ul { - margin: 10px 0; - padding-left: 20px; -} -</style> -""" - modal_js = """ -<script> -document.addEventListener("DOMContentLoaded", function() { - var modal = document.getElementById("metricsHelpModal"); - var openBtn = document.getElementById("openMetricsHelp"); - var span = document.getElementsByClassName("close")[0]; - if (openBtn && modal) { - openBtn.onclick = function() { - modal.style.display = "block"; - }; - } - if (span && modal) { - span.onclick = function() { - modal.style.display = "none"; - }; - } - window.onclick = function(event) { - if (event.target == modal) { - modal.style.display = "none"; - } - } -}); -</script> -""" + modal_html = ( + '<div id="metricsHelpModal" class="modal">' + ' <div class="modal-content">' + ' <span class="close">×</span>' + ' <h2>Model Evaluation Metrics — Help Guide</h2>' + ' <div class="metrics-guide">' + ' <h3>1) General Metrics (Regression and Classification)</h3>' + ' <p><strong>Loss (Regression & Classification):</strong> ' + 'Measures the difference between predicted and actual values, ' + 'optimized during training. Lower is better. ' + 'For regression, this is often Mean Squared Error (MSE) or ' + 'Mean Absolute Error (MAE). For classification, it’s typically ' + 'cross-entropy or log loss.</p>' + ' <h3>2) Regression Metrics</h3>' + ' <p><strong>Mean Absolute Error (MAE):</strong> ' + 'Average of absolute differences between predicted and actual values, ' + 'in the same units as the target. Use for interpretable error measurement ' + 'when all errors are equally important. Less sensitive to outliers than MSE.</p>' + ' <p><strong>Mean Squared Error (MSE):</strong> ' + 'Average of squared differences between predicted and actual values. ' + 'Penalizes larger errors more heavily, useful when large deviations are critical. ' + 'Often used as the loss function in regression.</p>' + ' <p><strong>Root Mean Squared Error (RMSE):</strong> ' + 'Square root of MSE, in the same units as the target. ' + 'Balances interpretability and sensitivity to large errors. ' + 'Widely used for regression evaluation.</p>' + ' <p><strong>Mean Absolute Percentage Error (MAPE):</strong> ' + 'Average absolute error as a percentage of actual values. ' + 'Scale-independent, ideal for comparing relative errors across datasets. ' + 'Avoid when actual values are near zero.</p>' + ' <p><strong>Root Mean Squared Percentage Error (RMSPE):</strong> ' + 'Square root of mean squared percentage error. Scale-independent, ' + 'penalizes larger relative errors more than MAPE. Use for forecasting ' + 'or when relative accuracy matters.</p>' + ' <p><strong>R² Score:</strong> Proportion of variance in the target ' + 'explained by the model. Ranges from negative infinity to 1 (perfect prediction). ' + 'Use to assess model fit; negative values indicate poor performance ' + 'compared to predicting the mean.</p>' + ' <h3>3) Classification Metrics</h3>' + ' <p><strong>Accuracy:</strong> Proportion of correct predictions ' + 'among all predictions. Simple but misleading for imbalanced datasets, ' + 'where high accuracy may hide poor performance on minority classes.</p>' + ' <p><strong>Micro Accuracy:</strong> Sums true positives and true negatives ' + 'across all classes before computing accuracy. Suitable for multiclass or ' + 'multilabel problems with imbalanced data.</p>' + ' <p><strong>Token Accuracy:</strong> Measures how often predicted tokens ' + '(e.g., in sequences) match true tokens. Common in NLP tasks like text generation ' + 'or token classification.</p>' + ' <p><strong>Precision:</strong> Proportion of positive predictions that are ' + 'correct (TP / (TP + FP)). Use when false positives are costly, e.g., spam detection.</p>' + ' <p><strong>Recall (Sensitivity):</strong> Proportion of actual positives ' + 'correctly predicted (TP / (TP + FN)). Use when missing positives is risky, ' + 'e.g., disease detection.</p>' + ' <p><strong>Specificity:</strong> True negative rate (TN / (TN + FP)). ' + 'Measures ability to identify negatives. Useful in medical testing to avoid ' + 'false alarms.</p>' + ' <h3>4) Classification: Macro, Micro, and Weighted Averages</h3>' + ' <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric ' + 'across all classes, treating each equally. Best for balanced datasets where ' + 'all classes are equally important.</p>' + ' <p><strong>Micro Precision / Recall / F1:</strong> Aggregates true positives, ' + 'false positives, and false negatives across all classes before computing. ' + 'Ideal for imbalanced or multilabel classification.</p>' + ' <p><strong>Weighted Precision / Recall / F1:</strong> Averages metrics ' + 'across classes, weighted by the number of true instances per class. Balances ' + 'class importance based on frequency.</p>' + ' <h3>5) Classification: Average Precision (PR-AUC Variants)</h3>' + ' <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged ' + 'equally across classes. Use for balanced multiclass problems.</p>' + ' <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC ' + 'using all instances. Best for imbalanced or multilabel classification.</p>' + ' <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged ' + 'across individual samples. Ideal for multilabel tasks where samples have multiple ' + 'labels.</p>' + ' <h3>6) Classification: ROC-AUC Variants</h3>' + ' <p><strong>ROC-AUC:</strong> Measures ability to distinguish between classes. ' + 'AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p>' + ' <p><strong>Macro ROC-AUC:</strong> Averages AUC across all classes equally. ' + 'Suitable for balanced multiclass problems.</p>' + ' <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions ' + 'across all classes. Useful for imbalanced or multilabel settings.</p>' + ' <h3>7) Classification: Confusion Matrix Stats (Per Class)</h3>' + ' <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions ' + 'for positives and negatives, respectively.</p>' + ' <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions ' + '— false alarms and missed detections.</p>' + ' <h3>8) Classification: Ranking Metrics</h3>' + ' <p><strong>Hits at K:</strong> Measures whether the true label is among the ' + 'top-K predictions. Common in recommendation systems and retrieval tasks.</p>' + ' <h3>9) Other Metrics (Classification)</h3>' + ' <p><strong>Cohen\'s Kappa:</strong> Measures agreement between predicted and ' + 'actual labels, adjusted for chance. Useful for multiclass classification with ' + 'imbalanced data.</p>' + ' <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure ' + 'using TP, TN, FP, and FN. Effective for imbalanced datasets.</p>' + ' <h3>10) Metric Recommendations</h3>' + ' <ul>' + ' <li><strong>Regression:</strong> Use <strong>RMSE</strong> or ' + '<strong>MAE</strong> for general evaluation, <strong>MAPE</strong> for relative ' + 'errors, and <strong>R²</strong> to assess model fit. Use <strong>MSE</strong> or ' + '<strong>RMSPE</strong> when large errors are critical.</li>' + ' <li><strong>Classification (Balanced Data):</strong> Use <strong>Accuracy</strong> ' + 'and <strong>F1</strong> for overall performance.</li>' + ' <li><strong>Classification (Imbalanced Data):</strong> Use <strong>Precision</strong>, ' + '<strong>Recall</strong>, and <strong>ROC-AUC</strong> to focus on minority class ' + 'performance.</li>' + ' <li><strong>Multilabel or Imbalanced Classification:</strong> Use ' + '<strong>Micro Precision/Recall/F1</strong> or <strong>Micro ROC-AUC</strong>.</li>' + ' <li><strong>Balanced Multiclass:</strong> Use <strong>Macro Precision/Recall/F1</strong> ' + 'or <strong>Macro ROC-AUC</strong>.</li>' + ' <li><strong>Class Frequency Matters:</strong> Use <strong>Weighted Precision/Recall/F1</strong> ' + 'to account for class imbalance.</li>' + ' <li><strong>Recommendation/Ranking:</strong> Use <strong>Hits at K</strong> for retrieval tasks.</li>' + ' <li><strong>Detailed Analysis:</strong> Use <strong>Confusion Matrix stats</strong> ' + 'for class-wise performance in classification.</li>' + ' </ul>' + ' </div>' + ' </div>' + '</div>' + ) + modal_css = ( + "<style>" + ".modal {" + " display: none;" + " position: fixed;" + " z-index: 1;" + " left: 0;" + " top: 0;" + " width: 100%;" + " height: 100%;" + " overflow: auto;" + " background-color: rgba(0,0,0,0.4);" + "}" + ".modal-content {" + " background-color: #fefefe;" + " margin: 15% auto;" + " padding: 20px;" + " border: 1px solid #888;" + " width: 80%;" + " max-width: 800px;" + "}" + ".close {" + " color: #aaa;" + " float: right;" + " font-size: 28px;" + " font-weight: bold;" + "}" + ".close:hover," + ".close:focus {" + " color: black;" + " text-decoration: none;" + " cursor: pointer;" + "}" + ".metrics-guide h3 {" + " margin-top: 20px;" + "}" + ".metrics-guide p {" + " margin: 5px 0;" + "}" + ".metrics-guide ul {" + " margin: 10px 0;" + " padding-left: 20px;" + "}" + "</style>" + ) + modal_js = ( + "<script>" + 'document.addEventListener("DOMContentLoaded", function() {' + ' var modal = document.getElementById("metricsHelpModal");' + ' var openBtn = document.getElementById("openMetricsHelp");' + ' var span = document.getElementsByClassName("close")[0];' + " if (openBtn && modal) {" + " openBtn.onclick = function() {" + " modal.style.display = \"block\";" + " };" + " }" + " if (span && modal) {" + " span.onclick = function() {" + " modal.style.display = \"none\";" + " };" + " }" + " window.onclick = function(event) {" + " if (event.target == modal) {" + " modal.style.display = \"none\";" + " }" + " }" + "});" + "</script>" + ) return modal_css + modal_html + modal_js