Mercurial > repos > goeckslab > image_learner
diff ludwig_backend.py @ 17:db9be962dc13 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
| author | goeckslab |
|---|---|
| date | Wed, 10 Dec 2025 00:24:13 +0000 |
| parents | 8729f69e9207 |
| children |
line wrap: on
line diff
--- a/ludwig_backend.py Wed Dec 03 01:28:52 2025 +0000 +++ b/ludwig_backend.py Wed Dec 10 00:24:13 2025 +0000 @@ -1,8 +1,9 @@ +import inspect import json import logging import os from pathlib import Path -from typing import Any, Dict, Optional, Protocol, Tuple +from typing import Any, Dict, List, Optional, Protocol, Tuple import pandas as pd import pandas.api.types as ptypes @@ -17,6 +18,7 @@ build_tabbed_html, encode_image_to_base64, format_config_table_html, + format_dataset_overview_table, format_stats_table_html, format_test_merged_stats_table_html, format_train_val_stats_table_html, @@ -33,7 +35,9 @@ from ludwig.utils.data_utils import get_split_path from metaformer_setup import get_visualizations_registry, META_DEFAULT_CFGS from plotly_plots import ( + build_binary_threshold_plot, build_classification_plots, + build_multiclass_metric_plots, build_prediction_diagnostics, build_regression_test_plots, build_regression_train_val_plots, @@ -267,6 +271,23 @@ else: encoder_config = {"type": raw_encoder} + # Set a human-friendly architecture string for reporting + arch_display = None + if is_metaformer and custom_model: + arch_display = str(custom_model) + elif isinstance(raw_encoder, dict): + enc_type = raw_encoder.get("type") + enc_variant = raw_encoder.get("model_variant") + if enc_type: + base = str(enc_type).replace("_", " ").title() + arch_display = f"{base} {enc_variant}" if enc_variant is not None else base + else: + arch_display = str(raw_encoder).replace("_", " ").title() + + if not arch_display: + arch_display = str(model_name) + config_params["architecture"] = arch_display + batch_size_cfg = batch_size or "auto" label_column_path = config_params.get("label_column_data_path") @@ -343,6 +364,7 @@ # Force Ludwig to respect our dimensions by setting additional parameters image_feat["preprocessing"]["requires_equal_dimensions"] = False logger.info(f"Set preprocessing dimensions for MetaFormer: {height}x{width} (infer_dimensions=True with max dimensions to allow validation)") + config_params["image_size"] = f"{height}x{width}" # Now set the encoder configuration image_feat["encoder"] = encoder_config @@ -374,8 +396,12 @@ image_feat["preprocessing"]["infer_image_max_height"] = height image_feat["preprocessing"]["infer_image_max_width"] = width logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions") + config_params["image_size"] = f"{height}x{width}" except (ValueError, IndexError): logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing") + elif not is_metaformer: + # No explicit resize provided; keep for reporting purposes + config_params.setdefault("image_size", "original") def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]: """Pick a validation metric that Ludwig will accept for the resolved task.""" @@ -471,6 +497,9 @@ config_params.get("validation_metric"), ) + # Propagate the resolved validation metric (including any task-based fallback or alias normalization) + config_params["validation_metric"] = val_metric + conf: Dict[str, Any] = { "model_type": "ecd", "input_features": [image_feat], @@ -641,18 +670,62 @@ except Exception as e: logger.error(f"Error converting Parquet to CSV: {e}") - def generate_plots(self, output_dir: Path) -> None: - """Generate all registered Ludwig visualizations for the latest experiment run.""" - logger.info("Generating all Ludwig visualizations…") + @staticmethod + def _extract_metric_series(stats: Dict[str, Any], split: str, prefer: Optional[str] = None) -> Tuple[Optional[str], Optional[List[float]]]: + """Pull the first numeric metric list we can find for the requested split.""" + if not isinstance(stats, dict): + return None, None + + split_stats = stats.get(split, {}) + ordered_metrics: List[Tuple[str, List[float]]] = [] + + def _append_metrics(metric_map: Dict[str, Any]) -> None: + for metric_name, values in metric_map.items(): + if isinstance(values, list) and any(isinstance(v, (int, float)) for v in values): + ordered_metrics.append((metric_name, values)) - # Keep only lightweight plots (drop compare_performance/roc_curves) - test_plots = { - "roc_curves_from_test_statistics", - "confusion_matrix", - } + if isinstance(split_stats, dict): + combined = split_stats.get("combined") + if isinstance(combined, dict): + _append_metrics(combined) + + for feature_name, feature_metrics in split_stats.items(): + if feature_name == "combined" or not isinstance(feature_metrics, dict): + continue + _append_metrics(feature_metrics) + + if prefer: + for metric_name, values in ordered_metrics: + if metric_name == prefer: + return metric_name, values + + return ordered_metrics[0] if ordered_metrics else (None, None) + + def generate_plots(self, output_dir: Path) -> None: + """Generate Ludwig visualizations (train/val + test) for the latest experiment run.""" + logger.info("Generating Ludwig visualizations (train/val + test)…") + + # Train/validation visualizations train_plots = { "learning_curves", - "compare_classifiers_performance_subset", + } + + # Test visualizations (multi-class transparency) + test_plots = { + "confusion_matrix", + "compare_performance", + "compare_classifiers_multiclass_multimetric", + "frequency_vs_f1", + "confidence_thresholding", + "confidence_thresholding_data_vs_acc", + "confidence_thresholding_data_vs_acc_subset", + "confidence_thresholding_data_vs_acc_subset_per_class", + # Binary-only visualizations will still be attempted; multi-class replacements handled elsewhere + "binary_threshold_vs_metric", + "roc_curves", + "precision_recall_curves", + "calibration_1_vs_all", + "calibration_multiclass", } output_dir = Path(output_dir) @@ -677,7 +750,6 @@ training_stats = _check(exp_dir / "training_statistics.json") test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) - probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) dataset_path = None @@ -688,6 +760,9 @@ cfg = json.load(f) dataset_path = _check(Path(cfg.get("dataset", ""))) split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) + model_name = cfg.get("model_name", "model") + else: + model_name = "model" output_feature = "" if desc.exists(): @@ -700,7 +775,44 @@ stats = json.load(f) output_feature = next(iter(stats.keys()), "") + probs_path = None + prob_candidates = [ + exp_dir / f"{LABEL_COLUMN_NAME}_probabilities.csv", + exp_dir / f"{output_feature}_probabilities.csv" if output_feature else None, + exp_dir / "probabilities.csv", + exp_dir / "predictions.csv", + exp_dir / PREDICTIONS_PARQUET_FILE_NAME, + ] + for cand in prob_candidates: + if cand and Path(cand).exists(): + probs_path = str(cand) + break + viz_registry = get_visualizations_registry() + if not viz_registry: + logger.warning( + "Ludwig visualizations registry not available; train/test PNGs will be skipped." + ) + return + + base_kwargs = { + "training_statistics": [training_stats] if training_stats else [], + "test_statistics": [test_stats] if test_stats else [], + "probabilities": [probs_path] if probs_path else [], + "output_feature_name": output_feature, + "ground_truth_split": 2, + "top_n_classes": [20], + "top_k": 3, + "metrics": ["f1", "precision", "recall", "accuracy"], + "positive_label": 0, + "ground_truth_metadata": gt_metadata, + "ground_truth": dataset_path, + "split_file": split_file, + "output_directory": None, # set per plot below + "normalize": False, + "file_format": "png", + "model_names": [model_name], + } for viz_name, viz_func in viz_registry.items(): if viz_name in train_plots: viz_dir_plot = train_viz @@ -710,25 +822,22 @@ continue try: + # Build per-viz kwargs based on the function signature to avoid unexpected args + sig_params = set(inspect.signature(viz_func).parameters.keys()) + call_kwargs = { + k: v + for k, v in base_kwargs.items() + if k in sig_params and v is not None + } + if "output_directory" in sig_params: + call_kwargs["output_directory"] = str(viz_dir_plot) + viz_func( - training_statistics=[training_stats] if training_stats else [], - test_statistics=[test_stats] if test_stats else [], - probabilities=[probs_path] if probs_path else [], - output_feature_name=output_feature, - ground_truth_split=2, - top_n_classes=[0], - top_k=3, - ground_truth_metadata=gt_metadata, - ground_truth=dataset_path, - split_file=split_file, - output_directory=str(viz_dir_plot), - normalize=False, - file_format="png", + **call_kwargs, ) logger.info(f"✔ Generated {viz_name}") except Exception as e: logger.warning(f"✘ Skipped {viz_name}: {e}") - logger.info(f"All visualizations written to {viz_dir}") def generate_html_report( @@ -756,6 +865,7 @@ label_metadata_path = config.get("label_column_data_path") if label_metadata_path: label_metadata_path = Path(label_metadata_path) + dataset_path_from_desc: Optional[Path] = None # Pull additional config details from description.json if available config_for_summary = dict(config) @@ -765,7 +875,8 @@ if desc_path.exists(): try: with open(desc_path, "r") as f: - desc_cfg = json.load(f).get("config", {}) + desc_json = json.load(f) + desc_cfg = desc_json.get("config", {}) if isinstance(desc_json, dict) else {} encoder_cfg = ( desc_cfg.get("input_features", [{}])[0].get("encoder", {}) if isinstance(desc_cfg.get("input_features", [{}]), list) @@ -783,10 +894,20 @@ arch_type = encoder_cfg.get("type") arch_variant = encoder_cfg.get("model_variant") + arch_custom = encoder_cfg.get("custom_model") arch_name = None + if arch_custom: + arch_name = str(arch_custom) if arch_type: arch_base = str(arch_type).replace("_", " ").title() - arch_name = f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base + arch_type_name = ( + f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base + ) + # Prefer explicit custom model names (e.g., MetaFormer) but fall back to encoder type + arch_name = arch_name or arch_type_name + if not arch_name and config.get("model_name"): + # As a last resort, show the user-selected model name (handles custom/MetaFormer cases) + arch_name = str(config.get("model_name")) summary_fields = { "architecture": arch_name, @@ -814,12 +935,22 @@ if k in {"target_column", "image_column"} and config_for_summary.get(k): continue config_for_summary.setdefault(k, v) + + dataset_field = None + if isinstance(desc_json, dict): + dataset_field = desc_json.get("dataset") or desc_cfg.get("dataset") + if dataset_field: + try: + dataset_path_from_desc = Path(dataset_field) + except TypeError: + dataset_path_from_desc = None + if dataset_path_from_desc and (not label_metadata_path or not label_metadata_path.exists()): + label_metadata_path = dataset_path_from_desc except Exception as e: # pragma: no cover - defensive logger.warning(f"Could not merge description.json into config summary: {e}") base_viz_dir = exp_dir / "visualizations" train_viz_dir = base_viz_dir / "train" - test_viz_dir = base_viz_dir / "test" html = get_html_template() @@ -880,10 +1011,164 @@ }); } }); -</script> + </script> """ html += f"<h1>{title}</h1>" + def append_plot_blocks(tab_html: str, plots: List[Dict[str, str]], title_suffix: str = "") -> str: + """Append Plotly blocks to a tab with consistent markup.""" + if not plots: + return tab_html + suffix = title_suffix or "" + for plot in plots: + tab_html += ( + f"<h2 style='text-align: center;'>{plot['title']}{suffix}</h2>" + f"<div class='plotly-center'>{plot['html']}</div>" + ) + return tab_html + + def build_dataset_overview( + label_metadata: Optional[Path], + output_type: Optional[str], + split_probabilities: Optional[List[float]], + label_split_counts: Optional[List[Dict[str, int]]] = None, + split_counts: Optional[Dict[int, int]] = None, + fallback_dataset: Optional[Path] = None, + ) -> str: + """Summarize dataset distribution across splits using the actual split config.""" + if label_split_counts: + # Use the actual counts captured during data prep instead of heuristics. + return format_dataset_overview_table(label_split_counts, regression_mode=False) + + if output_type == "regression" and split_counts: + rows = [ + {"split": "train", "count": int(split_counts.get(0, 0))}, + {"split": "validation", "count": int(split_counts.get(1, 0))}, + {"split": "test", "count": int(split_counts.get(2, 0))}, + ] + return format_dataset_overview_table(rows, regression_mode=True) + + candidate_paths: List[Path] = [] + if label_metadata and label_metadata.exists(): + candidate_paths.append(label_metadata) + if fallback_dataset and fallback_dataset.exists(): + candidate_paths.append(fallback_dataset) + if not candidate_paths: + return format_dataset_overview_table([]) + + def _normalize_split_probabilities( + probs: Optional[List[float]], + ) -> Optional[List[float]]: + if not probs or len(probs) != 3: + return None + try: + probs = [float(p) for p in probs] + except (TypeError, ValueError): + return None + total = sum(probs) + if total <= 0: + return None + return [p / total for p in probs] + + def _split_counts_from_column(df: pd.DataFrame) -> Dict[int, int]: + if SPLIT_COLUMN_NAME not in df.columns: + return {} + split_series = pd.to_numeric( + df[SPLIT_COLUMN_NAME], errors="coerce" + ).dropna() + if split_series.empty: + return {} + split_series = split_series.astype(int) + return split_series.value_counts().to_dict() + + def _split_counts_from_probs(total: int, probs: List[float]) -> Dict[int, int]: + train_n = int(total * probs[0]) + val_n = int(total * probs[1]) + test_n = max(0, total - train_n - val_n) + return {0: train_n, 1: val_n, 2: test_n} + + fallback_rows: Optional[List[Dict[str, int]]] = None + for meta_path in candidate_paths: + try: + df_labels = pd.read_csv(meta_path) + probs = _normalize_split_probabilities(split_probabilities) + + # Regression (or missing label column): only need split counts + if output_type == "regression" or LABEL_COLUMN_NAME not in df_labels.columns: + split_counts_found = _split_counts_from_column(df_labels) + if split_counts_found: + rows = [ + {"split": "train", "count": int(split_counts_found.get(0, 0))}, + {"split": "validation", "count": int(split_counts_found.get(1, 0))}, + {"split": "test", "count": int(split_counts_found.get(2, 0))}, + ] + return format_dataset_overview_table(rows, regression_mode=True) + if probs and fallback_rows is None: + split_counts_found = _split_counts_from_probs(len(df_labels), probs) + fallback_rows = [ + {"split": "train", "count": int(split_counts_found.get(0, 0))}, + {"split": "validation", "count": int(split_counts_found.get(1, 0))}, + {"split": "test", "count": int(split_counts_found.get(2, 0))}, + ] + continue + + # Classification: prefer actual split assignments; fall back to configured probabilities + if SPLIT_COLUMN_NAME in df_labels.columns: + df_counts = df_labels[[LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME]].copy() + df_counts[SPLIT_COLUMN_NAME] = pd.to_numeric( + df_counts[SPLIT_COLUMN_NAME], errors="coerce" + ) + df_counts = df_counts.dropna(subset=[SPLIT_COLUMN_NAME]) + if df_counts.empty: + continue + + df_counts[SPLIT_COLUMN_NAME] = df_counts[SPLIT_COLUMN_NAME].astype(int) + df_counts = df_counts.dropna(subset=[LABEL_COLUMN_NAME]) + counts = ( + df_counts.groupby([LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME]) + .size() + .unstack(fill_value=0) + .sort_index() + ) + rows = [] + for lbl, row in counts.iterrows(): + rows.append( + { + "label": str(lbl), + "train": int(row.get(0, 0)), + "validation": int(row.get(1, 0)), + "test": int(row.get(2, 0)), + } + ) + return format_dataset_overview_table(rows) + + if probs: + label_series = df_labels[LABEL_COLUMN_NAME].dropna() + label_counts = label_series.value_counts().sort_index() + if label_counts.empty: + continue + rows = [] + for lbl, count in label_counts.items(): + train_n = int(count * probs[0]) + val_n = int(count * probs[1]) + test_n = max(0, count - train_n - val_n) + rows.append( + { + "label": str(lbl), + "train": train_n, + "validation": val_n, + "test": test_n, + } + ) + fallback_rows = fallback_rows or rows + except Exception as exc: + logger.warning("Failed to build dataset overview from %s: %s", meta_path, exc) + continue + + if fallback_rows: + return format_dataset_overview_table(fallback_rows, regression_mode=output_type == "regression") + return format_dataset_overview_table([]) + metrics_html = "" train_val_metrics_html = "" test_metrics_html = "" @@ -911,6 +1196,23 @@ f"Could not load stats for HTML report: {type(e).__name__}: {e}" ) + if not output_type: + # Fallback to configured task type when stats are unavailable (e.g., failed run). + output_type = ( + str(config_for_summary.get("task_type")).lower() + if config_for_summary.get("task_type") + else None + ) + + dataset_overview_html = build_dataset_overview( + label_metadata_path, + output_type, + config.get("split_probabilities"), + config.get("label_split_counts"), + config.get("split_counts"), + dataset_path_from_desc, + ) + config_html = "" training_progress = self.get_training_process(output_dir) try: @@ -937,11 +1239,12 @@ exclude_names: Optional[set] = None, ) -> str: if not dir_path.exists(): - return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" + return "" exclude_names = exclude_names or set() - imgs = list(dir_path.glob("*.png")) + # Search recursively because Ludwig can nest figures under per-feature folders + imgs = list(dir_path.rglob("*.png")) # Exclude ROC curves and standard confusion matrices (keep only entropy version) default_exclude = { @@ -983,7 +1286,7 @@ ] if not imgs: - return f"<h2>{title}</h2><p><em>No plots found.</em></p>" + return "" # Sort images by name for consistent ordering (works with string and numeric labels) imgs = sorted(imgs, key=lambda x: x.name) @@ -1006,36 +1309,86 @@ ) return html_section - # Show performance first, then config - tab1_content = metrics_html + config_html + # Show dataset overview, performance first, then config + predictions_csv_path = exp_dir / "predictions.csv" + + tab1_content = dataset_overview_html + metrics_html + config_html - tab2_content = train_val_metrics_html + render_img_section( - "Training and Validation Visualizations", - train_viz_dir, - output_type, - exclude_names={ - "compare_classifiers_performance_from_prob.png", - "roc_curves_from_prediction_statistics.png", - "precision_recall_curves_from_prediction_statistics.png", - "precision_recall_curve.png", - }, + tab2_content = train_val_metrics_html + # Preload binary threshold plot so it appears first in Train/Val tab + threshold_plot = None + threshold_value = ( + config_for_summary.get("threshold") + if config_for_summary.get("threshold") is not None + else config.get("threshold") ) + if threshold_value is None and output_type == "binary": + threshold_value = 0.5 + if output_type == "binary" and predictions_csv_path.exists(): + try: + threshold_plot = build_binary_threshold_plot( + str(predictions_csv_path), + label_data_path=str(config.get("label_column_data_path")) + if config.get("label_column_data_path") + else None, + split_value=1, + ) + except Exception as e: + logger.warning(f"Could not generate validation threshold plot: {e}") + if train_stats_path.exists(): try: if output_type == "regression": tv_plots = build_regression_train_val_plots(str(train_stats_path)) + tab2_content = append_plot_blocks(tab2_content, tv_plots) else: tv_plots = build_train_validation_plots(str(train_stats_path)) - for plot in tv_plots: - tab2_content += ( - f"<h2 style='text-align: center;'>{plot['title']}</h2>" - f"<div class='plotly-center'>{plot['html']}</div>" - ) - if tv_plots: - logger.info(f"Generated {len(tv_plots)} train/val diagnostic plots") + # Add threshold plot first, then other train/val plots + if threshold_plot: + tab2_content = append_plot_blocks(tab2_content, [threshold_plot]) + # Only append once; avoid duplicates if added elsewhere + threshold_plot = None + tab2_content = append_plot_blocks(tab2_content, tv_plots) + if threshold_plot or tv_plots: + logger.info( + f"Added {len(tv_plots) + (1 if threshold_plot else 0)} train/val diagnostic plots" + ) except Exception as e: logger.warning(f"Could not generate train/val plots: {e}") + # Only include training PNGs for regression; classification is handled by filtered Plotly plots + if output_type == "regression": + tab2_content += render_img_section( + "Training and Validation Visualizations", + train_viz_dir, + output_type, + exclude_names={ + "compare_classifiers_performance_from_prob.png", + "roc_curves_from_prediction_statistics.png", + "precision_recall_curves_from_prediction_statistics.png", + "precision_recall_curve.png", + }, + ) + + # Validation diagnostics (calibration/threshold) from predictions.csv, using split=1 + if output_type in ("binary", "category") and predictions_csv_path.exists(): + try: + val_diag_plots = build_prediction_diagnostics( + str(predictions_csv_path), + label_data_path=str(config.get("label_column_data_path")) + if config.get("label_column_data_path") + else None, + split_value=1, + ) + val_conf_plots = [ + p for p in val_diag_plots if "Prediction Confidence Distribution" in p.get("title", "") + ] + tab2_content = append_plot_blocks( + tab2_content, val_conf_plots, " (Validation)" + ) + except Exception as e: + logger.warning(f"Could not generate validation diagnostics: {e}") + # --- Predictions vs Ground Truth table (REGRESSION ONLY) --- preds_section = "" parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME @@ -1077,18 +1430,12 @@ logger.warning(f"Could not build Predictions vs GT table: {e}") tab3_content = test_metrics_html + preds_section - test_plotly_added = False if output_type == "regression" and train_stats_path.exists(): try: test_plots = build_regression_test_plots(str(train_stats_path)) - for plot in test_plots: - tab3_content += ( - f"<h2 style='text-align: center;'>{plot['title']}</h2>" - f"<div class='plotly-center'>{plot['html']}</div>" - ) + tab3_content = append_plot_blocks(tab3_content, test_plots) if test_plots: - test_plotly_added = True logger.info(f"Generated {len(test_plots)} regression test plots") except Exception as e: logger.warning(f"Could not generate regression test plots: {e}") @@ -1104,46 +1451,42 @@ train_set_metadata_path=str(train_set_metadata_path) if train_set_metadata_path.exists() else None, + threshold=threshold_value, ) - for plot in interactive_plots: - tab3_content += ( - f"<h2 style='text-align: center;'>{plot['title']}</h2>" - f"<div class='plotly-center'>{plot['html']}</div>" - ) + tab3_content = append_plot_blocks(tab3_content, interactive_plots) if interactive_plots: - test_plotly_added = True - logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots") + logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots") except Exception as e: logger.warning(f"Could not generate Plotly plots: {e}") - # Add prediction diagnostics from predictions.csv - predictions_csv_path = exp_dir / "predictions.csv" - try: - diag_plots = build_prediction_diagnostics( - str(predictions_csv_path), - label_data_path=str(config.get("label_column_data_path")) - if config.get("label_column_data_path") - else None, - threshold=config.get("threshold"), - ) - for plot in diag_plots: - tab3_content += ( - f"<h2 style='text-align: center;'>{plot['title']}</h2>" - f"<div class='plotly-center'>{plot['html']}</div>" + # Multi-class transparency plots from test stats (replace ROC/PR for multi-class) + if output_type == "category" and test_stats_path.exists(): + try: + multi_curves = build_multiclass_metric_plots(str(test_stats_path)) + tab3_content = append_plot_blocks(tab3_content, multi_curves) + if multi_curves: + logger.info("Added multi-class per-class metric plots to test tab") + except Exception as e: + logger.warning(f"Could not generate multi-class metric plots: {e}") + + # Test diagnostics (confidence histogram) from predictions.csv, using split=2 + if predictions_csv_path.exists(): + try: + test_diag_plots = build_prediction_diagnostics( + str(predictions_csv_path), + label_data_path=str(config.get("label_column_data_path")) + if config.get("label_column_data_path") + else None, + split_value=2, ) - if diag_plots: - test_plotly_added = True - logger.info(f"Generated {len(diag_plots)} prediction diagnostic plots") - except Exception as e: - logger.warning(f"Could not generate prediction diagnostics: {e}") - - # Fallback: include static PNGs if no interactive plots were added - if not test_plotly_added: - tab3_content += render_img_section( - "Test Visualizations (PNG fallback)", - test_viz_dir, - output_type, - ) + test_conf_plots = [ + p for p in test_diag_plots if "Prediction Confidence Distribution" in p.get("title", "") + ] + if test_conf_plots: + tab3_content = append_plot_blocks(tab3_content, test_conf_plots) + logger.info("Added test prediction confidence plot") + except Exception as e: + logger.warning(f"Could not generate test diagnostics: {e}") # Add static TEST PNGs (with default dedupe/exclusions) tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
