Mercurial > repos > goeckslab > image_learner
diff image_learner_cli.py @ 8:85e6f4b2ad18 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 8a42eb9b33df7e1df5ad5153b380e20b910a05b6
author | goeckslab |
---|---|
date | Thu, 14 Aug 2025 14:53:10 +0000 |
parents | 801a8b6973fb |
children |
line wrap: on
line diff
--- a/image_learner_cli.py Fri Aug 08 13:06:28 2025 +0000 +++ b/image_learner_cli.py Thu Aug 14 14:53:10 2025 +0000 @@ -31,6 +31,7 @@ ) from ludwig.utils.data_utils import get_split_path from ludwig.visualize import get_visualizations_registry +from plotly_plots import build_classification_plots from sklearn.model_selection import train_test_split from utils import ( build_tabbed_html, @@ -52,6 +53,7 @@ config: dict, split_info: Optional[str] = None, training_progress: dict = None, + output_type: Optional[str] = None, ) -> str: display_keys = [ "task_type", @@ -63,114 +65,119 @@ "learning_rate", "random_seed", "early_stop", + "threshold", ] - rows = [] - for key in display_keys: - val = config.get(key, "N/A") - if key == "task_type": - val = val.title() if isinstance(val, str) else val - if key == "batch_size": - if val is not None: - val = int(val) - else: - if training_progress: - val = "Auto-selected batch size by Ludwig:<br>" - resolved_val = training_progress.get("batch_size") - val += f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>" + val = config.get(key, None) + if key == "threshold": + if output_type != "binary": + continue + val = val if val is not None else 0.5 + val_str = f"{val:.2f}" + if val == 0.5: + val_str += " (default)" + else: + if key == "task_type": + val_str = val.title() if isinstance(val, str) else "N/A" + elif key == "batch_size": + if val is not None: + val_str = int(val) + else: + if training_progress: + resolved_val = training_progress.get("batch_size") + val_str = ( + "Auto-selected batch size by Ludwig:<br>" + f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>" + ) + else: + val_str = "auto" + elif key == "learning_rate": + if val is not None and val != "auto": + val_str = f"{val:.6f}" else: - val = "auto" - if key == "learning_rate": - resolved_val = None - if val is None or val == "auto": - if training_progress: - resolved_val = training_progress.get("learning_rate") - val = ( - "Auto-selected learning rate by Ludwig:<br>" - f"<span style='font-size: 0.85em;'>" - f"{resolved_val if resolved_val else val}</span><br>" - "<span style='font-size: 0.85em;'>" - "Based on model architecture and training setup " - "(e.g., fine-tuning).<br>" - "See <a href='https://ludwig.ai/latest/configuration/trainer/" - "#trainer-parameters' target='_blank'>" - "Ludwig Trainer Parameters</a> for details." - "</span>" - ) + if training_progress: + resolved_val = training_progress.get("learning_rate") + val_str = ( + "Auto-selected learning rate by Ludwig:<br>" + f"<span style='font-size: 0.85em;'>" + f"{resolved_val if resolved_val else 'auto'}</span><br>" + "<span style='font-size: 0.85em;'>" + "Based on model architecture and training setup " + "(e.g., fine-tuning).<br>" + "</span>" + ) + else: + val_str = ( + "Auto-selected by Ludwig<br>" + "<span style='font-size: 0.85em;'>" + "Automatically tuned based on architecture and dataset.<br>" + "See <a href='https://ludwig.ai/latest/configuration/trainer/" + "#trainer-parameters' target='_blank'>" + "Ludwig Trainer Parameters</a> for details." + "</span>" + ) + elif key == "epochs": + if val is None: + val_str = "N/A" else: - val = ( - "Auto-selected by Ludwig<br>" - "<span style='font-size: 0.85em;'>" - "Automatically tuned based on architecture and dataset.<br>" - "See <a href='https://ludwig.ai/latest/configuration/trainer/" - "#trainer-parameters' target='_blank'>" - "Ludwig Trainer Parameters</a> for details." - "</span>" - ) + if ( + training_progress + and "epoch" in training_progress + and val > training_progress["epoch"] + ): + val_str = ( + f"Because of early stopping: the training " + f"stopped at epoch {training_progress['epoch']}" + ) + else: + val_str = val else: - val = f"{val:.6f}" - if key == "epochs": - if ( - training_progress - and "epoch" in training_progress - and val > training_progress["epoch"] - ): - val = ( - f"Because of early stopping: the training " - f"stopped at epoch {training_progress['epoch']}" - ) - - if val is None: - continue + val_str = val if val is not None else "N/A" + if val_str == "N/A" and key not in ["task_type"]: # Skip if N/A for non-essential + continue rows.append( f"<tr>" f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" f"{key.replace('_', ' ').title()}</td>" f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" - f"{val}</td>" + f"{val_str}</td>" f"</tr>" ) - aug_cfg = config.get("augmentation") if aug_cfg: types = [str(a.get("type", "")) for a in aug_cfg] aug_val = ", ".join(types) rows.append( - "<tr>" - "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>" - "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" - f"{aug_val}</td>" - "</tr>" + f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>" + f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{aug_val}</td></tr>" ) - if split_info: rows.append( - f"<tr>" - f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" - f"Data Split</td>" - f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" - f"{split_info}</td>" - f"</tr>" + f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td>" + f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{split_info}</td></tr>" ) - - return ( - "<h2 style='text-align: center;'>Training Setup</h2>" - "<div style='display: flex; justify-content: center;'>" - "<table style='border-collapse: collapse; width: 60%; table-layout: auto;'>" - "<thead><tr>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>" - "Parameter</th>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>" - "Value</th>" - "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" - "<p style='text-align: center; font-size: 0.9em;'>" - "Model trained using Ludwig.<br>" - "If want to learn more about Ludwig default settings," - "please check their <a href='https://ludwig.ai' target='_blank'>" - "website(ludwig.ai)</a>." - "</p><hr>" - ) + html = f""" + <h2 style="text-align: center;">Model and Training Summary</h2> + <div style="display: flex; justify-content: center;"> + <table style="border-collapse: collapse; width: 100%; table-layout: fixed;"> + <thead><tr> + <th style="padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Parameter</th> + <th style="padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Value</th> + </tr></thead> + <tbody> + {''.join(rows)} + </tbody> + </table> + </div><br> + <p style="text-align: center; font-size: 0.9em;"> + Model trained using <a href="https://ludwig.ai/" target="_blank" rel="noopener noreferrer">Ludwig</a>. + <a href="https://ludwig.ai/latest/configuration/" target="_blank" rel="noopener noreferrer"> + Ludwig documentation provides detailed information about default model and training parameters + </a> + </p><hr> + """ + return html def detect_output_type(test_stats): @@ -244,7 +251,6 @@ "roc_auc": get_last_value(label_stats, "roc_auc"), "hits_at_k": get_last_value(label_stats, "hits_at_k"), } - # Test metrics: dynamic extraction according to exclusions test_label_stats = test_stats.get("label", {}) if not test_label_stats: @@ -252,13 +258,11 @@ else: combined_stats = test_stats.get("combined", {}) overall_stats = test_label_stats.get("overall_stats", {}) - # Define exclusions if output_type == "binary": exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} else: exclude = {"per_class_stats", "confusion_matrix"} - # 1. Get all scalar test_label_stats not excluded test_metrics = {} for k, v in test_label_stats.items(): @@ -268,17 +272,13 @@ continue if isinstance(v, (int, float, str, bool)): test_metrics[k] = v - # 2. Add overall_stats (flattened) for k, v in overall_stats.items(): test_metrics[k] = v - # 3. Optionally include combined/loss if present and not already if "loss" in combined_stats and "loss" not in test_metrics: test_metrics["loss"] = combined_stats["loss"] - metrics["test"] = test_metrics - return metrics @@ -291,6 +291,11 @@ ) +# ----------------------------------------- +# 2) MODEL PERFORMANCE (Train/Val/Test) TABLE +# ----------------------------------------- + + def format_stats_table_html(train_stats: dict, test_stats: dict) -> str: """Formats a combined HTML table for training, validation, and test metrics.""" output_type = detect_output_type(test_stats) @@ -310,35 +315,33 @@ te = all_metrics["test"].get(metric_key) if all(x is not None for x in [t, v, te]): rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"]) - if not rows: return "<table><tr><td>No metric values found.</td></tr></table>" - html = ( "<h2 style='text-align: center;'>Model Performance Summary</h2>" "<div style='display: flex; justify-content: center;'>" - "<table style='border-collapse: collapse; table-layout: auto;'>" + "<table class='performance-summary' style='border-collapse: collapse;'>" "<thead><tr>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " - "white-space: nowrap;'>Metric</th>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " - "white-space: nowrap;'>Train</th>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " - "white-space: nowrap;'>Validation</th>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " - "white-space: nowrap;'>Test</th>" + "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>" + "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th>" + "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th>" + "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th>" "</tr></thead><tbody>" ) for row in rows: html += generate_table_row( row, - "padding: 10px; border: 1px solid #ccc; text-align: center; " - "white-space: nowrap;", + "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" ) html += "</tbody></table></div><br>" return html +# ------------------------------------------- +# 3) TRAIN/VALIDATION PERFORMANCE SUMMARY TABLE +# ------------------------------------------- + + def format_train_val_stats_table_html(train_stats: dict, test_stats: dict) -> str: """Formats an HTML table for training and validation metrics.""" output_type = detect_output_type(test_stats) @@ -354,33 +357,32 @@ v = all_metrics["validation"].get(metric_key) if t is not None and v is not None: rows.append([display_name, f"{t:.4f}", f"{v:.4f}"]) - if not rows: return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>" - html = ( "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>" "<div style='display: flex; justify-content: center;'>" - "<table style='border-collapse: collapse; table-layout: auto;'>" + "<table class='performance-summary' style='border-collapse: collapse;'>" "<thead><tr>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " - "white-space: nowrap;'>Metric</th>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " - "white-space: nowrap;'>Train</th>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " - "white-space: nowrap;'>Validation</th>" + "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>" + "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th>" + "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th>" "</tr></thead><tbody>" ) for row in rows: html += generate_table_row( row, - "padding: 10px; border: 1px solid #ccc; text-align: center; " - "white-space: nowrap;", + "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" ) html += "</tbody></table></div><br>" return html +# ----------------------------------------- +# 4) TEST‐ONLY PERFORMANCE SUMMARY TABLE +# ----------------------------------------- + + def format_test_merged_stats_table_html( test_metrics: Dict[str, Optional[float]], ) -> str: @@ -391,26 +393,21 @@ value = test_metrics[key] if value is not None: rows.append([display_name, f"{value:.4f}"]) - if not rows: return "<table><tr><td>No test metric values found.</td></tr></table>" - html = ( "<h2 style='text-align: center;'>Test Performance Summary</h2>" "<div style='display: flex; justify-content: center;'>" - "<table style='border-collapse: collapse; table-layout: auto;'>" + "<table class='performance-summary' style='border-collapse: collapse;'>" "<thead><tr>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " - "white-space: nowrap;'>Metric</th>" - "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " - "white-space: nowrap;'>Test</th>" + "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>" + "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th>" "</tr></thead><tbody>" ) for row in rows: html += generate_table_row( row, - "padding: 10px; border: 1px solid #ccc; text-align: center; " - "white-space: nowrap;", + "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" ) html += "</tbody></table></div><br>" return html @@ -426,13 +423,10 @@ """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" out = df.copy() out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) - idx_train = out.index[out[split_column] == 0].tolist() - if not idx_train: logger.info("No rows with split=0; nothing to do.") return out - # Always use stratify if possible stratify_arr = None if label_column and label_column in out.columns: @@ -450,7 +444,6 @@ logger.info("Using stratified split for validation set") else: logger.warning("Only one label class found; cannot stratify") - if validation_size <= 0: logger.info("validation_size <= 0; keeping all as train.") return out @@ -458,7 +451,6 @@ logger.info("validation_size >= 1; moving all train → validation.") out.loc[idx_train, split_column] = 1 return out - # Always try stratified split first try: train_idx, val_idx = train_test_split( @@ -476,7 +468,6 @@ random_state=random_state, stratify=None, ) - out.loc[train_idx, split_column] = 0 out.loc[val_idx, split_column] = 1 out[split_column] = out[split_column].astype(int) @@ -492,31 +483,24 @@ ) -> pd.DataFrame: """Create a stratified random split when no split column exists.""" out = df.copy() - # initialize split column out[split_column] = 0 - if not label_column or label_column not in out.columns: logger.warning("No label column found; using random split without stratification") # fall back to simple random assignment indices = out.index.tolist() np.random.seed(random_state) np.random.shuffle(indices) - n_total = len(indices) n_train = int(n_total * split_probabilities[0]) n_val = int(n_total * split_probabilities[1]) - out.loc[indices[:n_train], split_column] = 0 out.loc[indices[n_train:n_train + n_val], split_column] = 1 out.loc[indices[n_train + n_val:], split_column] = 2 - return out.astype({split_column: int}) - # check if stratification is possible label_counts = out[label_column].value_counts() min_samples_per_class = label_counts.min() - # ensure we have enough samples for stratification: # Each class must have at least as many samples as the number of splits, # so that each split can receive at least one sample per class. @@ -529,19 +513,14 @@ indices = out.index.tolist() np.random.seed(random_state) np.random.shuffle(indices) - n_total = len(indices) n_train = int(n_total * split_probabilities[0]) n_val = int(n_total * split_probabilities[1]) - out.loc[indices[:n_train], split_column] = 0 out.loc[indices[n_train:n_train + n_val], split_column] = 1 out.loc[indices[n_train + n_val:], split_column] = 2 - return out.astype({split_column: int}) - logger.info("Using stratified random split for train/validation/test sets") - # first split: separate test set train_val_idx, test_idx = train_test_split( out.index.tolist(), @@ -549,7 +528,6 @@ random_state=random_state, stratify=out[label_column], ) - # second split: separate training and validation from remaining data val_size_adjusted = split_probabilities[1] / (split_probabilities[0] + split_probabilities[1]) train_idx, val_idx = train_test_split( @@ -558,21 +536,17 @@ random_state=random_state, stratify=out.loc[train_val_idx, label_column], ) - # assign split values out.loc[train_idx, split_column] = 0 out.loc[val_idx, split_column] = 1 out.loc[test_idx, split_column] = 2 - logger.info("Successfully applied stratified random split") logger.info(f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}") - return out.astype({split_column: int}) class Backend(Protocol): """Interface for a machine learning backend.""" - def prepare_config( self, config_params: Dict[str, Any], @@ -604,14 +578,12 @@ class LudwigDirectBackend: """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" - def prepare_config( self, config_params: Dict[str, Any], split_config: Dict[str, Any], ) -> str: logger.info("LudwigDirectBackend: Preparing YAML configuration.") - model_name = config_params.get("model_name", "resnet18") use_pretrained = config_params.get("use_pretrained", False) fine_tune = config_params.get("fine_tune", False) @@ -634,9 +606,7 @@ } else: encoder_config = {"type": raw_encoder} - batch_size_cfg = batch_size or "auto" - label_column_path = config_params.get("label_column_data_path") label_series = None if label_column_path is not None and Path(label_column_path).exists(): @@ -644,7 +614,6 @@ label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] except Exception as e: logger.warning(f"Could not read label column for task detection: {e}") - if ( label_series is not None and ptypes.is_numeric_dtype(label_series.dtype) @@ -653,9 +622,7 @@ task_type = "regression" else: task_type = "classification" - config_params["task_type"] = task_type - image_feat: Dict[str, Any] = { "name": IMAGE_PATH_COLUMN_NAME, "type": "image", @@ -663,7 +630,6 @@ } if config_params.get("augmentation") is not None: image_feat["augmentation"] = config_params["augmentation"] - if task_type == "regression": output_feat = { "name": LABEL_COLUMN_NAME, @@ -679,15 +645,15 @@ }, } val_metric = config_params.get("validation_metric", "mean_squared_error") - else: num_unique_labels = ( label_series.nunique() if label_series is not None else 2 ) output_type = "binary" if num_unique_labels == 2 else "category" output_feat = {"name": LABEL_COLUMN_NAME, "type": output_type} + if output_type == "binary" and config_params.get("threshold") is not None: + output_feat["threshold"] = float(config_params["threshold"]) val_metric = None - conf: Dict[str, Any] = { "model_type": "ecd", "input_features": [image_feat], @@ -707,7 +673,6 @@ "in_memory": False, }, } - logger.debug("LudwigDirectBackend: Config dict built.") try: yaml_str = yaml.dump(conf, sort_keys=False, indent=2) @@ -729,7 +694,6 @@ ) -> None: """Invoke Ludwig's internal experiment_cli function to run the experiment.""" logger.info("LudwigDirectBackend: Starting experiment execution.") - try: from ludwig.experiment import experiment_cli except ImportError as e: @@ -738,9 +702,7 @@ exc_info=True, ) raise RuntimeError("Ludwig import failed.") from e - output_dir.mkdir(parents=True, exist_ok=True) - try: experiment_cli( dataset=str(dataset_path), @@ -771,16 +733,13 @@ output_dir.glob("experiment_run*"), key=lambda p: p.stat().st_mtime, ) - if not exp_dirs: logger.warning(f"No experiment run directories found in {output_dir}") return None - progress_file = exp_dirs[-1] / "model" / "training_progress.json" if not progress_file.exists(): logger.warning(f"No training_progress.json found in {progress_file}") return None - try: with progress_file.open("r", encoding="utf-8") as f: data = json.load(f) @@ -816,7 +775,6 @@ def generate_plots(self, output_dir: Path) -> None: """Generate all registered Ludwig visualizations for the latest experiment run.""" logger.info("Generating all Ludwig visualizations…") - test_plots = { "compare_performance", "compare_classifiers_performance_from_prob", @@ -840,7 +798,6 @@ "learning_curves", "compare_classifiers_performance_subset", } - output_dir = Path(output_dir) exp_dirs = sorted( output_dir.glob("experiment_run*"), @@ -850,7 +807,6 @@ logger.warning(f"No experiment run dirs found in {output_dir}") return exp_dir = exp_dirs[-1] - viz_dir = exp_dir / "visualizations" viz_dir.mkdir(exist_ok=True) train_viz = viz_dir / "train" @@ -865,7 +821,6 @@ test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) - dataset_path = None split_file = None desc = exp_dir / DESCRIPTION_FILE_NAME @@ -874,7 +829,6 @@ cfg = json.load(f) dataset_path = _check(Path(cfg.get("dataset", ""))) split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) - output_feature = "" if desc.exists(): try: @@ -885,7 +839,6 @@ with open(test_stats, "r") as f: stats = json.load(f) output_feature = next(iter(stats.keys()), "") - viz_registry = get_visualizations_registry() for viz_name, viz_func in viz_registry.items(): if viz_name in train_plots: @@ -894,7 +847,6 @@ viz_dir_plot = test_viz else: continue - try: viz_func( training_statistics=[training_stats] if training_stats else [], @@ -914,7 +866,6 @@ logger.info(f"✔ Generated {viz_name}") except Exception as e: logger.warning(f"✘ Skipped {viz_name}: {e}") - logger.info(f"All visualizations written to {viz_dir}") def generate_html_report( @@ -930,7 +881,6 @@ report_path = cwd / report_name output_dir = Path(output_dir) output_type = None - exp_dirs = sorted( output_dir.glob("experiment_run*"), key=lambda p: p.stat().st_mtime, @@ -938,14 +888,11 @@ if not exp_dirs: raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") exp_dir = exp_dirs[-1] - base_viz_dir = exp_dir / "visualizations" train_viz_dir = base_viz_dir / "train" test_viz_dir = base_viz_dir / "test" - html = get_html_template() html += f"<h1>{title}</h1>" - metrics_html = "" train_val_metrics_html = "" test_metrics_html = "" @@ -971,7 +918,6 @@ logger.warning( f"Could not load stats for HTML report: {type(e).__name__}: {e}" ) - config_html = "" training_progress = self.get_training_process(output_dir) try: @@ -986,93 +932,77 @@ ) -> str: if not dir_path.exists(): return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" - + # collect every PNG imgs = list(dir_path.glob("*.png")) + # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files --- + imgs = [ + img for img in imgs + if not ( + img.name == "confusion_matrix.png" + or img.name.startswith("confusion_matrix__label_top") + or img.name == "roc_curves.png" + ) + ] if not imgs: return f"<h2>{title}</h2><p><em>No plots found.</em></p>" - - if title == "Test Visualizations" and output_type == "binary": + if output_type == "binary": order = [ - "confusion_matrix__label_top2.png", "roc_curves_from_prediction_statistics.png", "compare_performance_label.png", "confusion_matrix_entropy__label_top2.png", + # ...you can tweak ordering as needed ] img_names = {img.name: img for img in imgs} - ordered_imgs = [ - img_names[fname] for fname in order if fname in img_names - ] - remaining = sorted( - [ - img - for img in imgs - if img.name not in order and img.name != "roc_curves.png" - ] - ) - imgs = ordered_imgs + remaining - - elif title == "Test Visualizations" and output_type == "category": + ordered = [img_names[n] for n in order if n in img_names] + others = sorted(img for img in imgs if img.name not in order) + imgs = ordered + others + elif output_type == "category": unwanted = { "compare_classifiers_multiclass_multimetric__label_best10.png", "compare_classifiers_multiclass_multimetric__label_top10.png", "compare_classifiers_multiclass_multimetric__label_worst10.png", } display_order = [ - "confusion_matrix__label_top10.png", "roc_curves.png", "compare_performance_label.png", "compare_classifiers_performance_from_prob.png", - "compare_classifiers_multiclass_multimetric__label_sorted.png", "confusion_matrix_entropy__label_top10.png", ] - img_names = {img.name: img for img in imgs if img.name not in unwanted} - ordered_imgs = [ - img_names[fname] for fname in display_order if fname in img_names - ] - remaining = sorted( - [img for img in img_names.values() if img.name not in display_order] - ) - imgs = ordered_imgs + remaining - + # filter and order + valid_imgs = [img for img in imgs if img.name not in unwanted] + img_map = {img.name: img for img in valid_imgs} + ordered = [img_map[n] for n in display_order if n in img_map] + others = sorted(img for img in valid_imgs if img.name not in display_order) + imgs = ordered + others else: - if output_type == "category": - unwanted = { - "compare_classifiers_multiclass_multimetric__label_best10.png", - "compare_classifiers_multiclass_multimetric__label_top10.png", - "compare_classifiers_multiclass_multimetric__label_worst10.png", - } - imgs = sorted([img for img in imgs if img.name not in unwanted]) - else: - imgs = sorted(imgs) - - section_html = f"<h2 style='text-align: center;'>{title}</h2><div>" + # regression: just sort whatever's left + imgs = sorted(imgs) + # render each remaining PNG + html = "" for img in imgs: b64 = encode_image_to_base64(str(img)) - section_html += ( + img_title = img.stem.replace("_", " ").title() + html += ( + f"<h2 style='text-align: center;'>{img_title}</h2>" f'<div class="plot" style="margin-bottom:20px;text-align:center;">' - f"<h3>{img.stem.replace('_', ' ').title()}</h3>" f'<img src="data:image/png;base64,{b64}" ' f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' f"</div>" ) - section_html += "</div>" - return section_html + return html tab1_content = config_html + metrics_html - tab2_content = train_val_metrics_html + render_img_section( - "Training & Validation Visualizations", train_viz_dir + "Training and Validation Visualizations", train_viz_dir ) - # --- Predictions vs Ground Truth table --- preds_section = "" parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME - if parquet_path.exists(): + if output_type == "regression" and parquet_path.exists(): try: # 1) load predictions from Parquet df_preds = pd.read_parquet(parquet_path).reset_index(drop=True) # assume the column containing your model's prediction is named "prediction" - # or contains that substring: pred_col = next( (c for c in df_preds.columns if "prediction" in c.lower()), None, @@ -1080,40 +1010,58 @@ if pred_col is None: raise ValueError("No prediction column found in Parquet output") df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"}) - # 2) load ground truth for the test split from prepared CSV df_all = pd.read_csv(config["label_column_data_path"]) - df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][ - LABEL_COLUMN_NAME - ].reset_index(drop=True) - - # 3) concatenate side‐by‐side + df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][LABEL_COLUMN_NAME].reset_index(drop=True) + # 3) concatenate side-by-side df_table = pd.concat([df_gt, df_pred], axis=1) df_table.columns = [LABEL_COLUMN_NAME, "prediction"] - # 4) render as HTML preds_html = df_table.to_html(index=False, classes="predictions-table") preds_section = ( - "<h2 style='text-align: center;'>Predictions vs. Ground Truth</h2>" - "<div style='overflow-x:auto; margin-bottom:20px;'>" + "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>" + "<div style='overflow-y:auto; max-height:400px; overflow-x:auto; margin-bottom:20px;'>" + preds_html + "</div>" ) except Exception as e: logger.warning(f"Could not build Predictions vs GT table: {e}") - # Test tab = Metrics + Preds table + Visualizations - - tab3_content = ( - test_metrics_html - + preds_section - + render_img_section("Test Visualizations", test_viz_dir, output_type) - ) - + tab3_content = test_metrics_html + preds_section + if output_type in ("binary", "category"): + training_stats_path = exp_dir / "training_statistics.json" + interactive_plots = build_classification_plots( + str(test_stats_path), + str(training_stats_path), + ) + for plot in interactive_plots: + # 2) inject the static "roc_curves_from_prediction_statistics.png" + if plot["title"] == "ROC-AUC": + static_img = test_viz_dir / "roc_curves_from_prediction_statistics.png" + if static_img.exists(): + b64 = encode_image_to_base64(str(static_img)) + tab3_content += ( + "<h2 style='text-align: center;'>" + "Roc Curves From Prediction Statistics" + "</h2>" + f'<div class="plot" style="margin-bottom:20px;text-align:center;">' + f'<img src="data:image/png;base64,{b64}" ' + f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' + "</div>" + ) + # always render the plotly panels exactly as before + tab3_content += ( + f"<h2 style='text-align: center;'>{plot['title']}</h2>" + + plot["html"] + ) + tab3_content += render_img_section( + "Test Visualizations", + test_viz_dir, + output_type + ) # assemble the tabs and help modal tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) modal_html = get_metrics_help_modal() html += tabbed_html + modal_html + get_html_closing() - try: with open(report_path, "w") as f: f.write(html) @@ -1121,13 +1069,11 @@ except Exception as e: logger.error(f"Failed to write HTML report: {e}") raise - return report_path class WorkflowOrchestrator: """Manages the image-classification workflow.""" - def __init__(self, args: argparse.Namespace, backend: Backend): self.args = args self.backend = backend @@ -1167,19 +1113,16 @@ """Load CSV, update image paths, handle splits, and write prepared CSV.""" if not self.temp_dir or not self.image_extract_dir: raise RuntimeError("Temp dirs not initialized before data prep.") - try: df = pd.read_csv(self.args.csv_file) logger.info(f"Loaded CSV: {self.args.csv_file}") except Exception: logger.error("Error loading CSV file", exc_info=True) raise - required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} missing = required - set(df.columns) if missing: raise ValueError(f"Missing CSV columns: {', '.join(missing)}") - try: df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( lambda p: str((self.image_extract_dir / p).resolve()) @@ -1187,7 +1130,6 @@ except Exception: logger.error("Error updating image paths", exc_info=True) raise - if SPLIT_COLUMN_NAME in df.columns: df, split_config, split_info = self._process_fixed_split(df) else: @@ -1208,16 +1150,13 @@ f"{[int(p * 100) for p in self.args.split_probabilities]}% " f"for train/val/test with balanced label distribution." ) - final_csv = self.temp_dir / TEMP_CSV_FILENAME try: - df.to_csv(final_csv, index=False) logger.info(f"Saved prepared data to {final_csv}") except Exception: logger.error("Error saving prepared CSV", exc_info=True) raise - return final_csv, split_config, split_info def _process_fixed_split( @@ -1232,10 +1171,8 @@ ) if df[SPLIT_COLUMN_NAME].isna().any(): logger.warning("Split column contains non-numeric/missing values.") - unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) logger.info(f"Unique split values: {unique}") - if unique == {0, 2}: df = split_data_0_2( df, @@ -1256,9 +1193,7 @@ logger.info("Using fixed split as-is.") else: raise ValueError(f"Unexpected split values: {unique}") - return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info - except Exception: logger.error("Error processing fixed split", exc_info=True) raise @@ -1274,14 +1209,11 @@ """Execute the full workflow end-to-end.""" logger.info("Starting workflow...") self.args.output_dir.mkdir(parents=True, exist_ok=True) - try: self._create_temp_dirs() self._extract_images() csv_path, split_cfg, split_info = self._prepare_data() - use_pretrained = self.args.use_pretrained or self.args.fine_tune - backend_args = { "model_name": self.args.model_name, "fine_tune": self.args.fine_tune, @@ -1295,13 +1227,12 @@ "early_stop": self.args.early_stop, "label_column_data_path": csv_path, "augmentation": self.args.augmentation, + "threshold": self.args.threshold, } yaml_str = self.backend.prepare_config(backend_args, split_cfg) - config_file = self.temp_dir / TEMP_CONFIG_FILENAME config_file.write_text(yaml_str) logger.info(f"Wrote backend config: {config_file}") - self.backend.run_experiment( csv_path, config_file, @@ -1349,8 +1280,6 @@ aug_list = [] for tok in aug_string.split(","): key = tok.strip() - if not key: - continue if key not in mapping: valid = ", ".join(mapping.keys()) raise ValueError(f"Unknown augmentation '{key}'. Valid choices: {valid}") @@ -1428,7 +1357,7 @@ parser.add_argument( "--validation-size", type=float, - default=0.1, + default=0.15, help="Fraction for validation (0.0–1.0)", ) parser.add_argument( @@ -1472,9 +1401,16 @@ "E.g. --augmentation random_horizontal_flip,random_rotate" ), ) - + parser.add_argument( + "--threshold", + type=float, + default=None, + help=( + "Decision threshold for binary classification (0.0–1.0)." + "Overrides default 0.5." + ) + ) args = parser.parse_args() - if not 0.0 <= args.validation_size <= 1.0: parser.error("validation-size must be between 0.0 and 1.0") if not args.csv_file.is_file(): @@ -1487,10 +1423,8 @@ setattr(args, "augmentation", augmentation_setup) except ValueError as e: parser.error(str(e)) - backend_instance = LudwigDirectBackend() orchestrator = WorkflowOrchestrator(args, backend_instance) - exit_code = 0 try: orchestrator.run() @@ -1505,7 +1439,6 @@ if __name__ == "__main__": try: import ludwig - logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") except ImportError: logger.error( @@ -1513,5 +1446,4 @@ "('pip install ludwig[image]')" ) sys.exit(1) - main()