# HG changeset patch # User goeckslab # Date 1764344749 0 # Node ID d17e3a1b865946a7aeed00b9de55ff31a1a1241d # Parent 94cd9ac4a9b149e2e4697a597b3a51c7c415ee29 planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17 diff -r 94cd9ac4a9b1 -r d17e3a1b8659 html_structure.py --- a/html_structure.py Wed Nov 26 22:00:32 2025 +0000 +++ b/html_structure.py Fri Nov 28 15:45:49 2025 +0000 @@ -22,21 +22,31 @@ output_type: Optional[str] = None, ) -> str: display_keys = [ + "architecture", + "pretrained", + "trainable", + "target_column", "task_type", - "model_name", + "validation_metric", + "loss_function", + "threshold", "epochs", + "total_epochs", "batch_size", "fine_tune", "use_pretrained", "learning_rate", + "optimizer", "random_seed", "early_stop", - "threshold", + "use_mixed_precision", + "gradient_clipping", ] rows = [] for key in display_keys: + val_str = "N/A" val = config.get(key, None) if key == "threshold": if output_type != "binary": @@ -49,7 +59,7 @@ if key == "task_type": val_str = val.title() if isinstance(val, str) else "N/A" elif key == "batch_size": - if val is not None: + if isinstance(val, (int, float)): val_str = int(val) else: val = "auto" @@ -120,6 +130,21 @@ ) else: val_str = val + elif key == "pretrained": + if isinstance(val, bool): + val_str = "Yes (ImageNet)" if val else "No" + else: + val_str = val if val is not None else "N/A" + elif key == "trainable": + if isinstance(val, bool): + val_str = "Trainable" if val else "Frozen" + else: + val_str = val if val is not None else "N/A" + elif key == "use_mixed_precision": + if isinstance(val, bool): + val_str = "Yes" if val else "No" + else: + val_str = val if val is not None else "N/A" else: val_str = val if val is not None else "N/A" if val_str == "N/A" and key not in ["task_type"]: @@ -155,7 +180,7 @@ ) html = f""" -

Model and Training Summary

+

Training Configuration (Model, Data, Metrics)

@@ -519,15 +544,15 @@ def build_tabbed_html(metrics_html: str, train_val_html: str, test_html: str) -> str: """ Build a 3-tab interface: - - Config and Results Summary + - Config and Overall Performance Summary - Train/Validation Results - Test Results Includes a persistent "Help" button that toggles the metrics modal. """ return f"""
-
Config and Results Summary
-
Train/Validation Results
+
Config and Overall Performance Summary
+
Training and Validation Results
Test Results
diff -r 94cd9ac4a9b1 -r d17e3a1b8659 image_learner.xml --- a/image_learner.xml Wed Nov 26 22:00:32 2025 +0000 +++ b/image_learner.xml Fri Nov 28 15:45:49 2025 +0000 @@ -1,4 +1,4 @@ - + trains and evaluates an image classification/regression model quay.io/goeckslab/galaxy-ludwig-gpu:0.10.1 @@ -29,6 +29,16 @@ ln -sf '$input_csv' "./${sanitized_input_csv}"; #end if + #if $task_selection.task == "binary" + #set $selected_validation_metric = $task_selection.validation_metric_binary + #elif $task_selection.task == "classification" + #set $selected_validation_metric = $task_selection.validation_metric_multiclass + #elif $task_selection.task == "regression" + #set $selected_validation_metric = $task_selection.validation_metric_regression + #else + #set $selected_validation_metric = None + #end if + python '$__tool_directory__/image_learner_cli.py' --csv-file "./${sanitized_input_csv}" --image-zip "$image_zip" @@ -39,27 +49,38 @@ --fine-tune #end if #end if - #if $customize_defaults == "true" - #if $epochs - --epochs "$epochs" + #if $advanced_settings.customize_defaults == "true" + #if $advanced_settings.epochs + --epochs "$advanced_settings.epochs" #end if - #if $early_stop - --early-stop "$early_stop" + #if $advanced_settings.early_stop + --early-stop "$advanced_settings.early_stop" #end if - #if $learning_rate_define == "true" - --learning-rate "$learning_rate" + #if $advanced_settings.learning_rate_condition.learning_rate_define == "true" + --learning-rate "$advanced_settings.learning_rate_condition.learning_rate" #end if - #if $batch_size_define == "true" - --batch-size "$batch_size" + #if $advanced_settings.batch_size_condition.batch_size_define == "true" + --batch-size "$advanced_settings.batch_size_condition.batch_size" #end if - --split-probabilities "$train_split" "$val_split" "$test_split" - #if $threshold - --threshold "$threshold" + --split-probabilities "$advanced_settings.train_split" "$advanced_settings.val_split" "$advanced_settings.test_split" + #if $advanced_settings.threshold + --threshold "$advanced_settings.threshold" #end if #end if #if $augmentation --augmentation "$augmentation" #end if + #if $selected_validation_metric + --validation-metric "$selected_validation_metric" + #end if + #if $column_override.override_columns == "true" + #if $column_override.target_column + --target-column "$column_override.target_column" + #end if + #if $column_override.image_column + --image-column "$column_override.image_column" + #end if + #end if --image-resize "$image_resize" --random-seed "$random_seed" --output-dir "." && @@ -74,6 +95,68 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -325,10 +408,12 @@ + + - - + + @@ -347,8 +432,8 @@ - - + + @@ -366,8 +451,8 @@ - - + + @@ -388,8 +473,8 @@ - - + + @@ -410,8 +495,8 @@ - - + + @@ -433,8 +518,8 @@ - - + + @@ -462,8 +547,8 @@ - - + + @@ -487,11 +572,15 @@ + + - - - + + + + + diff -r 94cd9ac4a9b1 -r d17e3a1b8659 image_learner_cli.py --- a/image_learner_cli.py Wed Nov 26 22:00:32 2025 +0000 +++ b/image_learner_cli.py Fri Nov 28 15:45:49 2025 +0000 @@ -142,6 +142,42 @@ "Overrides default 0.5." ), ) + parser.add_argument( + "--validation-metric", + type=str, + default="roc_auc", + choices=[ + "accuracy", + "loss", + "roc_auc", + "balanced_accuracy", + "precision", + "recall", + "f1", + "specificity", + "log_loss", + "pearson_r", + "mae", + "mse", + "rmse", + "mape", + "r2", + "explained_variance", + ], + help="Metric Ludwig uses to select the best model during training/validation.", + ) + parser.add_argument( + "--target-column", + type=str, + default=None, + help="Name of the target/label column in the metadata file (defaults to 'label').", + ) + parser.add_argument( + "--image-column", + type=str, + default=None, + help="Name of the image column in the metadata file (defaults to 'image_path').", + ) args = parser.parse_args() diff -r 94cd9ac4a9b1 -r d17e3a1b8659 image_workflow.py --- a/image_workflow.py Wed Nov 26 22:00:32 2025 +0000 +++ b/image_workflow.py Fri Nov 28 15:45:49 2025 +0000 @@ -127,16 +127,31 @@ logger.error("Error loading metadata file", exc_info=True) raise - required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} - missing = required - set(df.columns) - if missing: - raise ValueError(f"Missing CSV columns: {', '.join(missing)}") + label_col = self.args.target_column or LABEL_COLUMN_NAME + image_col = self.args.image_column or IMAGE_PATH_COLUMN_NAME + + # Remember the user-specified columns for reporting + self.args.report_target_column = label_col + self.args.report_image_column = image_col + + missing_cols = [] + if label_col not in df.columns: + missing_cols.append(label_col) + if image_col not in df.columns: + missing_cols.append(image_col) + if missing_cols: + raise ValueError( + f"Missing required column(s) in metadata: {', '.join(missing_cols)}. " + "Update the XML selections or rename your columns." + ) + + if label_col != LABEL_COLUMN_NAME: + df = df.rename(columns={label_col: LABEL_COLUMN_NAME}) + if image_col != IMAGE_PATH_COLUMN_NAME: + df = df.rename(columns={image_col: IMAGE_PATH_COLUMN_NAME}) try: - # Use relative paths that Ludwig can resolve from its internal working directory - df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( - lambda p: str(Path("images") / p) - ) + df = self._map_image_paths_with_search(df) except Exception: logger.error("Error updating image paths", exc_info=True) raise @@ -205,6 +220,71 @@ self.label_metadata = metadata self.output_type_hint = "binary" if metadata.get("is_binary") else None + def _map_image_paths_with_search(self, df: pd.DataFrame) -> pd.DataFrame: + """Map image identifiers to actual files by searching the extracted directory.""" + if not self.image_extract_dir: + raise RuntimeError("Image directory is not initialized.") + + # Build lookup maps for fast resolution by stem or full name + lookup_by_stem = {} + lookup_by_name = {} + for fpath in self.image_extract_dir.rglob("*"): + if fpath.is_file(): + stem_key = fpath.stem.lower() + name_key = fpath.name.lower() + # Prefer first encounter; warn on collisions + if stem_key in lookup_by_stem and lookup_by_stem[stem_key] != fpath: + logger.warning( + "Multiple files share the same stem '%s'. Using '%s'.", + stem_key, + lookup_by_stem[stem_key], + ) + else: + lookup_by_stem[stem_key] = fpath + if name_key in lookup_by_name and lookup_by_name[name_key] != fpath: + logger.warning( + "Multiple files share the same name '%s'. Using '%s'.", + name_key, + lookup_by_name[name_key], + ) + else: + lookup_by_name[name_key] = fpath + + resolved_paths = [] + missing_count = 0 + missing_samples = [] + + for raw in df[IMAGE_PATH_COLUMN_NAME]: + raw_str = str(raw) + name_key = Path(raw_str).name.lower() + stem_key = Path(raw_str).stem.lower() + resolved = lookup_by_name.get(name_key) or lookup_by_stem.get(stem_key) + + if resolved is None: + missing_count += 1 + missing_samples.append(raw_str) + resolved_paths.append(pd.NA) + continue + + try: + rel_path = resolved.relative_to(self.image_extract_dir) + except ValueError: + rel_path = resolved + resolved_paths.append(str(Path("images") / rel_path)) + + if missing_count: + logger.warning( + "Unable to locate %d image(s) from the metadata in the extracted images directory.", + missing_count, + ) + preview = ", ".join(missing_samples[:5]) + logger.warning("Missing samples (showing up to 5): %s", preview) + + df = df.copy() + df[IMAGE_PATH_COLUMN_NAME] = resolved_paths + df = df.dropna(subset=[IMAGE_PATH_COLUMN_NAME]).reset_index(drop=True) + return df + # Removed duplicate method def _detect_image_dimensions(self) -> Tuple[int, int]: @@ -275,6 +355,9 @@ "threshold": self.args.threshold, "label_metadata": self.label_metadata, "output_type_hint": self.output_type_hint, + "validation_metric": self.args.validation_metric, + "target_column": getattr(self.args, "report_target_column", LABEL_COLUMN_NAME), + "image_column": getattr(self.args, "report_image_column", IMAGE_PATH_COLUMN_NAME), } yaml_str = self.backend.prepare_config(backend_args, split_cfg) @@ -297,6 +380,9 @@ if ran_ok: logger.info("Workflow completed successfully.") + # Convert predictions parquet → csv + self.backend.convert_parquet_to_csv(self.args.output_dir) + logger.info("Converted Parquet to CSV.") # Generate a very small set of plots to conserve disk space self.backend.generate_plots(self.args.output_dir) # Build HTML report (robust to missing metrics) @@ -307,9 +393,6 @@ split_info, ) logger.info(f"HTML report generated at: {report_file}") - # Convert predictions parquet → csv - self.backend.convert_parquet_to_csv(self.args.output_dir) - logger.info("Converted Parquet to CSV.") # Post-process cleanup to reduce disk footprint for subsequent tests try: self._postprocess_cleanup(self.args.output_dir) diff -r 94cd9ac4a9b1 -r d17e3a1b8659 ludwig_backend.py --- a/ludwig_backend.py Wed Nov 26 22:00:32 2025 +0000 +++ b/ludwig_backend.py Fri Nov 28 15:45:49 2025 +0000 @@ -31,7 +31,13 @@ ) from ludwig.utils.data_utils import get_split_path from metaformer_setup import get_visualizations_registry, META_DEFAULT_CFGS -from plotly_plots import build_classification_plots +from plotly_plots import ( + build_classification_plots, + build_prediction_diagnostics, + build_regression_test_plots, + build_regression_train_val_plots, + build_train_validation_plots, +) from utils import detect_output_type, extract_metrics_from_json logger = logging.getLogger("ImageLearner") @@ -72,6 +78,8 @@ class LudwigDirectBackend: """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" + _torchvision_patched = False + def _detect_image_dimensions(self, image_zip_path: str) -> Tuple[int, int]: """Detect image dimensions from the first image in the dataset.""" try: @@ -344,6 +352,72 @@ logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions") except (ValueError, IndexError): logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing") + + def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]: + """Pick a validation metric that Ludwig will accept for the resolved task.""" + default_map = { + "regression": "pearson_r", + "binary": "roc_auc", + "category": "accuracy", + } + allowed_map = { + "regression": { + "pearson_r", + "mean_absolute_error", + "mean_squared_error", + "root_mean_squared_error", + "mean_absolute_percentage_error", + "r2", + "explained_variance", + "loss", + }, + # Ludwig rejects f1 and balanced_accuracy for binary outputs; keep to known-safe set. + "binary": { + "roc_auc", + "accuracy", + "precision", + "recall", + "specificity", + "log_loss", + "loss", + }, + "category": { + "accuracy", + "balanced_accuracy", + "precision", + "recall", + "f1", + "specificity", + "log_loss", + "loss", + }, + } + alias_map = { + "regression": { + "mae": "mean_absolute_error", + "mse": "mean_squared_error", + "rmse": "root_mean_squared_error", + "mape": "mean_absolute_percentage_error", + }, + } + + default_metric = default_map.get(task) + allowed = allowed_map.get(task, set()) + metric = requested or default_metric + + if metric is None: + return None + + metric = alias_map.get(task, {}).get(metric, metric) + + if metric not in allowed: + if requested: + logger.warning( + f"Validation metric '{requested}' is not supported for {task} outputs; using '{default_metric}' instead." + ) + metric = default_metric + return metric + if task_type == "regression": output_feat = { "name": LABEL_COLUMN_NAME, @@ -351,7 +425,7 @@ "decoder": {"type": "regressor"}, "loss": {"type": "mean_squared_error"}, } - val_metric = config_params.get("validation_metric", "mean_squared_error") + val_metric = _resolve_validation_metric("regression", config_params.get("validation_metric")) else: if num_unique_labels == 2: @@ -368,7 +442,10 @@ "type": "category", "loss": {"type": "softmax_cross_entropy"}, } - val_metric = None + val_metric = _resolve_validation_metric( + "binary" if num_unique_labels == 2 else "category", + config_params.get("validation_metric"), + ) conf: Dict[str, Any] = { "model_type": "ecd", @@ -380,7 +457,7 @@ "early_stop": early_stop, "batch_size": batch_size_cfg, "learning_rate": learning_rate, - # only set validation_metric for regression + # set validation_metric when provided **({"validation_metric": val_metric} if val_metric else {}), }, "preprocessing": { @@ -402,6 +479,41 @@ ) raise + def _patch_torchvision_download(self) -> None: + """ + Torchvision weight downloads sometimes fail checksum validation behind + corporate proxies that rewrite binaries. Skip hash checking to allow + pre-trained weights to load in those environments. + """ + if LudwigDirectBackend._torchvision_patched: + return + try: + import torch.hub as torch_hub + + original = torch_hub.load_state_dict_from_url + original_download = torch_hub.download_url_to_file + + def _no_hash(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None): + return original( + url, + model_dir=model_dir, + map_location=map_location, + progress=progress, + check_hash=False, + file_name=file_name, + ) + + def _download_no_hash(url, dst, hash_prefix=None, progress=True): + # Torchvision's download_url_to_file signature does not accept check_hash in older versions. + return original_download(url, dst, hash_prefix=None, progress=progress) + + torch_hub.load_state_dict_from_url = _no_hash # type: ignore[assignment] + torch_hub.download_url_to_file = _download_no_hash # type: ignore[assignment] + LudwigDirectBackend._torchvision_patched = True + logger.info("Disabled torchvision weight hash verification to avoid proxy-corrupted downloads.") + except Exception as exc: + logger.warning(f"Could not patch torchvision download hash check: {exc}") + def run_experiment( self, dataset_path: Path, @@ -412,6 +524,9 @@ """Invoke Ludwig's internal experiment_cli function to run the experiment.""" logger.info("LudwigDirectBackend: Starting experiment execution.") + # Avoid strict hash validation for torchvision weights (common in proxied environments) + self._patch_torchvision_download() + try: from ludwig.experiment import experiment_cli except ImportError as e: @@ -506,24 +621,10 @@ """Generate all registered Ludwig visualizations for the latest experiment run.""" logger.info("Generating all Ludwig visualizations…") + # Keep only lightweight plots (drop compare_performance/roc_curves) test_plots = { - "compare_performance", - "compare_classifiers_performance_from_prob", - "compare_classifiers_performance_from_pred", - "compare_classifiers_performance_changing_k", - "compare_classifiers_multiclass_multimetric", - "compare_classifiers_predictions", - "confidence_thresholding_2thresholds_2d", - "confidence_thresholding_2thresholds_3d", - "confidence_thresholding", - "confidence_thresholding_data_vs_acc", - "binary_threshold_vs_metric", - "roc_curves", "roc_curves_from_test_statistics", - "calibration_1_vs_all", - "calibration_multiclass", "confusion_matrix", - "frequency_vs_f1", } train_plots = { "learning_curves", @@ -627,6 +728,70 @@ if not exp_dirs: raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") exp_dir = exp_dirs[-1] + train_set_metadata_path = exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME + label_metadata_path = config.get("label_column_data_path") + if label_metadata_path: + label_metadata_path = Path(label_metadata_path) + + # Pull additional config details from description.json if available + config_for_summary = dict(config) + if "target_column" not in config_for_summary or not config_for_summary.get("target_column"): + config_for_summary["target_column"] = LABEL_COLUMN_NAME + desc_path = exp_dir / DESCRIPTION_FILE_NAME + if desc_path.exists(): + try: + with open(desc_path, "r") as f: + desc_cfg = json.load(f).get("config", {}) + encoder_cfg = ( + desc_cfg.get("input_features", [{}])[0].get("encoder", {}) + if isinstance(desc_cfg.get("input_features", [{}]), list) + else {} + ) + output_cfg = ( + desc_cfg.get("output_features", [{}])[0] + if isinstance(desc_cfg.get("output_features", [{}]), list) + else {} + ) + trainer_cfg = desc_cfg.get("trainer", {}) if isinstance(desc_cfg, dict) else {} + loss_cfg = output_cfg.get("loss", {}) if isinstance(output_cfg, dict) else {} + opt_cfg = trainer_cfg.get("optimizer", {}) if isinstance(trainer_cfg, dict) else {} + clip_cfg = trainer_cfg.get("gradient_clipping", {}) if isinstance(trainer_cfg, dict) else {} + + arch_type = encoder_cfg.get("type") + arch_variant = encoder_cfg.get("model_variant") + arch_name = None + if arch_type: + arch_base = str(arch_type).replace("_", " ").title() + arch_name = f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base + + summary_fields = { + "architecture": arch_name, + "model_variant": arch_variant, + "pretrained": encoder_cfg.get("use_pretrained"), + "trainable": encoder_cfg.get("trainable"), + "target_column": output_cfg.get("column"), + "task_type": output_cfg.get("type"), + "validation_metric": trainer_cfg.get("validation_metric"), + "loss_function": loss_cfg.get("type"), + "threshold": output_cfg.get("threshold"), + "total_epochs": trainer_cfg.get("epochs"), + "early_stop": trainer_cfg.get("early_stop"), + "batch_size": trainer_cfg.get("batch_size"), + "optimizer": opt_cfg.get("type"), + "learning_rate": trainer_cfg.get("learning_rate"), + "random_seed": desc_cfg.get("random_seed") or config.get("random_seed"), + "use_mixed_precision": trainer_cfg.get("use_mixed_precision"), + "gradient_clipping": clip_cfg.get("clipglobalnorm"), + } + for k, v in summary_fields.items(): + if v is None: + continue + # Do not override user-passed target/image column names in config + if k in {"target_column", "image_column"} and config_for_summary.get(k): + continue + config_for_summary.setdefault(k, v) + except Exception as e: # pragma: no cover - defensive + logger.warning(f"Could not merge description.json into config summary: {e}") base_viz_dir = exp_dir / "visualizations" train_viz_dir = base_viz_dir / "train" @@ -698,9 +863,10 @@ metrics_html = "" train_val_metrics_html = "" test_metrics_html = "" + output_type = None + train_stats_path = exp_dir / "training_statistics.json" + test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME try: - train_stats_path = exp_dir / "training_statistics.json" - test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME if train_stats_path.exists() and test_stats_path.exists(): with open(train_stats_path) as f: train_stats = json.load(f) @@ -725,10 +891,19 @@ training_progress = self.get_training_process(output_dir) try: config_html = format_config_table_html( - config, split_info, training_progress, output_type + config_for_summary, split_info, training_progress, output_type ) except Exception as e: logger.warning(f"Could not load config for HTML report: {e}") + config_html = ( + "

Model and Training Summary

" + "

Configuration details unavailable.

" + ) + if not config_html: + config_html = ( + "

Model and Training Summary

" + "

No configuration details found.

" + ) # ---------- image rendering with exclusions ---------- def render_img_section( @@ -776,6 +951,11 @@ for img in imgs if img.name not in default_exclude and img.name not in exclude_names + and not ( + "learning_curves" in img.stem + and "loss" in img.stem + and "label" in img.stem + ) ] if not imgs: @@ -802,7 +982,8 @@ ) return html_section - tab1_content = config_html + metrics_html + # Show performance first, then config + tab1_content = metrics_html + config_html tab2_content = train_val_metrics_html + render_img_section( "Training and Validation Visualizations", @@ -815,6 +996,21 @@ "precision_recall_curve.png", }, ) + if train_stats_path.exists(): + try: + if output_type == "regression": + tv_plots = build_regression_train_val_plots(str(train_stats_path)) + else: + tv_plots = build_train_validation_plots(str(train_stats_path)) + for plot in tv_plots: + tab2_content += ( + f"

{plot['title']}

" + f"
{plot['html']}
" + ) + if tv_plots: + logger.info(f"Generated {len(tv_plots)} train/val diagnostic plots") + except Exception as e: + logger.warning(f"Could not generate train/val plots: {e}") # --- Predictions vs Ground Truth table (REGRESSION ONLY) --- preds_section = "" @@ -849,7 +1045,7 @@ "
" "" "
" - "
" + "
" + preds_html + "
" ) @@ -857,27 +1053,75 @@ logger.warning(f"Could not build Predictions vs GT table: {e}") tab3_content = test_metrics_html + preds_section + test_plotly_added = False + + if output_type == "regression" and train_stats_path.exists(): + try: + test_plots = build_regression_test_plots(str(train_stats_path)) + for plot in test_plots: + tab3_content += ( + f"

{plot['title']}

" + f"
{plot['html']}
" + ) + if test_plots: + test_plotly_added = True + logger.info(f"Generated {len(test_plots)} regression test plots") + except Exception as e: + logger.warning(f"Could not generate regression test plots: {e}") if output_type in ("binary", "category") and test_stats_path.exists(): try: interactive_plots = build_classification_plots( str(test_stats_path), str(train_stats_path) if train_stats_path.exists() else None, + metadata_csv_path=str(label_metadata_path) + if label_metadata_path and label_metadata_path.exists() + else None, + train_set_metadata_path=str(train_set_metadata_path) + if train_set_metadata_path.exists() + else None, ) for plot in interactive_plots: tab3_content += ( f"

{plot['title']}

" f"
{plot['html']}
" ) + if interactive_plots: + test_plotly_added = True logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots") except Exception as e: logger.warning(f"Could not generate Plotly plots: {e}") + # Add prediction diagnostics from predictions.csv + predictions_csv_path = exp_dir / "predictions.csv" + try: + diag_plots = build_prediction_diagnostics( + str(predictions_csv_path), + label_data_path=str(config.get("label_column_data_path")) + if config.get("label_column_data_path") + else None, + threshold=config.get("threshold"), + ) + for plot in diag_plots: + tab3_content += ( + f"

{plot['title']}

" + f"
{plot['html']}
" + ) + if diag_plots: + test_plotly_added = True + logger.info(f"Generated {len(diag_plots)} prediction diagnostic plots") + except Exception as e: + logger.warning(f"Could not generate prediction diagnostics: {e}") + + # Fallback: include static PNGs if no interactive plots were added + if not test_plotly_added: + tab3_content += render_img_section( + "Test Visualizations (PNG fallback)", + test_viz_dir, + output_type, + ) + # Add static TEST PNGs (with default dedupe/exclusions) - tab3_content += render_img_section( - "Test Visualizations", test_viz_dir, output_type - ) - tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) modal_html = get_metrics_help_modal() html += tabbed_html + modal_html + get_html_closing() diff -r 94cd9ac4a9b1 -r d17e3a1b8659 plotly_plots.py --- a/plotly_plots.py Wed Nov 26 22:00:32 2025 +0000 +++ b/plotly_plots.py Fri Nov 28 15:45:49 2025 +0000 @@ -7,13 +7,105 @@ import plotly.graph_objects as go import plotly.io as pio from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME -from sklearn.metrics import auc, roc_curve -from sklearn.preprocessing import label_binarize + + +def _style_fig(fig: go.Figure, font_size: int = 12) -> go.Figure: + """Apply consistent styling across Plotly figures.""" + fig.update_layout( + font=dict(size=font_size), + plot_bgcolor="#ffffff", + paper_bgcolor="#ffffff", + ) + fig.update_xaxes(gridcolor="#e8e8e8") + fig.update_yaxes(gridcolor="#e8e8e8") + return fig + + +def _labels_from_metadata_dict(meta_dict: dict) -> List[str]: + """Extract ordered label names from Ludwig train_set_metadata.""" + if not isinstance(meta_dict, dict): + return [] + + for key in ("idx2str", "idx2label", "vocab"): + seq = meta_dict.get(key) + if isinstance(seq, list) and seq: + return [str(v) for v in seq] + + str2idx = meta_dict.get("str2idx") + if isinstance(str2idx, dict) and str2idx: + int_indices = [v for v in str2idx.values() if isinstance(v, int)] + if int_indices: + max_idx = max(int_indices) + ordered = [None] * (max_idx + 1) + for name, idx in str2idx.items(): + if isinstance(idx, int) and 0 <= idx < len(ordered): + ordered[idx] = name + return [str(v) for v in ordered if v is not None] + + return [] + + +def _resolve_confusion_labels( + label_stats: dict, + n_classes: int, + metadata_csv_path: Optional[str], + train_set_metadata_path: Optional[str], +) -> List[str]: + """Prefer original labels from metadata; fall back to stats if unavailable.""" + if train_set_metadata_path: + try: + meta_path = Path(train_set_metadata_path) + if meta_path.exists(): + with open(meta_path, "r") as f: + meta_json = json.load(f) + label_meta = meta_json.get(LABEL_COLUMN_NAME) + if not isinstance(label_meta, dict): + label_meta = next( + ( + v + for v in meta_json.values() + if isinstance(v, dict) + and any(k in v for k in ("idx2str", "str2idx", "idx2label", "vocab")) + ), + None, + ) + labels_from_meta = _labels_from_metadata_dict(label_meta) if label_meta else [] + if labels_from_meta and len(labels_from_meta) >= n_classes: + return [str(label) for label in labels_from_meta[:n_classes]] + except Exception as exc: + print(f"Warning: Unable to read labels from train_set_metadata: {exc}") + + if metadata_csv_path: + try: + csv_path = Path(metadata_csv_path) + if csv_path.exists(): + df_meta = pd.read_csv(csv_path) + if LABEL_COLUMN_NAME in df_meta.columns: + uniques = df_meta[LABEL_COLUMN_NAME].dropna().unique().tolist() + if uniques and len(uniques) >= n_classes: + return [str(u) for u in uniques[:n_classes]] + except Exception as exc: + print(f"Warning: Unable to read labels from metadata CSV: {exc}") + + pcs = label_stats.get("per_class_stats", {}) + if pcs: + pcs_labels = [str(k) for k in pcs.keys()] + if len(pcs_labels) >= n_classes: + return pcs_labels[:n_classes] + + labels = label_stats.get("labels") + if not labels: + labels = [str(i) for i in range(n_classes)] + if len(labels) < n_classes: + labels = labels + [str(i) for i in range(len(labels), n_classes)] + return [str(label) for label in labels[:n_classes]] def build_classification_plots( test_stats_path: str, training_stats_path: Optional[str] = None, + metadata_csv_path: Optional[str] = None, + train_set_metadata_path: Optional[str] = None, ) -> List[Dict[str, str]]: """ Read Ludwig’s test_statistics.json and build three interactive Plotly panels: @@ -21,6 +113,9 @@ - ROC-AUC - Classification Report Heatmap + If metadata paths are provided, the confusion matrix axes will use the original + label values from the training metadata rather than integer-encoded labels. + Returns a list of dicts, each with: { "title": , @@ -42,12 +137,12 @@ # 0) Confusion Matrix cm = np.array(label_stats["confusion_matrix"], dtype=int) - # Try to get actual class names from per_class_stats keys (which contain the real labels) - pcs = label_stats.get("per_class_stats", {}) - if pcs: - labels = list(pcs.keys()) - else: - labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])]) + labels = _resolve_confusion_labels( + label_stats, + n_classes, + metadata_csv_path=metadata_csv_path, + train_set_metadata_path=train_set_metadata_path, + ) total = cm.sum() fig_cm = go.Figure( @@ -70,6 +165,7 @@ height=side_px, margin=dict(t=100, l=80, r=80, b=80), ) + _style_fig(fig_cm) # annotate counts and percentages mval = cm.max() if cm.size else 0 @@ -110,16 +206,28 @@ ) }) - # 1) ROC-AUC Curves (Multi-class) - roc_plot = _build_roc_auc_plot(test_stats_path, labels, common_cfg) + # 1) ROC Curve (from test_statistics) + roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels) if roc_plot: plots.append(roc_plot) + # 2) Precision-Recall Curve (from test_statistics) + pr_plot = _build_precision_recall_plot(label_stats, common_cfg) + if pr_plot: + plots.append(pr_plot) + # 2) Classification Report Heatmap pcs = label_stats.get("per_class_stats", {}) if pcs: classes = list(pcs.keys()) - metrics = ["precision", "recall", "f1_score"] + metrics = [ + "precision", + "recall", + "f1_score", + "accuracy", + "matthews_correlation_coefficient", + "specificity", + ] z, txt = [], [] for c in classes: row, trow = [], [] @@ -133,7 +241,7 @@ fig_cr = go.Figure( go.Heatmap( z=z, - x=metrics, + x=[m.replace("_", " ") for m in metrics], y=[str(c) for c in classes], text=txt, texttemplate="%{text}", @@ -143,15 +251,16 @@ ) ) fig_cr.update_layout( - title="Classification Report", + title="Per-Class metrics", xaxis_title="", yaxis_title="Class", width=side_px, height=side_px, margin=dict(t=80, l=80, r=80, b=80), ) + _style_fig(fig_cr) plots.append({ - "title": "Classification Report", + "title": "Per-Class metrics", "html": pio.to_html( fig_cr, full_html=False, @@ -160,68 +269,667 @@ ) }) + # 3) Prediction Diagnostics (from predictions.csv) + # Note: appended separately in generate_html_report, not returned here. + + return plots + + +def build_train_validation_plots(train_stats_path: str) -> List[Dict[str, str]]: + """Generate Train/Validation learning curve plots from training_statistics.json.""" + if not train_stats_path or not Path(train_stats_path).exists(): + return [] + try: + with open(train_stats_path, "r") as f: + train_stats = json.load(f) + except Exception as exc: + print(f"Warning: Unable to read training statistics: {exc}") + return [] + + label_train = (train_stats.get("training") or {}).get("label", {}) + label_val = (train_stats.get("validation") or {}).get("label", {}) + if not label_train and not label_val: + return [] + plots: List[Dict[str, str]] = [] + include_js = True # Load Plotly.js once for this group + + def _get_series(stats: dict, metric: str) -> List[float]: + if metric not in stats: + return [] + vals = stats.get(metric, []) + if isinstance(vals, list): + return [float(v) for v in vals] + try: + return [float(vals)] + except Exception: + return [] + + def _line_plot(metric_key: str, title: str, yaxis_title: str) -> Optional[Dict[str, str]]: + train_series = _get_series(label_train, metric_key) + val_series = _get_series(label_val, metric_key) + if not train_series and not val_series: + return None + epochs_train = list(range(1, len(train_series) + 1)) + epochs_val = list(range(1, len(val_series) + 1)) + fig = go.Figure() + if train_series: + fig.add_trace( + go.Scatter( + x=epochs_train, + y=train_series, + mode="lines+markers", + name="Train", + line=dict(width=4), + ) + ) + if val_series: + fig.add_trace( + go.Scatter( + x=epochs_val, + y=val_series, + mode="lines+markers", + name="Validation", + line=dict(width=4), + ) + ) + fig.update_layout( + title=dict(text=title, x=0.5), + xaxis_title="Epoch", + yaxis_title=yaxis_title, + width=760, + height=520, + hovermode="x unified", + ) + _style_fig(fig) + return { + "title": title, + "html": pio.to_html( + fig, + full_html=False, + include_plotlyjs="cdn" if include_js else False, + ), + } + + # Core learning curves + for key, title in [ + ("roc_auc", "ROC-AUC across epochs"), + ("precision", "Precision across epochs"), + ("recall", "Recall/Sensitivity across epochs"), + ("specificity", "Specificity across epochs"), + ]: + plot = _line_plot(key, title, title.replace("Learning Curve", "").strip()) + if plot: + plots.append(plot) + include_js = False + + # Precision vs Recall evolution (validation) + val_prec = _get_series(label_val, "precision") + val_rec = _get_series(label_val, "recall") + if val_prec and val_rec: + epochs = list(range(1, min(len(val_prec), len(val_rec)) + 1)) + fig_pr = go.Figure() + fig_pr.add_trace( + go.Scatter( + x=epochs, + y=val_prec[: len(epochs)], + mode="lines+markers", + name="Precision", + ) + ) + fig_pr.add_trace( + go.Scatter( + x=epochs, + y=val_rec[: len(epochs)], + mode="lines+markers", + name="Recall", + ) + ) + fig_pr.update_layout( + title=dict(text="Validation Precision and Recall by Epoch", x=0.5), + xaxis_title="Epoch", + yaxis_title="Value", + width=760, + height=520, + hovermode="x unified", + ) + _style_fig(fig_pr) + plots.append({ + "title": "Precision vs Recall Evolution", + "html": pio.to_html( + fig_pr, + full_html=False, + include_plotlyjs="cdn" if include_js else False, + ), + }) + include_js = False + + # F1-score derived + def _compute_f1(p: List[float], r: List[float]) -> List[float]: + f1_vals = [] + for prec, rec in zip(p, r): + if (prec + rec) == 0: + f1_vals.append(0.0) + else: + f1_vals.append(2 * prec * rec / (prec + rec)) + return f1_vals + + f1_train = _compute_f1(_get_series(label_train, "precision"), _get_series(label_train, "recall")) + f1_val = _compute_f1(val_prec, val_rec) + if f1_train or f1_val: + fig = go.Figure() + if f1_train: + fig.add_trace(go.Scatter(x=list(range(1, len(f1_train) + 1)), y=f1_train, mode="lines+markers", name="Train", line=dict(width=4))) + if f1_val: + fig.add_trace(go.Scatter(x=list(range(1, len(f1_val) + 1)), y=f1_val, mode="lines+markers", name="Validation", line=dict(width=4))) + fig.update_layout( + title=dict(text="F1-Score across epochs (derived)", x=0.5), + xaxis_title="Epoch", + yaxis_title="F1-Score", + width=760, + height=520, + hovermode="x unified", + ) + _style_fig(fig) + plots.append({ + "title": "F1-Score across epochs (derived)", + "html": pio.to_html( + fig, + full_html=False, + include_plotlyjs="cdn" if include_js else False, + ), + }) + include_js = False + + # Overfitting Gap: Train vs Val ROC-AUC (gap) + roc_train = _get_series(label_train, "roc_auc") + roc_val = _get_series(label_val, "roc_auc") + if roc_train and roc_val: + epochs_gap = list(range(1, min(len(roc_train), len(roc_val)) + 1)) + gaps = [t - v for t, v in zip(roc_train[:len(epochs_gap)], roc_val[:len(epochs_gap)])] + fig_gap = go.Figure() + fig_gap.add_trace(go.Scatter(x=epochs_gap, y=gaps, mode="lines+markers", name="Train - Val ROC-AUC", line=dict(width=4))) + fig_gap.update_layout( + title=dict(text="Overfitting gap: ROC-AUC across epochs", x=0.5), + xaxis_title="Epoch", + yaxis_title="Gap", + width=760, + height=520, + hovermode="x unified", + ) + _style_fig(fig_gap) + plots.append({ + "title": "Overfitting gap: ROC-AUC across epochs", + "html": pio.to_html( + fig_gap, + full_html=False, + include_plotlyjs="cdn" if include_js else False, + ), + }) + include_js = False + + # Best Epoch Dashboard (based on max val ROC-AUC) + if roc_val: + best_idx = int(np.argmax(roc_val)) + best_epoch = best_idx + 1 + spec_val = _get_series(label_val, "specificity") + metrics_at_best = { + "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None, + "Precision": val_prec[best_idx] if best_idx < len(val_prec) else None, + "Recall": val_rec[best_idx] if best_idx < len(val_rec) else None, + "Specificity": spec_val[best_idx] if best_idx < len(spec_val) else None, + "F1-Score": f1_val[best_idx] if best_idx < len(f1_val) else None, + } + fig_best = go.Figure() + for name, value in metrics_at_best.items(): + if value is not None: + fig_best.add_trace(go.Bar(name=name, x=[name], y=[value])) + fig_best.update_layout( + title=dict(text=f"Best Epoch Dashboard (Val ROC-AUC @ epoch {best_epoch})", x=0.5), + xaxis_title="Metric", + yaxis_title="Value", + width=760, + height=520, + showlegend=False, + ) + _style_fig(fig_best) + plots.append({ + "title": "Best Validation Epoch Snapshot (Metrics)", + "html": pio.to_html( + fig_best, + full_html=False, + include_plotlyjs="cdn" if include_js else False, + ), + }) + include_js = False + return plots -def _build_roc_auc_plot(test_stats_path: str, class_labels: List[str], config: dict) -> Optional[Dict[str, str]]: - """ - Build an interactive ROC-AUC curve plot for multi-class classification. - Following sklearn's ROC example with micro-average and per-class curves. +def _get_regression_series(split_stats: dict, metric: str) -> List[float]: + if metric not in split_stats: + return [] + vals = split_stats.get(metric, []) + if isinstance(vals, list): + return [float(v) for v in vals] + try: + return [float(vals)] + except Exception: + return [] + - Args: - test_stats_path: Path to test_statistics.json - class_labels: List of class label names - config: Plotly config dict +def _regression_line_plot( + train_split: dict, + val_split: dict, + metric_key: str, + title: str, + yaxis_title: str, + include_js: bool, +) -> Optional[Dict[str, str]]: + train_series = _get_regression_series(train_split, metric_key) + val_series = _get_regression_series(val_split, metric_key) + if not train_series and not val_series: + return None + epochs_train = list(range(1, len(train_series) + 1)) + epochs_val = list(range(1, len(val_series) + 1)) + fig = go.Figure() + if train_series: + fig.add_trace( + go.Scatter( + x=epochs_train, + y=train_series, + mode="lines+markers", + name="Train", + line=dict(width=4), + ) + ) + if val_series: + fig.add_trace( + go.Scatter( + x=epochs_val, + y=val_series, + mode="lines+markers", + name="Validation", + line=dict(width=4), + ) + ) + fig.update_layout( + title=dict(text=title, x=0.5), + xaxis_title="Epoch", + yaxis_title=yaxis_title, + width=760, + height=520, + hovermode="x unified", + ) + _style_fig(fig) + return { + "title": title, + "html": pio.to_html( + fig, + full_html=False, + include_plotlyjs="cdn" if include_js else False, + ), + } + + +def build_regression_train_val_plots(train_stats_path: str) -> List[Dict[str, str]]: + """Generate regression Train/Validation learning curve plots from training_statistics.json.""" + if not train_stats_path or not Path(train_stats_path).exists(): + return [] + try: + with open(train_stats_path, "r") as f: + train_stats = json.load(f) + except Exception as exc: + print(f"Warning: Unable to read training statistics: {exc}") + return [] + + label_train = (train_stats.get("training") or {}).get("label", {}) + label_val = (train_stats.get("validation") or {}).get("label", {}) + if not label_train and not label_val: + return [] + + plots: List[Dict[str, str]] = [] + include_js = True + for metric_key, title, ytitle in [ + ("mean_absolute_error", "Mean Absolute Error across epochs", "MAE"), + ("root_mean_squared_error", "Root Mean Squared Error across epochs", "RMSE"), + ("mean_absolute_percentage_error", "Mean Absolute Percentage Error across epochs", "MAPE"), + ("r2", "R² across epochs", "R²"), + ("loss", "Loss across epochs", "Loss"), + ]: + plot = _regression_line_plot(label_train, label_val, metric_key, title, ytitle, include_js) + if plot: + plots.append(plot) + include_js = False + return plots + - Returns: - Dict with title and HTML, or None if data unavailable - """ +def build_regression_test_plots(train_stats_path: str) -> List[Dict[str, str]]: + """Generate regression Test learning curves from training_statistics.json.""" + if not train_stats_path or not Path(train_stats_path).exists(): + return [] try: - # Get the experiment directory from test_stats_path - exp_dir = Path(test_stats_path).parent + with open(train_stats_path, "r") as f: + train_stats = json.load(f) + except Exception as exc: + print(f"Warning: Unable to read training statistics: {exc}") + return [] + + label_test = (train_stats.get("test") or {}).get("label", {}) + if not label_test: + return [] - # Load predictions with probabilities - predictions_path = exp_dir / "predictions.csv" - if not predictions_path.exists(): - return None + plots: List[Dict[str, str]] = [] + include_js = True + metrics = [ + ("mean_absolute_error", "Mean Absolute Error Across Epochs", "MAE"), + ("root_mean_squared_error", "Root Mean Squared Error Across Epochs", "RMSE"), + ("mean_absolute_percentage_error", "Mean Absolute Percentage Error Across Epochs", "MAPE"), + ("r2", "R² Across Epochs", "R²"), + ("loss", "Loss Across Epochs", "Loss"), + ] + epochs = None + for metric_key, title, ytitle in metrics: + series = _get_regression_series(label_test, metric_key) + if not series: + continue + if epochs is None: + epochs = list(range(1, len(series) + 1)) + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=epochs, + y=series[: len(epochs)], + mode="lines+markers", + name="Test", + line=dict(width=4), + ) + ) + fig.update_layout( + title=dict(text=title, x=0.5), + xaxis_title="Epoch", + yaxis_title=ytitle, + width=760, + height=520, + hovermode="x unified", + ) + _style_fig(fig) + plots.append({ + "title": title, + "html": pio.to_html( + fig, + full_html=False, + include_plotlyjs="cdn" if include_js else False, + ), + }) + include_js = False + return plots - df_pred = pd.read_csv(predictions_path) + +def _build_static_roc_plot( + label_stats: dict, config: dict, friendly_labels: Optional[List[str]] = None +) -> Optional[Dict[str, str]]: + """Build ROC curve directly from test_statistics.json (single curve).""" + roc_data = label_stats.get("roc_curve") + if not isinstance(roc_data, dict): + return None + + fpr = roc_data.get("false_positive_rate") + tpr = roc_data.get("true_positive_rate") + if not fpr or not tpr or len(fpr) != len(tpr): + return None + + try: + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=fpr, + y=tpr, + mode="lines+markers", + name="ROC Curve", + line=dict(color="#1f77b4", width=4), + hovertemplate="FPR: %{x:.3f}
TPR: %{y:.3f}", + ) + ) + fig.add_trace( + go.Scatter( + x=[0, 1], + y=[0, 1], + mode="lines", + name="Random Classifier", + line=dict(color="gray", width=2, dash="dash"), + hovertemplate="Random Classifier", + ) + ) + + auc_val = label_stats.get("roc_auc") or label_stats.get("roc_auc_macro") or label_stats.get("roc_auc_micro") + auc_txt = f" (AUC = {auc_val:.3f})" if isinstance(auc_val, (int, float)) else "" - if SPLIT_COLUMN_NAME in df_pred.columns: - split_series = df_pred[SPLIT_COLUMN_NAME].astype(str).str.lower() - test_mask = split_series.isin({"2", "test", "testing"}) - if test_mask.any(): - df_pred = df_pred[test_mask].reset_index(drop=True) + # Determine which label is treated as positive for the curve + label_list: List = [] + pcs = label_stats.get("per_class_stats", {}) + if pcs: + label_list = list(pcs.keys()) + if not label_list: + labels_from_stats = label_stats.get("labels") + if isinstance(labels_from_stats, list): + label_list = labels_from_stats + + # Try to resolve index of the positive label explicitly provided by Ludwig + pos_label_raw = ( + roc_data.get("positive_label") + or roc_data.get("positive_class") + or label_stats.get("positive_label") + ) + pos_label_idx = None + if pos_label_raw is not None and isinstance(label_list, list): + try: + pos_label_idx = label_list.index(pos_label_raw) + except ValueError: + pos_label_idx = None + + # Fallback: use the second label if available, otherwise the first + if pos_label_idx is None: + if isinstance(label_list, list) and len(label_list) >= 2: + pos_label_idx = 1 + elif isinstance(label_list, list) and label_list: + pos_label_idx = 0 + + if pos_label_raw is None and isinstance(label_list, list) and pos_label_idx is not None: + pos_label_raw = label_list[pos_label_idx] + + # Map to friendly label if we have one from metadata/CSV + pos_label_display = pos_label_raw + if ( + friendly_labels + and isinstance(pos_label_idx, int) + and 0 <= pos_label_idx < len(friendly_labels) + ): + pos_label_display = friendly_labels[pos_label_idx] + + pos_label_txt = ( + f"Positive class: {pos_label_display}" + if pos_label_display is not None + else "Positive class: (not available)" + ) + + title_label = f"ROC Curve{auc_txt}" + if pos_label_display is not None: + title_label = f"ROC Curve (Positive Class: {pos_label_display}){auc_txt}" - if df_pred.empty: - return None + fig.update_layout( + title=dict(text=title_label, x=0.5), + xaxis_title="False Positive Rate", + yaxis_title="True Positive Rate", + width=700, + height=600, + margin=dict(t=80, l=80, r=80, b=110), + hovermode="closest", + legend=dict( + x=0.6, + y=0.1, + bgcolor="rgba(255,255,255,0.9)", + bordercolor="rgba(0,0,0,0.2)", + borderwidth=1, + ), + ) + _style_fig(fig) + fig.update_xaxes(range=[0, 1.0]) + fig.update_yaxes(range=[0, 1.05]) - # Extract probability columns (label_probabilities_0, label_probabilities_1, etc.) - # or label_probabilities_ for string labels - prob_cols = [col for col in df_pred.columns if col.startswith('label_probabilities_') and col != 'label_probabilities'] + fig.add_annotation( + x=0.5, + y=-0.15, + xref="paper", + yref="paper", + showarrow=False, + text=f"{pos_label_txt}", + xanchor="center", + ) + + return { + "title": "ROC Curve", + "html": pio.to_html( + fig, + full_html=False, + include_plotlyjs=False, + config=config, + ), + } + except Exception as e: + print(f"Error building ROC plot: {e}") + return None + + +def _build_precision_recall_plot(label_stats: dict, config: dict) -> Optional[Dict[str, str]]: + """Build Precision-Recall curve directly from test_statistics.json.""" + pr_data = label_stats.get("precision_recall_curve") + if not isinstance(pr_data, dict): + return None + + precisions = pr_data.get("precisions") + recalls = pr_data.get("recalls") + if not precisions or not recalls or len(precisions) != len(recalls): + return None - # Sort by class number if numeric, otherwise keep alphabetical order - if prob_cols and prob_cols[0].split('_')[-1].isdigit(): - prob_cols.sort(key=lambda x: int(x.split('_')[-1])) - else: - prob_cols.sort() # Alphabetical sort for string class names + try: + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=recalls, + y=precisions, + mode="lines+markers", + name="Precision-Recall", + line=dict(color="#d62728", width=4), + hovertemplate="Recall: %{x:.3f}
Precision: %{y:.3f}", + ) + ) + + ap_val = ( + label_stats.get("average_precision_macro") + or label_stats.get("average_precision_micro") + or label_stats.get("average_precision_samples") + ) + ap_txt = f" (AP = {ap_val:.3f})" if isinstance(ap_val, (int, float)) else "" + + fig.update_layout( + title=dict(text=f"Precision-Recall Curve{ap_txt}", x=0.5), + xaxis_title="Recall", + yaxis_title="Precision", + width=700, + height=600, + margin=dict(t=80, l=80, r=80, b=80), + hovermode="closest", + legend=dict( + x=0.6, + y=0.1, + bgcolor="rgba(255,255,255,0.9)", + bordercolor="rgba(0,0,0,0.2)", + borderwidth=1, + ), + ) + _style_fig(fig) + fig.update_xaxes(range=[0, 1.0]) + fig.update_yaxes(range=[0, 1.05]) + + return { + "title": "Precision-Recall Curve", + "html": pio.to_html( + fig, + full_html=False, + include_plotlyjs=False, + config=config, + ), + } + except Exception as e: + print(f"Error building Precision-Recall plot: {e}") + return None + - if not prob_cols: - return None +def build_prediction_diagnostics( + predictions_path: str, + label_data_path: Optional[str] = None, + split_value: int = 2, + threshold: Optional[float] = None, +) -> List[Dict[str, str]]: + """Generate diagnostic plots from predictions.csv for classification tasks.""" + preds_file = Path(predictions_path) + if not preds_file.exists(): + return [] + + try: + df_pred = pd.read_csv(predictions_path) + except Exception as exc: + print(f"Warning: Unable to read predictions CSV: {exc}") + return [] + + plots: List[Dict[str, str]] = [] + + # Identify probability columns + prob_cols = [ + c for c in df_pred.columns + if c.startswith("label_probabilities_") and c != "label_probabilities" + ] + prob_cols_sorted = sorted(prob_cols) - # Get probabilities matrix (n_samples x n_classes) - y_score = df_pred[prob_cols].values - n_classes = len(prob_cols) + def _select_positive_prob(): + if not prob_cols_sorted: + return None, None + # Prefer a column indicating positive/event/true/1 + preferred_keys = ("event", "true", "positive", "pos", "1") + for col in prob_cols_sorted: + suffix = col.replace("label_probabilities_", "").lower() + if any(k in suffix for k in preferred_keys): + return col, suffix + if len(prob_cols_sorted) == 2: + col = prob_cols_sorted[1] + return col, col.replace("label_probabilities_", "") + col = prob_cols_sorted[0] + return col, col.replace("label_probabilities_", "") - y_true = None - candidate_cols = [ + pos_prob_col, pos_label_hint = _select_positive_prob() + pos_prob_series = df_pred[pos_prob_col] if pos_prob_col and pos_prob_col in df_pred else None + + # Confidence series: prefer label_probability, otherwise positive prob, otherwise max prob + confidence_series = None + if "label_probability" in df_pred.columns: + confidence_series = df_pred["label_probability"] + elif pos_prob_series is not None: + confidence_series = pos_prob_series + elif prob_cols_sorted: + confidence_series = df_pred[prob_cols_sorted].max(axis=1) + + # True labels + def _extract_labels(): + candidates = [ LABEL_COLUMN_NAME, f"{LABEL_COLUMN_NAME}_ground_truth", f"{LABEL_COLUMN_NAME}__ground_truth", f"{LABEL_COLUMN_NAME}_target", f"{LABEL_COLUMN_NAME}__target", + "label", + "label_true", ] - candidate_cols.extend( + candidates.extend( [ col for col in df_pred.columns @@ -230,174 +938,182 @@ and "predictions" not in col ] ) - for col in candidate_cols: - if col in df_pred.columns and col not in prob_cols: - y_true = df_pred[col].values - break + for col in candidates: + if col in df_pred.columns and col not in prob_cols_sorted: + return df_pred[col] + if label_data_path and Path(label_data_path).exists(): + try: + df_all = pd.read_csv(label_data_path) + if SPLIT_COLUMN_NAME in df_all.columns: + df_all = df_all[df_all[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True) + if LABEL_COLUMN_NAME in df_all.columns: + return df_all[LABEL_COLUMN_NAME].reset_index(drop=True) + except Exception as exc: + print(f"Warning: Unable to load labels from dataset: {exc}") + return None - if y_true is None: - desc_path = exp_dir / "description.json" - if desc_path.exists(): - try: - with open(desc_path, 'r') as f: - desc = json.load(f) - dataset_path = desc.get('dataset', '') - if dataset_path and Path(dataset_path).exists(): - df_orig = pd.read_csv(dataset_path) - if SPLIT_COLUMN_NAME in df_orig.columns: - df_orig = df_orig[df_orig[SPLIT_COLUMN_NAME] == 2].reset_index(drop=True) - if LABEL_COLUMN_NAME in df_orig.columns: - y_true = df_orig[LABEL_COLUMN_NAME].values - if len(y_true) != len(df_pred): - print( - f"Warning: Test set size mismatch. Truncating to {len(df_pred)} samples for ROC plot." - ) - y_true = y_true[:len(df_pred)] - else: - print("Warning: Original dataset referenced in description.json is unavailable.") - except Exception as exc: # pragma: no cover - defensive - print(f"Warning: Failed to recover labels from dataset: {exc}") - - if y_true is None or len(y_true) == 0: - print("Warning: Unable to locate ground-truth labels for ROC plot.") - return None - - if len(y_true) != len(y_score): - limit = min(len(y_true), len(y_score)) - if limit == 0: - return None - print(f"Warning: Aligning prediction and label lengths to {limit} samples for ROC plot.") - y_true = y_true[:limit] - y_score = y_score[:limit] + labels_series = _extract_labels() - # Get actual class names from probability column names - actual_classes = [col.replace('label_probabilities_', '') for col in prob_cols] - display_classes = class_labels if len(class_labels) == n_classes else actual_classes - - # Binarize the output following sklearn example - # Use actual class names if they're strings, otherwise use range - if isinstance(y_true[0], str): - y_test = label_binarize(y_true, classes=actual_classes) - else: - y_test = label_binarize(y_true, classes=list(range(n_classes))) - - # Handle binary classification case - if y_test.ndim != 2: - y_test = np.atleast_2d(y_test) + # Plot 1: Confidence Histogram + if confidence_series is not None: + fig_conf = go.Figure() + fig_conf.add_trace( + go.Histogram( + x=confidence_series, + nbinsx=20, + marker=dict(color="#1f77b4", line=dict(color="#ffffff", width=1)), + opacity=0.8, + histnorm="percent", + ) + ) + fig_conf.update_layout( + title=dict(text="Prediction Confidence Distribution", x=0.5), + xaxis_title="Predicted probability (confidence)", + yaxis_title="Percentage (%)", + bargap=0.05, + width=700, + height=500, + ) + _style_fig(fig_conf) + plots.append({ + "title": "Prediction Confidence Distribution", + "html": pio.to_html(fig_conf, full_html=False, include_plotlyjs=False), + }) - if n_classes == 2: - if y_test.shape[1] == 1: - y_test = np.hstack([1 - y_test, y_test]) - elif y_test.shape[1] != 2: - print("Warning: Unexpected label binarization shape for binary ROC plot.") - return None - elif y_test.shape[1] != n_classes: - print("Warning: Label binarization did not produce expected class dimension; skipping ROC plot.") - return None + # The remaining plots require true labels and a positive-class probability + if labels_series is None or pos_prob_series is None: + return plots + + # Align lengths + min_len = min(len(labels_series), len(pos_prob_series)) + if min_len == 0: + return plots + y_true_raw = labels_series.iloc[:min_len] + y_score = np.array(pos_prob_series.iloc[:min_len], dtype=float) - # Compute ROC curve and ROC area for each class (following sklearn example) - fpr = dict() - tpr = dict() - roc_auc = dict() + # Determine positive label + unique_labels = pd.unique(y_true_raw) + unique_labels_list = list(unique_labels) + positive_label = None + if pos_label_hint and str(pos_label_hint) in [str(u) for u in unique_labels_list]: + positive_label = pos_label_hint + elif len(unique_labels_list) == 2: + positive_label = unique_labels_list[1] + else: + positive_label = unique_labels_list[0] - for i in range(n_classes): - if np.sum(y_test[:, i]) > 0: # Check if class exists in test set - fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i]) - roc_auc[i] = auc(fpr[i], tpr[i]) - - # Compute micro-average ROC curve and ROC area (sklearn example) - fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel()) - roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) - - # Create ROC curve plot - fig_roc = go.Figure() + y_true = (y_true_raw == positive_label).astype(int).values - # Colors for different classes - colors = [ - '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', - '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf' - ] - - # Plot micro-average ROC curve first (most important) - fig_roc.add_trace(go.Scatter( - x=fpr["micro"], - y=tpr["micro"], - mode='lines', - name=f'Micro-average ROC (AUC = {roc_auc["micro"]:.3f})', - line=dict(color='deeppink', width=3, dash='dot'), - hovertemplate=('Micro-average ROC
' - 'FPR: %{x:.3f}
' - 'TPR: %{y:.3f}
' - f'AUC: {roc_auc["micro"]:.3f}') - )) - - # Plot ROC curve for each class - for i in range(n_classes): - if i in roc_auc: # Only plot if class exists in test set - class_name = display_classes[i] if i < len(display_classes) else f"Class {i}" - color = colors[i % len(colors)] - - fig_roc.add_trace(go.Scatter( - x=fpr[i], - y=tpr[i], - mode='lines', - name=f'{class_name} (AUC = {roc_auc[i]:.3f})', - line=dict(color=color, width=2), - hovertemplate=(f'{class_name}
' - 'FPR: %{x:.3f}
' - 'TPR: %{y:.3f}
' - f'AUC: {roc_auc[i]:.3f}') - )) + # Plot 2: Calibration Curve + bins = np.linspace(0.0, 1.0, 11) + bin_ids = np.digitize(y_score, bins, right=True) + bin_centers = [] + frac_positives = [] + for b in range(1, len(bins)): + mask = bin_ids == b + if not np.any(mask): + continue + bin_centers.append(y_score[mask].mean()) + frac_positives.append(y_true[mask].mean()) + if bin_centers and frac_positives: + fig_cal = go.Figure() + fig_cal.add_trace( + go.Scatter( + x=bin_centers, + y=frac_positives, + mode="lines+markers", + name="Calibration", + line=dict(color="#2ca02c", width=4), + ) + ) + fig_cal.add_trace( + go.Scatter( + x=[0, 1], + y=[0, 1], + mode="lines", + name="Perfect Calibration", + line=dict(color="gray", width=2, dash="dash"), + ) + ) + fig_cal.update_layout( + title=dict(text="Calibration Curve", x=0.5), + xaxis_title="Predicted probability", + yaxis_title="Observed frequency", + width=700, + height=500, + ) + _style_fig(fig_cal) + plots.append({ + "title": "Calibration Curve (Predicted Probability vs Observed Frequency)", + "html": pio.to_html(fig_cal, full_html=False, include_plotlyjs=False), + }) - # Add diagonal line (random classifier) - fig_roc.add_trace(go.Scatter( - x=[0, 1], - y=[0, 1], - mode='lines', - name='Random Classifier', - line=dict(color='gray', width=1, dash='dash'), - hovertemplate='Random Classifier
AUC = 0.500' - )) - - # Calculate macro-average AUC - class_aucs = [roc_auc[i] for i in range(n_classes) if i in roc_auc] - if class_aucs: - macro_auc = np.mean(class_aucs) - title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f}, Macro-avg = {macro_auc:.3f})" - else: - title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f})" + # Plot 3: Threshold vs Metrics + thresholds = np.linspace(0.0, 1.0, 21) + accs, f1s, sens, specs = [], [], [], [] + for t in thresholds: + y_pred = (y_score >= t).astype(int) + tp = np.sum((y_true == 1) & (y_pred == 1)) + tn = np.sum((y_true == 0) & (y_pred == 0)) + fp = np.sum((y_true == 0) & (y_pred == 1)) + fn = np.sum((y_true == 1) & (y_pred == 0)) + acc = (tp + tn) / max(len(y_true), 1) + prec = tp / max(tp + fp, 1e-9) + rec = tp / max(tp + fn, 1e-9) + f1 = 0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec) + sensitivity = rec + specificity = tn / max(tn + fp, 1e-9) + accs.append(acc) + f1s.append(f1) + sens.append(sensitivity) + specs.append(specificity) - fig_roc.update_layout( - title=dict(text=title_text, x=0.5), - xaxis_title="False Positive Rate", - yaxis_title="True Positive Rate", - width=700, - height=600, - margin=dict(t=80, l=80, r=80, b=80), - legend=dict( - x=0.6, - y=0.1, - bgcolor="rgba(255,255,255,0.9)", - bordercolor="rgba(0,0,0,0.2)", - borderwidth=1 - ), - hovermode='closest' - ) + fig_thresh = go.Figure() + fig_thresh.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4))) + fig_thresh.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1", line=dict(width=4))) + fig_thresh.add_trace(go.Scatter(x=thresholds, y=sens, mode="lines", name="Sensitivity", line=dict(width=4))) + fig_thresh.add_trace(go.Scatter(x=thresholds, y=specs, mode="lines", name="Specificity", line=dict(width=4))) + fig_thresh.update_layout( + title=dict(text="Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", x=0.5), + xaxis_title="Decision threshold", + yaxis_title="Metric value", + width=700, + height=500, + legend=dict( + x=0.7, + y=0.2, + bgcolor="rgba(255,255,255,0.9)", + bordercolor="rgba(0,0,0,0.2)", + borderwidth=1, + ), + shapes=[ + dict( + type="line", + x0=threshold, + x1=threshold, + y0=0, + y1=1, + xref="x", + yref="paper", + line=dict(color="#d62728", width=2, dash="dash"), + ) + ] if isinstance(threshold, (int, float)) else [], + annotations=[ + dict( + x=threshold, + y=1.02, + xref="x", + yref="paper", + showarrow=False, + text=f"Threshold = {threshold:.2f}", + font=dict(size=11, color="#d62728"), + ) + ] if isinstance(threshold, (int, float)) else [], + ) + _style_fig(fig_thresh) + plots.append({ + "title": "Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", + "html": pio.to_html(fig_thresh, full_html=False, include_plotlyjs=False), + }) - # Set equal aspect ratio and proper range - fig_roc.update_xaxes(range=[0, 1.0]) - fig_roc.update_yaxes(range=[0, 1.05]) - - return { - "title": "ROC-AUC Curves", - "html": pio.to_html( - fig_roc, - full_html=False, - include_plotlyjs=False, - config=config - ) - } - - except Exception as e: - print(f"Error building ROC-AUC plot: {e}") - return None + return plots