Mercurial > repos > goeckslab > image_learner
diff plotly_plots.py @ 15:d17e3a1b8659 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
| author | goeckslab |
|---|---|
| date | Fri, 28 Nov 2025 15:45:49 +0000 |
| parents | c5150cceab47 |
| children |
line wrap: on
line diff
--- a/plotly_plots.py Wed Nov 26 22:00:32 2025 +0000 +++ b/plotly_plots.py Fri Nov 28 15:45:49 2025 +0000 @@ -7,13 +7,105 @@ import plotly.graph_objects as go import plotly.io as pio from constants import LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME -from sklearn.metrics import auc, roc_curve -from sklearn.preprocessing import label_binarize + + +def _style_fig(fig: go.Figure, font_size: int = 12) -> go.Figure: + """Apply consistent styling across Plotly figures.""" + fig.update_layout( + font=dict(size=font_size), + plot_bgcolor="#ffffff", + paper_bgcolor="#ffffff", + ) + fig.update_xaxes(gridcolor="#e8e8e8") + fig.update_yaxes(gridcolor="#e8e8e8") + return fig + + +def _labels_from_metadata_dict(meta_dict: dict) -> List[str]: + """Extract ordered label names from Ludwig train_set_metadata.""" + if not isinstance(meta_dict, dict): + return [] + + for key in ("idx2str", "idx2label", "vocab"): + seq = meta_dict.get(key) + if isinstance(seq, list) and seq: + return [str(v) for v in seq] + + str2idx = meta_dict.get("str2idx") + if isinstance(str2idx, dict) and str2idx: + int_indices = [v for v in str2idx.values() if isinstance(v, int)] + if int_indices: + max_idx = max(int_indices) + ordered = [None] * (max_idx + 1) + for name, idx in str2idx.items(): + if isinstance(idx, int) and 0 <= idx < len(ordered): + ordered[idx] = name + return [str(v) for v in ordered if v is not None] + + return [] + + +def _resolve_confusion_labels( + label_stats: dict, + n_classes: int, + metadata_csv_path: Optional[str], + train_set_metadata_path: Optional[str], +) -> List[str]: + """Prefer original labels from metadata; fall back to stats if unavailable.""" + if train_set_metadata_path: + try: + meta_path = Path(train_set_metadata_path) + if meta_path.exists(): + with open(meta_path, "r") as f: + meta_json = json.load(f) + label_meta = meta_json.get(LABEL_COLUMN_NAME) + if not isinstance(label_meta, dict): + label_meta = next( + ( + v + for v in meta_json.values() + if isinstance(v, dict) + and any(k in v for k in ("idx2str", "str2idx", "idx2label", "vocab")) + ), + None, + ) + labels_from_meta = _labels_from_metadata_dict(label_meta) if label_meta else [] + if labels_from_meta and len(labels_from_meta) >= n_classes: + return [str(label) for label in labels_from_meta[:n_classes]] + except Exception as exc: + print(f"Warning: Unable to read labels from train_set_metadata: {exc}") + + if metadata_csv_path: + try: + csv_path = Path(metadata_csv_path) + if csv_path.exists(): + df_meta = pd.read_csv(csv_path) + if LABEL_COLUMN_NAME in df_meta.columns: + uniques = df_meta[LABEL_COLUMN_NAME].dropna().unique().tolist() + if uniques and len(uniques) >= n_classes: + return [str(u) for u in uniques[:n_classes]] + except Exception as exc: + print(f"Warning: Unable to read labels from metadata CSV: {exc}") + + pcs = label_stats.get("per_class_stats", {}) + if pcs: + pcs_labels = [str(k) for k in pcs.keys()] + if len(pcs_labels) >= n_classes: + return pcs_labels[:n_classes] + + labels = label_stats.get("labels") + if not labels: + labels = [str(i) for i in range(n_classes)] + if len(labels) < n_classes: + labels = labels + [str(i) for i in range(len(labels), n_classes)] + return [str(label) for label in labels[:n_classes]] def build_classification_plots( test_stats_path: str, training_stats_path: Optional[str] = None, + metadata_csv_path: Optional[str] = None, + train_set_metadata_path: Optional[str] = None, ) -> List[Dict[str, str]]: """ Read Ludwig’s test_statistics.json and build three interactive Plotly panels: @@ -21,6 +113,9 @@ - ROC-AUC - Classification Report Heatmap + If metadata paths are provided, the confusion matrix axes will use the original + label values from the training metadata rather than integer-encoded labels. + Returns a list of dicts, each with: { "title": <plot title>, @@ -42,12 +137,12 @@ # 0) Confusion Matrix cm = np.array(label_stats["confusion_matrix"], dtype=int) - # Try to get actual class names from per_class_stats keys (which contain the real labels) - pcs = label_stats.get("per_class_stats", {}) - if pcs: - labels = list(pcs.keys()) - else: - labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])]) + labels = _resolve_confusion_labels( + label_stats, + n_classes, + metadata_csv_path=metadata_csv_path, + train_set_metadata_path=train_set_metadata_path, + ) total = cm.sum() fig_cm = go.Figure( @@ -70,6 +165,7 @@ height=side_px, margin=dict(t=100, l=80, r=80, b=80), ) + _style_fig(fig_cm) # annotate counts and percentages mval = cm.max() if cm.size else 0 @@ -110,16 +206,28 @@ ) }) - # 1) ROC-AUC Curves (Multi-class) - roc_plot = _build_roc_auc_plot(test_stats_path, labels, common_cfg) + # 1) ROC Curve (from test_statistics) + roc_plot = _build_static_roc_plot(label_stats, common_cfg, friendly_labels=labels) if roc_plot: plots.append(roc_plot) + # 2) Precision-Recall Curve (from test_statistics) + pr_plot = _build_precision_recall_plot(label_stats, common_cfg) + if pr_plot: + plots.append(pr_plot) + # 2) Classification Report Heatmap pcs = label_stats.get("per_class_stats", {}) if pcs: classes = list(pcs.keys()) - metrics = ["precision", "recall", "f1_score"] + metrics = [ + "precision", + "recall", + "f1_score", + "accuracy", + "matthews_correlation_coefficient", + "specificity", + ] z, txt = [], [] for c in classes: row, trow = [], [] @@ -133,7 +241,7 @@ fig_cr = go.Figure( go.Heatmap( z=z, - x=metrics, + x=[m.replace("_", " ") for m in metrics], y=[str(c) for c in classes], text=txt, texttemplate="%{text}", @@ -143,15 +251,16 @@ ) ) fig_cr.update_layout( - title="Classification Report", + title="Per-Class metrics", xaxis_title="", yaxis_title="Class", width=side_px, height=side_px, margin=dict(t=80, l=80, r=80, b=80), ) + _style_fig(fig_cr) plots.append({ - "title": "Classification Report", + "title": "Per-Class metrics", "html": pio.to_html( fig_cr, full_html=False, @@ -160,68 +269,667 @@ ) }) + # 3) Prediction Diagnostics (from predictions.csv) + # Note: appended separately in generate_html_report, not returned here. + + return plots + + +def build_train_validation_plots(train_stats_path: str) -> List[Dict[str, str]]: + """Generate Train/Validation learning curve plots from training_statistics.json.""" + if not train_stats_path or not Path(train_stats_path).exists(): + return [] + try: + with open(train_stats_path, "r") as f: + train_stats = json.load(f) + except Exception as exc: + print(f"Warning: Unable to read training statistics: {exc}") + return [] + + label_train = (train_stats.get("training") or {}).get("label", {}) + label_val = (train_stats.get("validation") or {}).get("label", {}) + if not label_train and not label_val: + return [] + plots: List[Dict[str, str]] = [] + include_js = True # Load Plotly.js once for this group + + def _get_series(stats: dict, metric: str) -> List[float]: + if metric not in stats: + return [] + vals = stats.get(metric, []) + if isinstance(vals, list): + return [float(v) for v in vals] + try: + return [float(vals)] + except Exception: + return [] + + def _line_plot(metric_key: str, title: str, yaxis_title: str) -> Optional[Dict[str, str]]: + train_series = _get_series(label_train, metric_key) + val_series = _get_series(label_val, metric_key) + if not train_series and not val_series: + return None + epochs_train = list(range(1, len(train_series) + 1)) + epochs_val = list(range(1, len(val_series) + 1)) + fig = go.Figure() + if train_series: + fig.add_trace( + go.Scatter( + x=epochs_train, + y=train_series, + mode="lines+markers", + name="Train", + line=dict(width=4), + ) + ) + if val_series: + fig.add_trace( + go.Scatter( + x=epochs_val, + y=val_series, + mode="lines+markers", + name="Validation", + line=dict(width=4), + ) + ) + fig.update_layout( + title=dict(text=title, x=0.5), + xaxis_title="Epoch", + yaxis_title=yaxis_title, + width=760, + height=520, + hovermode="x unified", + ) + _style_fig(fig) + return { + "title": title, + "html": pio.to_html( + fig, + full_html=False, + include_plotlyjs="cdn" if include_js else False, + ), + } + + # Core learning curves + for key, title in [ + ("roc_auc", "ROC-AUC across epochs"), + ("precision", "Precision across epochs"), + ("recall", "Recall/Sensitivity across epochs"), + ("specificity", "Specificity across epochs"), + ]: + plot = _line_plot(key, title, title.replace("Learning Curve", "").strip()) + if plot: + plots.append(plot) + include_js = False + + # Precision vs Recall evolution (validation) + val_prec = _get_series(label_val, "precision") + val_rec = _get_series(label_val, "recall") + if val_prec and val_rec: + epochs = list(range(1, min(len(val_prec), len(val_rec)) + 1)) + fig_pr = go.Figure() + fig_pr.add_trace( + go.Scatter( + x=epochs, + y=val_prec[: len(epochs)], + mode="lines+markers", + name="Precision", + ) + ) + fig_pr.add_trace( + go.Scatter( + x=epochs, + y=val_rec[: len(epochs)], + mode="lines+markers", + name="Recall", + ) + ) + fig_pr.update_layout( + title=dict(text="Validation Precision and Recall by Epoch", x=0.5), + xaxis_title="Epoch", + yaxis_title="Value", + width=760, + height=520, + hovermode="x unified", + ) + _style_fig(fig_pr) + plots.append({ + "title": "Precision vs Recall Evolution", + "html": pio.to_html( + fig_pr, + full_html=False, + include_plotlyjs="cdn" if include_js else False, + ), + }) + include_js = False + + # F1-score derived + def _compute_f1(p: List[float], r: List[float]) -> List[float]: + f1_vals = [] + for prec, rec in zip(p, r): + if (prec + rec) == 0: + f1_vals.append(0.0) + else: + f1_vals.append(2 * prec * rec / (prec + rec)) + return f1_vals + + f1_train = _compute_f1(_get_series(label_train, "precision"), _get_series(label_train, "recall")) + f1_val = _compute_f1(val_prec, val_rec) + if f1_train or f1_val: + fig = go.Figure() + if f1_train: + fig.add_trace(go.Scatter(x=list(range(1, len(f1_train) + 1)), y=f1_train, mode="lines+markers", name="Train", line=dict(width=4))) + if f1_val: + fig.add_trace(go.Scatter(x=list(range(1, len(f1_val) + 1)), y=f1_val, mode="lines+markers", name="Validation", line=dict(width=4))) + fig.update_layout( + title=dict(text="F1-Score across epochs (derived)", x=0.5), + xaxis_title="Epoch", + yaxis_title="F1-Score", + width=760, + height=520, + hovermode="x unified", + ) + _style_fig(fig) + plots.append({ + "title": "F1-Score across epochs (derived)", + "html": pio.to_html( + fig, + full_html=False, + include_plotlyjs="cdn" if include_js else False, + ), + }) + include_js = False + + # Overfitting Gap: Train vs Val ROC-AUC (gap) + roc_train = _get_series(label_train, "roc_auc") + roc_val = _get_series(label_val, "roc_auc") + if roc_train and roc_val: + epochs_gap = list(range(1, min(len(roc_train), len(roc_val)) + 1)) + gaps = [t - v for t, v in zip(roc_train[:len(epochs_gap)], roc_val[:len(epochs_gap)])] + fig_gap = go.Figure() + fig_gap.add_trace(go.Scatter(x=epochs_gap, y=gaps, mode="lines+markers", name="Train - Val ROC-AUC", line=dict(width=4))) + fig_gap.update_layout( + title=dict(text="Overfitting gap: ROC-AUC across epochs", x=0.5), + xaxis_title="Epoch", + yaxis_title="Gap", + width=760, + height=520, + hovermode="x unified", + ) + _style_fig(fig_gap) + plots.append({ + "title": "Overfitting gap: ROC-AUC across epochs", + "html": pio.to_html( + fig_gap, + full_html=False, + include_plotlyjs="cdn" if include_js else False, + ), + }) + include_js = False + + # Best Epoch Dashboard (based on max val ROC-AUC) + if roc_val: + best_idx = int(np.argmax(roc_val)) + best_epoch = best_idx + 1 + spec_val = _get_series(label_val, "specificity") + metrics_at_best = { + "ROC-AUC": roc_val[best_idx] if best_idx < len(roc_val) else None, + "Precision": val_prec[best_idx] if best_idx < len(val_prec) else None, + "Recall": val_rec[best_idx] if best_idx < len(val_rec) else None, + "Specificity": spec_val[best_idx] if best_idx < len(spec_val) else None, + "F1-Score": f1_val[best_idx] if best_idx < len(f1_val) else None, + } + fig_best = go.Figure() + for name, value in metrics_at_best.items(): + if value is not None: + fig_best.add_trace(go.Bar(name=name, x=[name], y=[value])) + fig_best.update_layout( + title=dict(text=f"Best Epoch Dashboard (Val ROC-AUC @ epoch {best_epoch})", x=0.5), + xaxis_title="Metric", + yaxis_title="Value", + width=760, + height=520, + showlegend=False, + ) + _style_fig(fig_best) + plots.append({ + "title": "Best Validation Epoch Snapshot (Metrics)", + "html": pio.to_html( + fig_best, + full_html=False, + include_plotlyjs="cdn" if include_js else False, + ), + }) + include_js = False + return plots -def _build_roc_auc_plot(test_stats_path: str, class_labels: List[str], config: dict) -> Optional[Dict[str, str]]: - """ - Build an interactive ROC-AUC curve plot for multi-class classification. - Following sklearn's ROC example with micro-average and per-class curves. +def _get_regression_series(split_stats: dict, metric: str) -> List[float]: + if metric not in split_stats: + return [] + vals = split_stats.get(metric, []) + if isinstance(vals, list): + return [float(v) for v in vals] + try: + return [float(vals)] + except Exception: + return [] + - Args: - test_stats_path: Path to test_statistics.json - class_labels: List of class label names - config: Plotly config dict +def _regression_line_plot( + train_split: dict, + val_split: dict, + metric_key: str, + title: str, + yaxis_title: str, + include_js: bool, +) -> Optional[Dict[str, str]]: + train_series = _get_regression_series(train_split, metric_key) + val_series = _get_regression_series(val_split, metric_key) + if not train_series and not val_series: + return None + epochs_train = list(range(1, len(train_series) + 1)) + epochs_val = list(range(1, len(val_series) + 1)) + fig = go.Figure() + if train_series: + fig.add_trace( + go.Scatter( + x=epochs_train, + y=train_series, + mode="lines+markers", + name="Train", + line=dict(width=4), + ) + ) + if val_series: + fig.add_trace( + go.Scatter( + x=epochs_val, + y=val_series, + mode="lines+markers", + name="Validation", + line=dict(width=4), + ) + ) + fig.update_layout( + title=dict(text=title, x=0.5), + xaxis_title="Epoch", + yaxis_title=yaxis_title, + width=760, + height=520, + hovermode="x unified", + ) + _style_fig(fig) + return { + "title": title, + "html": pio.to_html( + fig, + full_html=False, + include_plotlyjs="cdn" if include_js else False, + ), + } + + +def build_regression_train_val_plots(train_stats_path: str) -> List[Dict[str, str]]: + """Generate regression Train/Validation learning curve plots from training_statistics.json.""" + if not train_stats_path or not Path(train_stats_path).exists(): + return [] + try: + with open(train_stats_path, "r") as f: + train_stats = json.load(f) + except Exception as exc: + print(f"Warning: Unable to read training statistics: {exc}") + return [] + + label_train = (train_stats.get("training") or {}).get("label", {}) + label_val = (train_stats.get("validation") or {}).get("label", {}) + if not label_train and not label_val: + return [] + + plots: List[Dict[str, str]] = [] + include_js = True + for metric_key, title, ytitle in [ + ("mean_absolute_error", "Mean Absolute Error across epochs", "MAE"), + ("root_mean_squared_error", "Root Mean Squared Error across epochs", "RMSE"), + ("mean_absolute_percentage_error", "Mean Absolute Percentage Error across epochs", "MAPE"), + ("r2", "R² across epochs", "R²"), + ("loss", "Loss across epochs", "Loss"), + ]: + plot = _regression_line_plot(label_train, label_val, metric_key, title, ytitle, include_js) + if plot: + plots.append(plot) + include_js = False + return plots + - Returns: - Dict with title and HTML, or None if data unavailable - """ +def build_regression_test_plots(train_stats_path: str) -> List[Dict[str, str]]: + """Generate regression Test learning curves from training_statistics.json.""" + if not train_stats_path or not Path(train_stats_path).exists(): + return [] try: - # Get the experiment directory from test_stats_path - exp_dir = Path(test_stats_path).parent + with open(train_stats_path, "r") as f: + train_stats = json.load(f) + except Exception as exc: + print(f"Warning: Unable to read training statistics: {exc}") + return [] + + label_test = (train_stats.get("test") or {}).get("label", {}) + if not label_test: + return [] - # Load predictions with probabilities - predictions_path = exp_dir / "predictions.csv" - if not predictions_path.exists(): - return None + plots: List[Dict[str, str]] = [] + include_js = True + metrics = [ + ("mean_absolute_error", "Mean Absolute Error Across Epochs", "MAE"), + ("root_mean_squared_error", "Root Mean Squared Error Across Epochs", "RMSE"), + ("mean_absolute_percentage_error", "Mean Absolute Percentage Error Across Epochs", "MAPE"), + ("r2", "R² Across Epochs", "R²"), + ("loss", "Loss Across Epochs", "Loss"), + ] + epochs = None + for metric_key, title, ytitle in metrics: + series = _get_regression_series(label_test, metric_key) + if not series: + continue + if epochs is None: + epochs = list(range(1, len(series) + 1)) + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=epochs, + y=series[: len(epochs)], + mode="lines+markers", + name="Test", + line=dict(width=4), + ) + ) + fig.update_layout( + title=dict(text=title, x=0.5), + xaxis_title="Epoch", + yaxis_title=ytitle, + width=760, + height=520, + hovermode="x unified", + ) + _style_fig(fig) + plots.append({ + "title": title, + "html": pio.to_html( + fig, + full_html=False, + include_plotlyjs="cdn" if include_js else False, + ), + }) + include_js = False + return plots - df_pred = pd.read_csv(predictions_path) + +def _build_static_roc_plot( + label_stats: dict, config: dict, friendly_labels: Optional[List[str]] = None +) -> Optional[Dict[str, str]]: + """Build ROC curve directly from test_statistics.json (single curve).""" + roc_data = label_stats.get("roc_curve") + if not isinstance(roc_data, dict): + return None + + fpr = roc_data.get("false_positive_rate") + tpr = roc_data.get("true_positive_rate") + if not fpr or not tpr or len(fpr) != len(tpr): + return None + + try: + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=fpr, + y=tpr, + mode="lines+markers", + name="ROC Curve", + line=dict(color="#1f77b4", width=4), + hovertemplate="FPR: %{x:.3f}<br>TPR: %{y:.3f}<extra></extra>", + ) + ) + fig.add_trace( + go.Scatter( + x=[0, 1], + y=[0, 1], + mode="lines", + name="Random Classifier", + line=dict(color="gray", width=2, dash="dash"), + hovertemplate="Random Classifier<extra></extra>", + ) + ) + + auc_val = label_stats.get("roc_auc") or label_stats.get("roc_auc_macro") or label_stats.get("roc_auc_micro") + auc_txt = f" (AUC = {auc_val:.3f})" if isinstance(auc_val, (int, float)) else "" - if SPLIT_COLUMN_NAME in df_pred.columns: - split_series = df_pred[SPLIT_COLUMN_NAME].astype(str).str.lower() - test_mask = split_series.isin({"2", "test", "testing"}) - if test_mask.any(): - df_pred = df_pred[test_mask].reset_index(drop=True) + # Determine which label is treated as positive for the curve + label_list: List = [] + pcs = label_stats.get("per_class_stats", {}) + if pcs: + label_list = list(pcs.keys()) + if not label_list: + labels_from_stats = label_stats.get("labels") + if isinstance(labels_from_stats, list): + label_list = labels_from_stats + + # Try to resolve index of the positive label explicitly provided by Ludwig + pos_label_raw = ( + roc_data.get("positive_label") + or roc_data.get("positive_class") + or label_stats.get("positive_label") + ) + pos_label_idx = None + if pos_label_raw is not None and isinstance(label_list, list): + try: + pos_label_idx = label_list.index(pos_label_raw) + except ValueError: + pos_label_idx = None + + # Fallback: use the second label if available, otherwise the first + if pos_label_idx is None: + if isinstance(label_list, list) and len(label_list) >= 2: + pos_label_idx = 1 + elif isinstance(label_list, list) and label_list: + pos_label_idx = 0 + + if pos_label_raw is None and isinstance(label_list, list) and pos_label_idx is not None: + pos_label_raw = label_list[pos_label_idx] + + # Map to friendly label if we have one from metadata/CSV + pos_label_display = pos_label_raw + if ( + friendly_labels + and isinstance(pos_label_idx, int) + and 0 <= pos_label_idx < len(friendly_labels) + ): + pos_label_display = friendly_labels[pos_label_idx] + + pos_label_txt = ( + f"Positive class: {pos_label_display}" + if pos_label_display is not None + else "Positive class: (not available)" + ) + + title_label = f"ROC Curve{auc_txt}" + if pos_label_display is not None: + title_label = f"ROC Curve (Positive Class: {pos_label_display}){auc_txt}" - if df_pred.empty: - return None + fig.update_layout( + title=dict(text=title_label, x=0.5), + xaxis_title="False Positive Rate", + yaxis_title="True Positive Rate", + width=700, + height=600, + margin=dict(t=80, l=80, r=80, b=110), + hovermode="closest", + legend=dict( + x=0.6, + y=0.1, + bgcolor="rgba(255,255,255,0.9)", + bordercolor="rgba(0,0,0,0.2)", + borderwidth=1, + ), + ) + _style_fig(fig) + fig.update_xaxes(range=[0, 1.0]) + fig.update_yaxes(range=[0, 1.05]) - # Extract probability columns (label_probabilities_0, label_probabilities_1, etc.) - # or label_probabilities_<class_name> for string labels - prob_cols = [col for col in df_pred.columns if col.startswith('label_probabilities_') and col != 'label_probabilities'] + fig.add_annotation( + x=0.5, + y=-0.15, + xref="paper", + yref="paper", + showarrow=False, + text=f"<span style='font-size:12px;color:#555;'>{pos_label_txt}</span>", + xanchor="center", + ) + + return { + "title": "ROC Curve", + "html": pio.to_html( + fig, + full_html=False, + include_plotlyjs=False, + config=config, + ), + } + except Exception as e: + print(f"Error building ROC plot: {e}") + return None + + +def _build_precision_recall_plot(label_stats: dict, config: dict) -> Optional[Dict[str, str]]: + """Build Precision-Recall curve directly from test_statistics.json.""" + pr_data = label_stats.get("precision_recall_curve") + if not isinstance(pr_data, dict): + return None + + precisions = pr_data.get("precisions") + recalls = pr_data.get("recalls") + if not precisions or not recalls or len(precisions) != len(recalls): + return None - # Sort by class number if numeric, otherwise keep alphabetical order - if prob_cols and prob_cols[0].split('_')[-1].isdigit(): - prob_cols.sort(key=lambda x: int(x.split('_')[-1])) - else: - prob_cols.sort() # Alphabetical sort for string class names + try: + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=recalls, + y=precisions, + mode="lines+markers", + name="Precision-Recall", + line=dict(color="#d62728", width=4), + hovertemplate="Recall: %{x:.3f}<br>Precision: %{y:.3f}<extra></extra>", + ) + ) + + ap_val = ( + label_stats.get("average_precision_macro") + or label_stats.get("average_precision_micro") + or label_stats.get("average_precision_samples") + ) + ap_txt = f" (AP = {ap_val:.3f})" if isinstance(ap_val, (int, float)) else "" + + fig.update_layout( + title=dict(text=f"Precision-Recall Curve{ap_txt}", x=0.5), + xaxis_title="Recall", + yaxis_title="Precision", + width=700, + height=600, + margin=dict(t=80, l=80, r=80, b=80), + hovermode="closest", + legend=dict( + x=0.6, + y=0.1, + bgcolor="rgba(255,255,255,0.9)", + bordercolor="rgba(0,0,0,0.2)", + borderwidth=1, + ), + ) + _style_fig(fig) + fig.update_xaxes(range=[0, 1.0]) + fig.update_yaxes(range=[0, 1.05]) + + return { + "title": "Precision-Recall Curve", + "html": pio.to_html( + fig, + full_html=False, + include_plotlyjs=False, + config=config, + ), + } + except Exception as e: + print(f"Error building Precision-Recall plot: {e}") + return None + - if not prob_cols: - return None +def build_prediction_diagnostics( + predictions_path: str, + label_data_path: Optional[str] = None, + split_value: int = 2, + threshold: Optional[float] = None, +) -> List[Dict[str, str]]: + """Generate diagnostic plots from predictions.csv for classification tasks.""" + preds_file = Path(predictions_path) + if not preds_file.exists(): + return [] + + try: + df_pred = pd.read_csv(predictions_path) + except Exception as exc: + print(f"Warning: Unable to read predictions CSV: {exc}") + return [] + + plots: List[Dict[str, str]] = [] + + # Identify probability columns + prob_cols = [ + c for c in df_pred.columns + if c.startswith("label_probabilities_") and c != "label_probabilities" + ] + prob_cols_sorted = sorted(prob_cols) - # Get probabilities matrix (n_samples x n_classes) - y_score = df_pred[prob_cols].values - n_classes = len(prob_cols) + def _select_positive_prob(): + if not prob_cols_sorted: + return None, None + # Prefer a column indicating positive/event/true/1 + preferred_keys = ("event", "true", "positive", "pos", "1") + for col in prob_cols_sorted: + suffix = col.replace("label_probabilities_", "").lower() + if any(k in suffix for k in preferred_keys): + return col, suffix + if len(prob_cols_sorted) == 2: + col = prob_cols_sorted[1] + return col, col.replace("label_probabilities_", "") + col = prob_cols_sorted[0] + return col, col.replace("label_probabilities_", "") - y_true = None - candidate_cols = [ + pos_prob_col, pos_label_hint = _select_positive_prob() + pos_prob_series = df_pred[pos_prob_col] if pos_prob_col and pos_prob_col in df_pred else None + + # Confidence series: prefer label_probability, otherwise positive prob, otherwise max prob + confidence_series = None + if "label_probability" in df_pred.columns: + confidence_series = df_pred["label_probability"] + elif pos_prob_series is not None: + confidence_series = pos_prob_series + elif prob_cols_sorted: + confidence_series = df_pred[prob_cols_sorted].max(axis=1) + + # True labels + def _extract_labels(): + candidates = [ LABEL_COLUMN_NAME, f"{LABEL_COLUMN_NAME}_ground_truth", f"{LABEL_COLUMN_NAME}__ground_truth", f"{LABEL_COLUMN_NAME}_target", f"{LABEL_COLUMN_NAME}__target", + "label", + "label_true", ] - candidate_cols.extend( + candidates.extend( [ col for col in df_pred.columns @@ -230,174 +938,182 @@ and "predictions" not in col ] ) - for col in candidate_cols: - if col in df_pred.columns and col not in prob_cols: - y_true = df_pred[col].values - break + for col in candidates: + if col in df_pred.columns and col not in prob_cols_sorted: + return df_pred[col] + if label_data_path and Path(label_data_path).exists(): + try: + df_all = pd.read_csv(label_data_path) + if SPLIT_COLUMN_NAME in df_all.columns: + df_all = df_all[df_all[SPLIT_COLUMN_NAME] == split_value].reset_index(drop=True) + if LABEL_COLUMN_NAME in df_all.columns: + return df_all[LABEL_COLUMN_NAME].reset_index(drop=True) + except Exception as exc: + print(f"Warning: Unable to load labels from dataset: {exc}") + return None - if y_true is None: - desc_path = exp_dir / "description.json" - if desc_path.exists(): - try: - with open(desc_path, 'r') as f: - desc = json.load(f) - dataset_path = desc.get('dataset', '') - if dataset_path and Path(dataset_path).exists(): - df_orig = pd.read_csv(dataset_path) - if SPLIT_COLUMN_NAME in df_orig.columns: - df_orig = df_orig[df_orig[SPLIT_COLUMN_NAME] == 2].reset_index(drop=True) - if LABEL_COLUMN_NAME in df_orig.columns: - y_true = df_orig[LABEL_COLUMN_NAME].values - if len(y_true) != len(df_pred): - print( - f"Warning: Test set size mismatch. Truncating to {len(df_pred)} samples for ROC plot." - ) - y_true = y_true[:len(df_pred)] - else: - print("Warning: Original dataset referenced in description.json is unavailable.") - except Exception as exc: # pragma: no cover - defensive - print(f"Warning: Failed to recover labels from dataset: {exc}") - - if y_true is None or len(y_true) == 0: - print("Warning: Unable to locate ground-truth labels for ROC plot.") - return None - - if len(y_true) != len(y_score): - limit = min(len(y_true), len(y_score)) - if limit == 0: - return None - print(f"Warning: Aligning prediction and label lengths to {limit} samples for ROC plot.") - y_true = y_true[:limit] - y_score = y_score[:limit] + labels_series = _extract_labels() - # Get actual class names from probability column names - actual_classes = [col.replace('label_probabilities_', '') for col in prob_cols] - display_classes = class_labels if len(class_labels) == n_classes else actual_classes - - # Binarize the output following sklearn example - # Use actual class names if they're strings, otherwise use range - if isinstance(y_true[0], str): - y_test = label_binarize(y_true, classes=actual_classes) - else: - y_test = label_binarize(y_true, classes=list(range(n_classes))) - - # Handle binary classification case - if y_test.ndim != 2: - y_test = np.atleast_2d(y_test) + # Plot 1: Confidence Histogram + if confidence_series is not None: + fig_conf = go.Figure() + fig_conf.add_trace( + go.Histogram( + x=confidence_series, + nbinsx=20, + marker=dict(color="#1f77b4", line=dict(color="#ffffff", width=1)), + opacity=0.8, + histnorm="percent", + ) + ) + fig_conf.update_layout( + title=dict(text="Prediction Confidence Distribution", x=0.5), + xaxis_title="Predicted probability (confidence)", + yaxis_title="Percentage (%)", + bargap=0.05, + width=700, + height=500, + ) + _style_fig(fig_conf) + plots.append({ + "title": "Prediction Confidence Distribution", + "html": pio.to_html(fig_conf, full_html=False, include_plotlyjs=False), + }) - if n_classes == 2: - if y_test.shape[1] == 1: - y_test = np.hstack([1 - y_test, y_test]) - elif y_test.shape[1] != 2: - print("Warning: Unexpected label binarization shape for binary ROC plot.") - return None - elif y_test.shape[1] != n_classes: - print("Warning: Label binarization did not produce expected class dimension; skipping ROC plot.") - return None + # The remaining plots require true labels and a positive-class probability + if labels_series is None or pos_prob_series is None: + return plots + + # Align lengths + min_len = min(len(labels_series), len(pos_prob_series)) + if min_len == 0: + return plots + y_true_raw = labels_series.iloc[:min_len] + y_score = np.array(pos_prob_series.iloc[:min_len], dtype=float) - # Compute ROC curve and ROC area for each class (following sklearn example) - fpr = dict() - tpr = dict() - roc_auc = dict() + # Determine positive label + unique_labels = pd.unique(y_true_raw) + unique_labels_list = list(unique_labels) + positive_label = None + if pos_label_hint and str(pos_label_hint) in [str(u) for u in unique_labels_list]: + positive_label = pos_label_hint + elif len(unique_labels_list) == 2: + positive_label = unique_labels_list[1] + else: + positive_label = unique_labels_list[0] - for i in range(n_classes): - if np.sum(y_test[:, i]) > 0: # Check if class exists in test set - fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i]) - roc_auc[i] = auc(fpr[i], tpr[i]) - - # Compute micro-average ROC curve and ROC area (sklearn example) - fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel()) - roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) - - # Create ROC curve plot - fig_roc = go.Figure() + y_true = (y_true_raw == positive_label).astype(int).values - # Colors for different classes - colors = [ - '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', - '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf' - ] - - # Plot micro-average ROC curve first (most important) - fig_roc.add_trace(go.Scatter( - x=fpr["micro"], - y=tpr["micro"], - mode='lines', - name=f'Micro-average ROC (AUC = {roc_auc["micro"]:.3f})', - line=dict(color='deeppink', width=3, dash='dot'), - hovertemplate=('<b>Micro-average ROC</b><br>' - 'FPR: %{x:.3f}<br>' - 'TPR: %{y:.3f}<br>' - f'AUC: {roc_auc["micro"]:.3f}<extra></extra>') - )) - - # Plot ROC curve for each class - for i in range(n_classes): - if i in roc_auc: # Only plot if class exists in test set - class_name = display_classes[i] if i < len(display_classes) else f"Class {i}" - color = colors[i % len(colors)] - - fig_roc.add_trace(go.Scatter( - x=fpr[i], - y=tpr[i], - mode='lines', - name=f'{class_name} (AUC = {roc_auc[i]:.3f})', - line=dict(color=color, width=2), - hovertemplate=(f'<b>{class_name}</b><br>' - 'FPR: %{x:.3f}<br>' - 'TPR: %{y:.3f}<br>' - f'AUC: {roc_auc[i]:.3f}<extra></extra>') - )) + # Plot 2: Calibration Curve + bins = np.linspace(0.0, 1.0, 11) + bin_ids = np.digitize(y_score, bins, right=True) + bin_centers = [] + frac_positives = [] + for b in range(1, len(bins)): + mask = bin_ids == b + if not np.any(mask): + continue + bin_centers.append(y_score[mask].mean()) + frac_positives.append(y_true[mask].mean()) + if bin_centers and frac_positives: + fig_cal = go.Figure() + fig_cal.add_trace( + go.Scatter( + x=bin_centers, + y=frac_positives, + mode="lines+markers", + name="Calibration", + line=dict(color="#2ca02c", width=4), + ) + ) + fig_cal.add_trace( + go.Scatter( + x=[0, 1], + y=[0, 1], + mode="lines", + name="Perfect Calibration", + line=dict(color="gray", width=2, dash="dash"), + ) + ) + fig_cal.update_layout( + title=dict(text="Calibration Curve", x=0.5), + xaxis_title="Predicted probability", + yaxis_title="Observed frequency", + width=700, + height=500, + ) + _style_fig(fig_cal) + plots.append({ + "title": "Calibration Curve (Predicted Probability vs Observed Frequency)", + "html": pio.to_html(fig_cal, full_html=False, include_plotlyjs=False), + }) - # Add diagonal line (random classifier) - fig_roc.add_trace(go.Scatter( - x=[0, 1], - y=[0, 1], - mode='lines', - name='Random Classifier', - line=dict(color='gray', width=1, dash='dash'), - hovertemplate='Random Classifier<br>AUC = 0.500<extra></extra>' - )) - - # Calculate macro-average AUC - class_aucs = [roc_auc[i] for i in range(n_classes) if i in roc_auc] - if class_aucs: - macro_auc = np.mean(class_aucs) - title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f}, Macro-avg = {macro_auc:.3f})" - else: - title_text = f"ROC Curves (Micro-avg = {roc_auc['micro']:.3f})" + # Plot 3: Threshold vs Metrics + thresholds = np.linspace(0.0, 1.0, 21) + accs, f1s, sens, specs = [], [], [], [] + for t in thresholds: + y_pred = (y_score >= t).astype(int) + tp = np.sum((y_true == 1) & (y_pred == 1)) + tn = np.sum((y_true == 0) & (y_pred == 0)) + fp = np.sum((y_true == 0) & (y_pred == 1)) + fn = np.sum((y_true == 1) & (y_pred == 0)) + acc = (tp + tn) / max(len(y_true), 1) + prec = tp / max(tp + fp, 1e-9) + rec = tp / max(tp + fn, 1e-9) + f1 = 0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec) + sensitivity = rec + specificity = tn / max(tn + fp, 1e-9) + accs.append(acc) + f1s.append(f1) + sens.append(sensitivity) + specs.append(specificity) - fig_roc.update_layout( - title=dict(text=title_text, x=0.5), - xaxis_title="False Positive Rate", - yaxis_title="True Positive Rate", - width=700, - height=600, - margin=dict(t=80, l=80, r=80, b=80), - legend=dict( - x=0.6, - y=0.1, - bgcolor="rgba(255,255,255,0.9)", - bordercolor="rgba(0,0,0,0.2)", - borderwidth=1 - ), - hovermode='closest' - ) + fig_thresh = go.Figure() + fig_thresh.add_trace(go.Scatter(x=thresholds, y=accs, mode="lines", name="Accuracy", line=dict(width=4))) + fig_thresh.add_trace(go.Scatter(x=thresholds, y=f1s, mode="lines", name="F1", line=dict(width=4))) + fig_thresh.add_trace(go.Scatter(x=thresholds, y=sens, mode="lines", name="Sensitivity", line=dict(width=4))) + fig_thresh.add_trace(go.Scatter(x=thresholds, y=specs, mode="lines", name="Specificity", line=dict(width=4))) + fig_thresh.update_layout( + title=dict(text="Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", x=0.5), + xaxis_title="Decision threshold", + yaxis_title="Metric value", + width=700, + height=500, + legend=dict( + x=0.7, + y=0.2, + bgcolor="rgba(255,255,255,0.9)", + bordercolor="rgba(0,0,0,0.2)", + borderwidth=1, + ), + shapes=[ + dict( + type="line", + x0=threshold, + x1=threshold, + y0=0, + y1=1, + xref="x", + yref="paper", + line=dict(color="#d62728", width=2, dash="dash"), + ) + ] if isinstance(threshold, (int, float)) else [], + annotations=[ + dict( + x=threshold, + y=1.02, + xref="x", + yref="paper", + showarrow=False, + text=f"Threshold = {threshold:.2f}", + font=dict(size=11, color="#d62728"), + ) + ] if isinstance(threshold, (int, float)) else [], + ) + _style_fig(fig_thresh) + plots.append({ + "title": "Threshold Sweep: Accuracy, F1, Sensitivity, Specificity", + "html": pio.to_html(fig_thresh, full_html=False, include_plotlyjs=False), + }) - # Set equal aspect ratio and proper range - fig_roc.update_xaxes(range=[0, 1.0]) - fig_roc.update_yaxes(range=[0, 1.05]) - - return { - "title": "ROC-AUC Curves", - "html": pio.to_html( - fig_roc, - full_html=False, - include_plotlyjs=False, - config=config - ) - } - - except Exception as e: - print(f"Error building ROC-AUC plot: {e}") - return None + return plots
