Mercurial > repos > goeckslab > multimodal_learner
diff plot_logic.py @ 0:375c36923da1 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
| author | goeckslab |
|---|---|
| date | Tue, 09 Dec 2025 23:49:47 +0000 |
| parents | |
| children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/plot_logic.py Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,1697 @@ +from __future__ import annotations + +import html +import os +from html import escape as _escape +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +import shap +from feature_help_modal import get_metrics_help_modal +from report_utils import build_tabbed_html, get_html_closing, get_html_template +from sklearn.calibration import calibration_curve +from sklearn.metrics import ( + auc, + average_precision_score, + classification_report, + confusion_matrix, + log_loss, + precision_recall_curve, + roc_auc_score, + roc_curve, +) +from sklearn.model_selection import learning_curve as skl_learning_curve +from sklearn.preprocessing import label_binarize + +# ========================= +# Utilities +# ========================= + + +def plot_with_table_style_title(fig, title: str) -> str: + """ + Render a Plotly figure with a report-style <h2> header so it matches the + green table section headers. + """ + # kill Plotly’s built-in title + fig.update_layout(title=None) + + # figure HTML without PlotlyJS (we load it once globally) + plot_html = fig.to_html(full_html=False, include_plotlyjs=False) + + # use <h2> — your CSS already styles <h2> like the table headers + return f""" +<h2>{html.escape(title)}</h2> +<div class="plotly-center">{plot_html}</div> +""".strip() + + +def _save_plotly(fig: go.Figure, path: Optional[str]) -> None: + """ + Save a Plotly figure. If `path` ends with `.html`, save interactive HTML. + If it ends with a raster extension (png/jpg/jpeg/webp), uses Kaleido. + If None, do nothing (caller may choose to display in notebook). + """ + if not path: + return + ext = os.path.splitext(path)[1].lower() + if ext == ".html": + fig.write_html(path, include_plotlyjs="cdn", full_html=True) + else: + # Requires kaleido: pip install -U kaleido + fig.write_image(path) + + +def _save_matplotlib(path: Optional[str]) -> None: + """Save current Matplotlib figure if `path` is provided, else show().""" + if path: + plt.savefig(path, bbox_inches="tight") + plt.close() + else: + plt.show() + +# ========================= +# Classification Plots +# ========================= + + +def generate_confusion_matrix_plot( + y_true, + y_pred, + title: str = "Confusion Matrix", +) -> go.Figure: + y_true = np.asarray(y_true) + y_pred = np.asarray(y_pred) + + # Class order (works for strings or numbers) + labels = pd.Index(np.unique(np.concatenate([y_true, y_pred])), dtype=object).tolist() + cm = confusion_matrix(y_true, y_pred, labels=labels) + max_val = cm.max() if cm.size else 0 + + # Use categorical axes by passing string labels for x/y + cats = [str(label) for label in labels] + total = int(cm.sum()) + + fig = go.Figure( + data=go.Heatmap( + z=cm, + x=cats, # categorical x + y=cats, # categorical y + colorscale="Blues", + showscale=True, + colorbar=dict(title="Count"), + xgap=2, + ygap=2, + hovertemplate="True=%{y}<br>Pred=%{x}<br>Count=%{z}<extra></extra>", + zmin=0 + ) + ) + + # Add annotations with count and percentage (all white text, matching sample_output.html) + annotations = [] + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + val = int(cm[i, j]) + pct = (val / total * 100) if total > 0 else 0 + text_color = "white" if max_val and val > (max_val / 2) else "black" + # Count annotation (bold, bottom) + annotations.append( + dict( + x=cats[j], + y=cats[i], + text=f"<b>{val}</b>", + showarrow=False, + font=dict(color=text_color, size=14), + xanchor="center", + yanchor="bottom", + yshift=2 + ) + ) + # Percentage annotation (top) + annotations.append( + dict( + x=cats[j], + y=cats[i], + text=f"{pct:.1f}%", + showarrow=False, + font=dict(color=text_color, size=13), + xanchor="center", + yanchor="top", + yshift=-2 + ) + ) + + fig.update_layout( + title=None, + xaxis_title="Predicted label", + yaxis_title="True label", + xaxis=dict(type="category"), + yaxis=dict(type="category", autorange="reversed"), # typical CM orientation + margin=dict(l=80, r=20, t=40, b=80), + template="plotly_white", + plot_bgcolor="white", + paper_bgcolor="white", + annotations=annotations + ) + return fig + + +def generate_roc_curve_plot( + y_true_bin: np.ndarray, + y_score: np.ndarray, + title: str = "ROC Curve", + marker_threshold: float | None = None, +) -> go.Figure: + y_true_bin = np.asarray(y_true_bin).astype(int).reshape(-1) + y_score = np.asarray(y_score).astype(float).reshape(-1) + + fpr, tpr, thr = roc_curve(y_true_bin, y_score) + roc_auc = auc(fpr, tpr) + + fig = go.Figure() + fig.add_trace(go.Scatter( + x=fpr, y=tpr, mode="lines", + name=f"ROC (AUC={roc_auc:.3f})", + line=dict(width=3) + )) + + # 45° chance line (no legend to keep it clean) + fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode="lines", + line=dict(dash="dash", width=2, color="#888"), showlegend=False)) + + # Optional marker at the user threshold + if marker_threshold is not None and len(thr): + # roc_curve returns thresholds of same length as fpr/tpr; includes inf at idx 0 + finite = np.isfinite(thr) + if np.any(finite): + idx_local = int(np.argmin(np.abs(thr[finite] - float(marker_threshold)))) + idx = int(np.nonzero(finite)[0][idx_local]) # map back to original indices + x_m, y_m = float(fpr[idx]), float(tpr[idx]) + + fig.add_trace( + go.Scatter( + x=[x_m], y=[y_m], + mode="markers", + name=f"@ {float(marker_threshold):.2f}", + marker=dict(size=12, color="red", symbol="x") + ) + ) + fig.add_annotation( + x=0.02, y=0.98, xref="paper", yref="paper", + text=f"threshold = {float(marker_threshold):.2f}", + showarrow=False, + font=dict(color="black", size=12), + align="left" + ) + + fig.update_layout( + title=None, + xaxis_title="False Positive Rate", + yaxis_title="True Positive Rate", + template="plotly_white", + legend=dict(x=1, y=0, xanchor="right"), + margin=dict(l=60, r=20, t=60, b=60), + ) + return fig + + +def generate_pr_curve_plot( + y_true_bin: np.ndarray, + y_score: np.ndarray, + title: str = "Precision–Recall Curve", + marker_threshold: float | None = None, +) -> go.Figure: + y_true_bin = np.asarray(y_true_bin).astype(int).reshape(-1) + y_score = np.asarray(y_score).astype(float).reshape(-1) + + precision, recall, thr = precision_recall_curve(y_true_bin, y_score) + pr_auc = auc(recall, precision) + + fig = go.Figure() + fig.add_trace(go.Scatter( + x=recall, y=precision, mode="lines", + name=f"PR (AUC={pr_auc:.3f})", + line=dict(width=3) + )) + + # Optional marker at the user threshold + if marker_threshold is not None and len(thr): + # In PR, thresholds has length len(precision)-1. The point for thr[j] is (recall[j+1], precision[j+1]). + j = int(np.argmin(np.abs(thr - float(marker_threshold)))) + j = int(np.clip(j, 0, len(thr) - 1)) + x_m, y_m = float(recall[j + 1]), float(precision[j + 1]) + + fig.add_trace( + go.Scatter( + x=[x_m], y=[y_m], + mode="markers", + name=f"@ {float(marker_threshold):.2f}", + marker=dict(size=12, color="red", symbol="x") + ) + ) + fig.add_annotation( + x=0.02, y=0.98, xref="paper", yref="paper", + text=f"threshold = {float(marker_threshold):.2f}", + showarrow=False, + font=dict(color="black", size=12), + align="left" + ) + + fig.update_layout( + title=None, + xaxis_title="Recall", + yaxis_title="Precision", + template="plotly_white", + legend=dict(x=1, y=0, xanchor="right"), + margin=dict(l=60, r=20, t=60, b=60), + ) + return fig + + +def generate_calibration_plot( + y_true_bin: np.ndarray, + y_prob: np.ndarray, + n_bins: int = 10, + title: str = "Calibration Plot", + path: Optional[str] = None, +) -> go.Figure: + """ + Binary calibration curve (Plotly). + """ + prob_true, prob_pred = calibration_curve(y_true_bin, y_prob, n_bins=n_bins, strategy="uniform") + fig = go.Figure() + fig.add_trace(go.Scatter( + x=prob_pred, y=prob_true, mode="lines+markers", name="Model", + line=dict(color="#1f77b4", width=3), marker=dict(size=7, color="#1f77b4") + )) + fig.add_trace( + go.Scatter( + x=[0, 1], y=[0, 1], + mode="lines", + line=dict(dash="dash", color="#808080", width=2), + name="Perfect" + ) + ) + fig.update_layout( + title=None, + xaxis_title="Predicted Probability", + yaxis_title="Observed Probability", + yaxis=dict(range=[0, 1]), + xaxis=dict(range=[0, 1]), + template="plotly_white", + margin=dict(l=60, r=40, t=50, b=50), + ) + _save_plotly(fig, path) + return fig + + +def generate_threshold_plot( + y_true_bin: np.ndarray, + y_prob: np.ndarray, + title: str = "Threshold Plot", + user_threshold: float | None = None, +) -> go.Figure: + y_true = np.asarray(y_true_bin, dtype=int).ravel() + p = np.asarray(y_prob, dtype=float).ravel() + p = np.nan_to_num(p, nan=0.0) + p = np.clip(p, 0.0, 1.0) + + def _compute_metrics(thresholds: np.ndarray): + """Vectorized-ish helper to compute precision/recall/F1/queue rate arrays.""" + prec, rec, f1, qrate = [], [], [], [] + for t in thresholds: + yhat = (p >= t).astype(int) + tp = int(((yhat == 1) & (y_true == 1)).sum()) + fp = int(((yhat == 1) & (y_true == 0)).sum()) + fn = int(((yhat == 0) & (y_true == 1)).sum()) + + pr = tp / (tp + fp) if (tp + fp) else np.nan # undefined when no predicted positives + rc = tp / (tp + fn) if (tp + fn) else 0.0 + f = (2 * pr * rc) / (pr + rc) if (pr + rc) and not np.isnan(pr) else 0.0 + q = float(yhat.mean()) + + prec.append(pr) + rec.append(rc) + f1.append(f) + qrate.append(q) + return np.asarray(prec, dtype=float), np.asarray(rec, dtype=float), np.asarray(f1, dtype=float), np.asarray(qrate, dtype=float) + + # Use uniform threshold grid for plotting (0 to 1 in steps of 0.01) + th = np.linspace(0.0, 1.0, 101) + prec, rec, f1_arr, qrate = _compute_metrics(th) + + # Compute F1*-optimal threshold using actual score distribution (more precise than grid) + cand_th = np.unique(np.concatenate(([0.0, 1.0], p))) + # cap to a reasonable size by sampling if extremely large + if cand_th.size > 2000: + cand_th = np.linspace(0.0, 1.0, 2001) + _, _, f1_cand, _ = _compute_metrics(cand_th) + + if np.all(np.isnan(f1_cand)): + t_star = 0.5 # fallback when no valid F1 can be computed + else: + f1_max = np.nanmax(f1_cand) + best_idxs = np.where(np.isclose(f1_cand, f1_max, equal_nan=False))[0] + # pick the middle of the best candidates to avoid biasing toward 0 + best_idx = int(best_idxs[len(best_idxs) // 2]) + t_star = float(cand_th[best_idx]) + + # Replace NaNs for plotting (set to 0 where precision is undefined) + prec_plot = np.nan_to_num(prec, nan=0.0) + + fig = go.Figure() + + # Precision (blue line) + fig.add_trace(go.Scatter( + x=th, y=prec_plot, mode="lines", name="Precision", + line=dict(width=3, color="#1f77b4"), + hovertemplate="Threshold=%{x:.3f}<br>Precision=%{y:.3f}<extra></extra>" + )) + + # Recall (orange line) + fig.add_trace(go.Scatter( + x=th, y=rec, mode="lines", name="Recall", + line=dict(width=3, color="#ff7f0e"), + hovertemplate="Threshold=%{x:.3f}<br>Recall=%{y:.3f}<extra></extra>" + )) + + # F1 (green line) + fig.add_trace(go.Scatter( + x=th, y=f1_arr, mode="lines", name="F1", + line=dict(width=3, color="#2ca02c"), + hovertemplate="Threshold=%{x:.3f}<br>F1=%{y:.3f}<extra></extra>" + )) + + # Queue Rate (grey dashed line) + fig.add_trace(go.Scatter( + x=th, y=qrate, mode="lines", name="Queue Rate", + line=dict(width=2, color="#808080", dash="dash"), + hovertemplate="Threshold=%{x:.3f}<br>Queue Rate=%{y:.3f}<extra></extra>" + )) + + # F1*-optimal threshold marker (dashed vertical line) + fig.add_vline( + x=t_star, + line_width=2, + line_dash="dash", + line_color="black", + annotation_text=f"t* = {t_star:.2f}", + annotation_position="top" + ) + + # User threshold (solid red line) if provided + if user_threshold is not None: + fig.add_vline( + x=float(user_threshold), + line_width=2, + line_color="red", + annotation_text=f"threshold = {float(user_threshold):.2f}", + annotation_position="top" + ) + + fig.update_layout( + title=None, + template="plotly_white", + xaxis=dict( + title="Discrimination Threshold", + range=[0, 1], + gridcolor="#e0e0e0", + showgrid=True, + zeroline=False + ), + yaxis=dict( + title="Score", + range=[0, 1], + gridcolor="#e0e0e0", + showgrid=True, + zeroline=False + ), + legend=dict( + orientation="h", + yanchor="bottom", + y=1.02, + xanchor="right", + x=1.0 + ), + margin=dict(l=60, r=20, t=40, b=50), + plot_bgcolor="white", + paper_bgcolor="white", + ) + return fig + + +def generate_per_class_metrics_plot( + y_true: Sequence, + y_pred: Sequence, + metrics: Sequence[str] = ("precision", "recall", "f1_score"), + title: str = "Classification Report", + path: Optional[str] = None, +) -> go.Figure: + """ + Per-class metrics heatmap (Plotly), similar to sklearn's classification report. + Rows = classes, columns = metrics; cell text shows the value (0–1). + """ + # Map display names -> sklearn keys + key_map = {"f1_score": "f1-score", "precision": "precision", "recall": "recall"} + report = classification_report( + y_true, y_pred, output_dict=True, zero_division=0 + ) + + # Order classes sensibly (numeric if possible, else lexical) + def _sort_key(x): + try: + return (0, float(x)) + except Exception: + return (1, str(x)) + + # Use all classes seen in y_true or y_pred (so rows don't jump around) + uniq = sorted(set(list(y_true) + list(y_pred)), key=_sort_key) + classes = [str(c) for c in uniq] + + # Build Z matrix (rows=classes, cols=metrics) + used_metrics = [key_map.get(m, m) for m in metrics] + z = [] + for c in classes: + row = report.get(c, {}) + z.append([float(row.get(m, 0.0) or 0.0) for m in used_metrics]) + z = np.array(z, dtype=float) + + # Pretty cell labels + z_text = [[f"{v:.2f}" for v in r] for r in z] + + fig = go.Figure( + data=go.Heatmap( + z=z, + x=list(metrics), # keep display names ("precision", "recall", "f1_score") + y=classes, # classes as strings + colorscale="Reds", + zmin=0.0, + zmax=1.0, + colorbar=dict(title="Value"), + text=z_text, + texttemplate="%{text}", + hovertemplate="Class %{y}<br>%{x}: %{z:.2f}<extra></extra>", + ) + ) + fig.update_layout( + title=None, + xaxis_title="", + yaxis_title="Class", + template="plotly_white", + margin=dict(l=60, r=60, t=70, b=40), + ) + + _save_plotly(fig, path) + return fig + + +def generate_multiclass_roc_curve_plot( + y_true: Sequence, + y_prob: np.ndarray, + classes: Sequence, + title: str = "Multiclass ROC Curve", + path: Optional[str] = None, +) -> go.Figure: + """ + One-vs-rest ROC curves for multiclass (Plotly). + Handles binary passed as 2-column probs as well. + """ + y_true = np.asarray(y_true) + y_prob = np.asarray(y_prob) + + # Normalize to shape (n_samples, n_classes) + if y_prob.ndim == 1 or y_prob.shape[1] == 1: + y_prob = np.hstack([1 - y_prob.reshape(-1, 1), y_prob.reshape(-1, 1)]) + + y_true_bin = label_binarize(y_true, classes=classes) + if y_true_bin.shape[1] == 1 and y_prob.shape[1] == 2: + y_true_bin = np.hstack([1 - y_true_bin, y_true_bin]) + + if y_prob.shape[1] != y_true_bin.shape[1]: + raise ValueError( + f"Shape mismatch: y_prob has {y_prob.shape[1]} columns but y_true_bin has {y_true_bin.shape[1]}." + ) + + fig = go.Figure() + for i, cls in enumerate(classes): + fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_prob[:, i]) + auc_val = roc_auc_score(y_true_bin[:, i], y_prob[:, i]) + fig.add_trace(go.Scatter(x=fpr, y=tpr, mode="lines", name=f"{cls} (AUC {auc_val:.2f})")) + + fig.add_trace( + go.Scatter(x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash"), showlegend=False) + ) + fig.update_layout( + title=None, + xaxis_title="False Positive Rate", + yaxis_title="True Positive Rate", + template="plotly_white", + ) + _save_plotly(fig, path) + return fig + + +def generate_multiclass_pr_curve_plot( + y_true: Sequence, + y_prob: np.ndarray, + classes: Optional[Sequence] = None, + title: str = "Precision–Recall Curve", + path: Optional[str] = None, +) -> go.Figure: + """ + Multiclass PR curves (Plotly). If classes is None or len==2, shows binary PR. + """ + y_true = np.asarray(y_true) + y_prob = np.asarray(y_prob) + fig = go.Figure() + + if not classes or len(classes) == 2: + precision, recall, _ = precision_recall_curve(y_true, y_prob[:, 1]) + ap = average_precision_score(y_true, y_prob[:, 1]) + fig.add_trace(go.Scatter(x=recall, y=precision, mode="lines", name=f"AP = {ap:.2f}")) + else: + for i, cls in enumerate(classes): + y_true_bin = (y_true == cls).astype(int) + y_prob_cls = y_prob[:, i] + precision, recall, _ = precision_recall_curve(y_true_bin, y_prob_cls) + ap = average_precision_score(y_true_bin, y_prob_cls) + fig.add_trace(go.Scatter(x=recall, y=precision, mode="lines", name=f"{cls} (AP {ap:.2f})")) + + fig.update_layout( + title=None, + xaxis_title="Recall", + yaxis_title="Precision", + yaxis=dict(range=[0, 1]), + xaxis=dict(range=[0, 1]), + template="plotly_white", + ) + _save_plotly(fig, path) + return fig + + +def generate_metric_comparison_bar( + metrics_scores: Mapping[str, Sequence[float]], + phases: Sequence[str] = ("train", "val", "test"), + title: str = "Metric Comparison Across Phases", + path: Optional[str] = None, +) -> go.Figure: + """ + Grouped bar chart comparing metrics across phases (Plotly). + metrics_scores: {metric_name: [train, val, test]} + """ + df = pd.DataFrame(metrics_scores, index=phases).T.reset_index().rename(columns={"index": "Metric"}) + df_m = df.melt(id_vars="Metric", var_name="Phase", value_name="Score") + fig = px.bar(df_m, x="Metric", y="Score", color="Phase", barmode="group", title=None) + ymax = max(1.0, df_m["Score"].max() * 1.05) + fig.update_yaxes(range=[0, ymax]) + fig.update_layout(template="plotly_white") + _save_plotly(fig, path) + return fig + +# ========================= +# Regression Plots +# ========================= + + +def generate_scatter_plot( + y_true: Sequence[float], + y_pred: Sequence[float], + title: str = "Predicted vs Actual", + path: Optional[str] = None, +) -> go.Figure: + """ + Predicted vs. Actual scatter with y=x reference (Plotly). + """ + y_true = np.asarray(y_true) + y_pred = np.asarray(y_pred) + vmin = float(min(np.min(y_true), np.min(y_pred))) + vmax = float(max(np.max(y_true), np.max(y_pred))) + + fig = px.scatter(x=y_true, y=y_pred, opacity=0.6, labels={"x": "Actual", "y": "Predicted"}, title=None) + fig.add_trace(go.Scatter(x=[vmin, vmax], y=[vmin, vmax], mode="lines", line=dict(dash="dash"), name="Ideal")) + fig.update_layout(template="plotly_white") + _save_plotly(fig, path) + return fig + + +def generate_residual_plot( + y_true: Sequence[float], + y_pred: Sequence[float], + title: str = "Residual Plot", + path: Optional[str] = None, +) -> go.Figure: + """ + Residuals vs Predicted (Plotly). + """ + y_true = np.asarray(y_true) + y_pred = np.asarray(y_pred) + residuals = y_true - y_pred + + fig = px.scatter(x=y_pred, y=residuals, opacity=0.6, + labels={"x": "Predicted", "y": "Residual (Actual - Predicted)"}, + title=None) + fig.add_hline(y=0, line_dash="dash") + fig.update_layout(template="plotly_white") + _save_plotly(fig, path) + return fig + + +def generate_residual_histogram( + y_true: Sequence[float], + y_pred: Sequence[float], + bins: int = 30, + title: str = "Residual Histogram", + path: Optional[str] = None, +) -> go.Figure: + """ + Residuals histogram (Plotly). + """ + residuals = np.asarray(y_true) - np.asarray(y_pred) + fig = px.histogram(x=residuals, nbins=bins, labels={"x": "Residual"}, title=None) + fig.update_layout(yaxis_title="Frequency", template="plotly_white") + _save_plotly(fig, path) + return fig + + +def generate_regression_calibration_plot( + y_true: Sequence[float], + y_pred: Sequence[float], + num_bins: int = 10, + title: str = "Regression Calibration Plot", + path: Optional[str] = None, +) -> go.Figure: + """ + Binned Actual vs Predicted means (Plotly). + """ + y_true = np.asarray(y_true) + y_pred = np.asarray(y_pred) + + order = np.argsort(y_pred) + y_true_sorted = y_true[order] + y_pred_sorted = y_pred[order] + + bins = np.array_split(np.arange(len(y_pred_sorted)), num_bins) + bin_means_pred = [float(np.mean(y_pred_sorted[idx])) for idx in bins if len(idx)] + bin_means_true = [float(np.mean(y_true_sorted[idx])) for idx in bins if len(idx)] + + vmin = float(min(np.min(y_pred), np.min(y_true))) + vmax = float(max(np.max(y_pred), np.max(y_true))) + + fig = go.Figure() + fig.add_trace(go.Scatter(x=bin_means_pred, y=bin_means_true, mode="lines+markers", + name="Binned Actual vs Predicted")) + fig.add_trace(go.Scatter(x=[vmin, vmax], y=[vmin, vmax], mode="lines", line=dict(dash="dash"), + name="Ideal")) + fig.update_layout( + title=None, + xaxis_title="Mean Predicted per bin", + yaxis_title="Mean Actual per bin", + template="plotly_white", + ) + _save_plotly(fig, path) + return fig + +# ========================= +# Confidence / Diagnostics +# ========================= + + +def plot_error_vs_confidence( + y_true: Union[Sequence[int], np.ndarray], + y_proba: Union[Sequence[float], np.ndarray], + n_bins: int = 10, + title: str = "Error vs Confidence", + path: Optional[str] = None, +) -> go.Figure: + """ + Error rate vs confidence (binary), confidence=max(p, 1-p). Plotly. + """ + y_true = np.asarray(y_true) + y_proba = np.asarray(y_proba).reshape(-1) + y_pred = (y_proba >= 0.5).astype(int) + confidence = np.maximum(y_proba, 1 - y_proba) + error = (y_pred != y_true).astype(int) + + bins = np.linspace(0.0, 1.0, n_bins + 1) + idx = np.digitize(confidence, bins, right=True) + + centers, err_rates = [], [] + for i in range(1, len(bins)): + mask = (idx == i) + if mask.any(): + centers.append(float(confidence[mask].mean())) + err_rates.append(float(error[mask].mean())) + + fig = go.Figure() + fig.add_trace(go.Scatter(x=centers, y=err_rates, mode="lines+markers", name="Error rate")) + fig.update_layout( + title=None, + xaxis_title="Confidence (max predicted probability)", + yaxis_title="Error Rate", + yaxis=dict(range=[0, 1]), + template="plotly_white", + ) + _save_plotly(fig, path) + return fig + + +def plot_confidence_histogram( + y_proba: np.ndarray, + bins: int = 20, + title: str = "Confidence Histogram", + path: Optional[str] = None, +) -> go.Figure: + """ + Histogram of max predicted probabilities (Plotly). + Works for binary (n_samples,) or (n_samples,2) and multiclass (n_samples,C). + """ + y_proba = np.asarray(y_proba) + if y_proba.ndim == 1: + confidences = np.maximum(y_proba, 1 - y_proba) + else: + confidences = np.max(y_proba, axis=1) + + fig = px.histogram( + x=confidences, + nbins=bins, + range_x=(0, 1), + histnorm="percent", + labels={"x": "Confidence (max predicted probability)", "y": "Percent of samples (%)"}, + title=None, + ) + if fig.data: + fig.update_traces(hovertemplate="Conf=%{x:.2f}<br>%{y:.2f}%<extra></extra>") + fig.update_layout(yaxis_title="Percent of samples (%)", template="plotly_white") + _save_plotly(fig, path) + return fig + +# ========================= +# Learning Curve +# ========================= + + +def generate_learning_curve_from_predictions( + y_true, + y_pred=None, + y_proba=None, + classes=None, + metric: str = "accuracy", + train_fracs: np.ndarray = np.linspace(0.1, 1.0, 10), + n_repeats: int = 5, + seed: int = 42, + title: str = "Learning Curve", + path: str | None = None, + return_stats: bool = False, +) -> Union[go.Figure, tuple[list[int], list[float], list[float]]]: + rng = np.random.default_rng(seed) + y_true = np.asarray(y_true) + N = len(y_true) + + if metric == "accuracy" and y_pred is None: + raise ValueError("accuracy curve requires y_pred") + if metric == "log_loss" and y_proba is None: + raise ValueError("log_loss curve requires y_proba") + + if y_proba is not None: + y_proba = np.asarray(y_proba) + if y_pred is not None: + y_pred = np.asarray(y_pred) + + sizes = (np.clip((train_fracs * N).astype(int), 1, N)).tolist() + means, stds = [], [] + for n in sizes: + vals = [] + for _ in range(n_repeats): + idx = rng.choice(N, size=n, replace=False) + if metric == "accuracy": + vals.append(float((y_true[idx] == y_pred[idx]).mean())) + else: + if y_proba.ndim == 1: + p = y_proba[idx] + pp = np.column_stack([1 - p, p]) + else: + pp = y_proba[idx] + vals.append(float(log_loss(y_true[idx], pp, labels=None if classes is None else classes))) + means.append(np.mean(vals)) + stds.append(np.std(vals)) + + if return_stats: + return sizes, means, stds + + fig = go.Figure() + fig.add_trace(go.Scatter( + x=sizes, y=means, mode="lines+markers", name="Train", + line=dict(width=3, shape="spline"), marker=dict(size=7), + error_y=dict(type="data", array=stds, visible=True) + )) + fig.update_layout( + title=None, + template="plotly_white", + xaxis=dict(title="epoch" if metric == "log_loss" else "samples", gridcolor="#eee"), + yaxis=dict(title=("loss" if metric == "log_loss" else "accuracy"), gridcolor="#eee"), + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), + margin=dict(l=50, r=20, t=60, b=50), + ) + if path: + _save_plotly(fig, path) + return fig + + +def build_train_html_and_plots( + predictor, + problem_type: str, + df_train: pd.DataFrame, + label_column: str, + tmpdir: str, + df_val: Optional[pd.DataFrame] = None, + seed: int = 42, + perf_table_html: str | None = None, + threshold: Optional[float] = None, + section_tile: str = "Training Diagnostics", +) -> str: + y_true = df_train[label_column].values + y_true_val = df_val[label_column].values if df_val is not None else None + # predictions on TRAIN + pred_labels, pred_proba = None, None + try: + pred_labels = predictor.predict(df_train) + except Exception: + pass + try: + proba_raw = predictor.predict_proba(df_train) + pred_proba = proba_raw.to_numpy() if isinstance(proba_raw, (pd.Series, pd.DataFrame)) else np.asarray(proba_raw) + except Exception: + pred_proba = None + + # predictions on VAL (if provided) + pred_labels_val, pred_proba_val = None, None + if df_val is not None: + try: + pred_labels_val = predictor.predict(df_val) + except Exception: + pred_labels_val = None + try: + proba_raw_val = predictor.predict_proba(df_val) + pred_proba_val = proba_raw_val.to_numpy() if isinstance(proba_raw_val, (pd.Series, pd.DataFrame)) else np.asarray(proba_raw_val) + except Exception: + pred_proba_val = None + + pos_scores_train: Optional[np.ndarray] = None + pos_scores_val: Optional[np.ndarray] = None + if problem_type == "binary": + if pred_proba is not None: + pos_scores_train = ( + pred_proba.reshape(-1) + if pred_proba.ndim == 1 or (pred_proba.ndim == 2 and pred_proba.shape[1] == 1) + else pred_proba[:, -1] + ) + if pred_proba_val is not None: + pos_scores_val = ( + pred_proba_val.reshape(-1) + if pred_proba_val.ndim == 1 or (pred_proba_val.ndim == 2 and pred_proba_val.shape[1] == 1) + else pred_proba_val[:, -1] + ) + + # Collect plots then append in desired order + perf_card = f"<div class='card'>{perf_table_html}</div>" if perf_table_html else None + acc_plot = loss_plot = None + cm_train = pc_train = cm_val = pc_val = None + threshold_val_plot = None + roc_combined = pr_combined = cal_combined = None + mc_roc_val = None + conf_train = conf_val = None + bar_train = bar_val = None + + # 1) Learning Curve — Accuracy + if problem_type in ("binary", "multiclass"): + acc_fig = go.Figure() + added_acc = False + if pred_labels is not None: + train_sizes, train_means, train_stds = generate_learning_curve_from_predictions( + y_true=y_true, + y_pred=np.asarray(pred_labels), + metric="accuracy", + title="Learning Curves — Label Accuracy", + seed=seed, + return_stats=True, + ) + acc_fig.add_trace(go.Scatter( + x=train_sizes, y=train_means, mode="lines+markers", name="Train", + line=dict(color="#1f77b4", width=3, shape="spline"), marker=dict(size=7), + error_y=dict(type="data", array=train_stds, visible=True), + )) + added_acc = True + if pred_labels_val is not None and y_true_val is not None: + val_sizes, val_means, val_stds = generate_learning_curve_from_predictions( + y_true=y_true_val, + y_pred=np.asarray(pred_labels_val), + metric="accuracy", + title="Learning Curves — Label Accuracy", + seed=seed, + return_stats=True, + ) + acc_fig.add_trace(go.Scatter( + x=val_sizes, y=val_means, mode="lines+markers", name="Validation", + line=dict(color="#ff7f0e", width=3, shape="spline"), marker=dict(size=7), + error_y=dict(type="data", array=val_stds, visible=True), + )) + added_acc = True + if added_acc: + acc_fig.update_layout( + title=None, + template="plotly_white", + xaxis=dict(title="samples", gridcolor="#eee"), + yaxis=dict(title="accuracy", gridcolor="#eee"), + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), + margin=dict(l=50, r=20, t=60, b=50), + ) + acc_plot = plot_with_table_style_title(acc_fig, "Learning Curves — Label Accuracy") + + # 2) Learning Curve — Loss + if problem_type in ("binary", "multiclass"): + classes = np.unique(y_true) + loss_fig = go.Figure() + added_loss = False + if pred_proba is not None: + pp = pred_proba.reshape(-1) if pred_proba.ndim == 1 or (pred_proba.ndim == 2 and pred_proba.shape[1] == 1) else pred_proba + train_sizes, train_means, train_stds = generate_learning_curve_from_predictions( + y_true=y_true, + y_proba=pp, + classes=classes, + metric="log_loss", + title="Learning Curves — Label Loss", + seed=seed, + return_stats=True, + ) + loss_fig.add_trace(go.Scatter( + x=train_sizes, y=train_means, mode="lines+markers", name="Train", + line=dict(color="#1f77b4", width=3, shape="spline"), marker=dict(size=7), + error_y=dict(type="data", array=train_stds, visible=True), + )) + added_loss = True + if pred_proba_val is not None and y_true_val is not None: + pp_val = pred_proba_val.reshape(-1) if pred_proba_val.ndim == 1 or (pred_proba_val.ndim == 2 and pred_proba_val.shape[1] == 1) else pred_proba_val + val_sizes, val_means, val_stds = generate_learning_curve_from_predictions( + y_true=y_true_val, + y_proba=pp_val, + classes=classes, + metric="log_loss", + title="Learning Curves — Label Loss", + seed=seed, + return_stats=True, + ) + loss_fig.add_trace(go.Scatter( + x=val_sizes, y=val_means, mode="lines+markers", name="Validation", + line=dict(color="#ff7f0e", width=3, shape="spline"), marker=dict(size=7), + error_y=dict(type="data", array=val_stds, visible=True), + )) + added_loss = True + if added_loss: + loss_fig.update_layout( + title=None, + template="plotly_white", + xaxis=dict(title="epoch", gridcolor="#eee"), + yaxis=dict(title="loss", gridcolor="#eee"), + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), + margin=dict(l=50, r=20, t=60, b=50), + ) + loss_plot = plot_with_table_style_title(loss_fig, "Learning Curves — Label Loss") + + # Confusion matrices & per-class metrics + cm_train = pc_train = cm_val = pc_val = None + + # Probability diagnostics (binary) + if problem_type == "binary": + # Combined Calibration (Train/Val) + cal_fig = go.Figure() + added_cal = False + if pos_scores_train is not None: + y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int) + prob_true, prob_pred = calibration_curve(y_bin_train, pos_scores_train, n_bins=10, strategy="uniform") + cal_fig.add_trace(go.Scatter( + x=prob_pred, y=prob_true, mode="lines+markers", + name="Train", + line=dict(color="#1f77b4", width=3), + marker=dict(size=7, color="#1f77b4"), + )) + added_cal = True + if pos_scores_val is not None and y_true_val is not None: + y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) + prob_true_v, prob_pred_v = calibration_curve(y_bin_val, pos_scores_val, n_bins=10, strategy="uniform") + cal_fig.add_trace(go.Scatter( + x=prob_pred_v, y=prob_true_v, mode="lines+markers", + name="Validation", + line=dict(color="#ff7f0e", width=3), + marker=dict(size=7, color="#ff7f0e"), + )) + added_cal = True + if added_cal: + cal_fig.add_trace(go.Scatter( + x=[0, 1], y=[0, 1], + mode="lines", + line=dict(dash="dash", color="#808080", width=2), + name="Perfect", + showlegend=True, + )) + cal_fig.update_layout( + title=None, + xaxis_title="Predicted Probability", + yaxis_title="Observed Probability", + xaxis=dict(range=[0, 1]), + yaxis=dict(range=[0, 1]), + template="plotly_white", + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), + margin=dict(l=60, r=40, t=50, b=50), + ) + cal_combined = plot_with_table_style_title(cal_fig, "Calibration Curve (Train vs Validation)") + + # Combined ROC (Train/Val) + roc_fig = go.Figure() + added_roc = False + if pos_scores_train is not None: + y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int) + fpr_tr, tpr_tr, thr_tr = roc_curve(y_bin_train, pos_scores_train) + roc_fig.add_trace(go.Scatter( + x=fpr_tr, y=tpr_tr, mode="lines", + name="Train", + line=dict(color="#1f77b4", width=3), + )) + if threshold is not None and np.isfinite(thr_tr).any(): + finite = np.isfinite(thr_tr) + idx_local = int(np.argmin(np.abs(thr_tr[finite] - float(threshold)))) + idx = int(np.nonzero(finite)[0][idx_local]) + roc_fig.add_trace(go.Scatter( + x=[fpr_tr[idx]], y=[tpr_tr[idx]], + mode="markers", + name="Train @ threshold", + marker=dict(size=12, color="#1f77b4", symbol="x") + )) + added_roc = True + if pos_scores_val is not None and y_true_val is not None: + y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) + fpr_v, tpr_v, thr_v = roc_curve(y_bin_val, pos_scores_val) + roc_fig.add_trace(go.Scatter( + x=fpr_v, y=tpr_v, mode="lines", + name="Validation", + line=dict(color="#ff7f0e", width=3), + )) + if threshold is not None and np.isfinite(thr_v).any(): + finite = np.isfinite(thr_v) + idx_local = int(np.argmin(np.abs(thr_v[finite] - float(threshold)))) + idx = int(np.nonzero(finite)[0][idx_local]) + roc_fig.add_trace(go.Scatter( + x=[fpr_v[idx]], y=[tpr_v[idx]], + mode="markers", + name="Val @ threshold", + marker=dict(size=12, color="#ff7f0e", symbol="x") + )) + added_roc = True + if added_roc: + roc_fig.add_trace(go.Scatter( + x=[0, 1], y=[0, 1], mode="lines", + line=dict(dash="dash", width=2, color="#808080"), + showlegend=False + )) + roc_fig.update_layout( + title=None, + xaxis_title="False Positive Rate", + yaxis_title="True Positive Rate", + template="plotly_white", + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), + margin=dict(l=60, r=20, t=60, b=60), + ) + roc_combined = plot_with_table_style_title(roc_fig, "ROC Curve (Train vs Validation)") + + # Combined PR (Train/Val) + pr_fig = go.Figure() + added_pr = False + if pos_scores_train is not None: + y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int) + prec_tr, rec_tr, thr_tr = precision_recall_curve(y_bin_train, pos_scores_train) + pr_auc_tr = auc(rec_tr, prec_tr) + pr_fig.add_trace(go.Scatter( + x=rec_tr, y=prec_tr, mode="lines", + name=f"Train (AUC={pr_auc_tr:.3f})", + line=dict(color="#1f77b4", width=3), + )) + if threshold is not None and len(thr_tr): + j = int(np.argmin(np.abs(thr_tr - float(threshold)))) + j = int(np.clip(j, 0, len(thr_tr) - 1)) + pr_fig.add_trace(go.Scatter( + x=[rec_tr[j + 1]], y=[prec_tr[j + 1]], + mode="markers", + name="Train @ threshold", + marker=dict(size=12, color="#1f77b4", symbol="x") + )) + added_pr = True + if pos_scores_val is not None and y_true_val is not None: + y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) + prec_v, rec_v, thr_v = precision_recall_curve(y_bin_val, pos_scores_val) + pr_auc_v = auc(rec_v, prec_v) + pr_fig.add_trace(go.Scatter( + x=rec_v, y=prec_v, mode="lines", + name=f"Validation (AUC={pr_auc_v:.3f})", + line=dict(color="#ff7f0e", width=3), + )) + if threshold is not None and len(thr_v): + j = int(np.argmin(np.abs(thr_v - float(threshold)))) + j = int(np.clip(j, 0, len(thr_v) - 1)) + pr_fig.add_trace(go.Scatter( + x=[rec_v[j + 1]], y=[prec_v[j + 1]], + mode="markers", + name="Val @ threshold", + marker=dict(size=12, color="#ff7f0e", symbol="x") + )) + added_pr = True + if added_pr: + pr_fig.update_layout( + title=None, + xaxis_title="Recall", + yaxis_title="Precision", + template="plotly_white", + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), + margin=dict(l=60, r=20, t=60, b=60), + ) + pr_combined = plot_with_table_style_title(pr_fig, "Precision–Recall Curve (Train vs Validation)") + + if pos_scores_val is not None and y_true_val is not None: + y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) + fig_thr_val = generate_threshold_plot(y_true_bin=y_bin_val, y_prob=pos_scores_val, title="Threshold Plot (Validation)", + user_threshold=threshold) + threshold_val_plot = plot_with_table_style_title(fig_thr_val, "Threshold Plot (Validation)") + + # Multiclass OVR ROC (validation) + if problem_type == "multiclass" and pred_proba_val is not None and pred_proba_val.ndim >= 2 and y_true_val is not None: + classes_val = np.unique(y_true_val) + fig_mc_roc_val = generate_multiclass_roc_curve_plot(y_true_val, pred_proba_val, classes_val, title="One-vs-Rest ROC (Validation)") + mc_roc_val = plot_with_table_style_title(fig_mc_roc_val, "One-vs-Rest ROC (Validation)") + + # Prediction Confidence Histogram (train/val) + conf_train = conf_val = None + + # Per-class accuracy bars + if problem_type in ("binary", "multiclass") and pred_labels is not None: + classes_for_bar = pd.Index(np.unique(y_true), dtype=object).tolist() + acc_vals = [] + for c in classes_for_bar: + mask = y_true == c + acc_vals.append(float((np.asarray(pred_labels)[mask] == c).mean()) if mask.any() else 0.0) + bar_fig = go.Figure(data=go.Bar(x=[str(c) for c in classes_for_bar], y=acc_vals, marker_color="#1f77b4")) + bar_fig.update_layout( + title=None, + template="plotly_white", + xaxis=dict(title="Label", gridcolor="#eee"), + yaxis=dict(title="Accuracy", gridcolor="#eee", range=[0, 1]), + margin=dict(l=50, r=20, t=60, b=50), + ) + bar_train = plot_with_table_style_title(bar_fig, "Per-Class Training Accuracy") + if problem_type in ("binary", "multiclass") and pred_labels_val is not None and y_true_val is not None: + classes_for_bar_val = pd.Index(np.unique(y_true_val), dtype=object).tolist() + acc_vals_val = [] + for c in classes_for_bar_val: + mask = y_true_val == c + acc_vals_val.append(float((np.asarray(pred_labels_val)[mask] == c).mean()) if mask.any() else 0.0) + bar_fig_val = go.Figure(data=go.Bar(x=[str(c) for c in classes_for_bar_val], y=acc_vals_val, marker_color="#ff7f0e")) + bar_fig_val.update_layout( + title=None, + template="plotly_white", + xaxis=dict(title="Label", gridcolor="#eee"), + yaxis=dict(title="Accuracy", gridcolor="#eee", range=[0, 1]), + margin=dict(l=50, r=20, t=60, b=50), + ) + bar_val = plot_with_table_style_title(bar_fig_val, "Per-Class Validation Accuracy") + + # Assemble in requested order + pieces: list[str] = [] + if perf_card: + pieces.append(perf_card) + for block in (threshold_val_plot, roc_combined, pr_combined): + if block: + pieces.append(block) + # Remaining plots (keep existing order) + for block in (cal_combined, cm_train, pc_train, cm_val, pc_val, mc_roc_val, conf_train, conf_val, bar_train, bar_val): + if block: + pieces.append(block) + # Learning curves should appear last in the tab + for block in (acc_plot, loss_plot): + if block: + pieces.append(block) + + if not pieces: + return "<h2>Training Diagnostics</h2><p><em>No training diagnostics available for this run.</em></p>" + + return "<h2>Train and Validation Performance Summary</h2>" + "".join(pieces) + + +def generate_learning_curve( + estimator, + X, + y, + scoring: str = "r2", + cv_folds: int = 5, + n_jobs: int = -1, + train_sizes: np.ndarray = np.linspace(0.1, 1.0, 10), + title: str = "Learning Curve", + path: Optional[str] = None, +) -> go.Figure: + """ + Learning curve using sklearn.learning_curve, visualized with Plotly. + """ + sizes, train_scores, test_scores = skl_learning_curve( + estimator, X, y, cv=cv_folds, scoring=scoring, n_jobs=n_jobs, train_sizes=train_sizes + ) + train_mean = train_scores.mean(axis=1) + train_std = train_scores.std(axis=1) + test_mean = test_scores.mean(axis=1) + test_std = test_scores.std(axis=1) + + fig = go.Figure() + fig.add_trace(go.Scatter( + x=sizes, y=train_mean, mode="lines+markers", name="Training score", + error_y=dict(type="data", array=train_std, visible=True) + )) + fig.add_trace(go.Scatter( + x=sizes, y=test_mean, mode="lines+markers", name="CV score", + error_y=dict(type="data", array=test_std, visible=True) + )) + fig.update_layout( + title=None, + xaxis_title="Training examples", + yaxis_title=scoring, + template="plotly_white", + ) + _save_plotly(fig, path) + return fig + +# ========================= +# SHAP (Matplotlib-based) +# ========================= + + +def generate_shap_summary_plot( + shap_values, features: pd.DataFrame, title: str = "SHAP Summary Plot", path: Optional[str] = None +) -> None: + """ + SHAP summary plot (Matplotlib). SHAP's interactive support with Plotly is limited; + keep matplotlib for clarity and stability. + """ + plt.figure(figsize=(10, 8)) + shap.summary_plot(shap_values, features, show=False) + plt.title(title) + _save_matplotlib(path) + + +def generate_shap_force_plot( + explainer, instance: pd.DataFrame, title: str = "SHAP Force Plot", path: Optional[str] = None +) -> None: + """ + SHAP force plot (Matplotlib). + """ + shap_values = explainer(instance) + plt.figure(figsize=(10, 4)) + shap.plots.force(shap_values[0], show=False) + plt.title(title) + _save_matplotlib(path) + + +def generate_shap_waterfall_plot( + explainer, instance: pd.DataFrame, title: str = "SHAP Waterfall Plot", path: Optional[str] = None +) -> None: + """ + SHAP waterfall plot (Matplotlib). + """ + shap_values = explainer(instance) + plt.figure(figsize=(10, 6)) + shap.plots.waterfall(shap_values[0], show=False) + plt.title(title) + _save_matplotlib(path) + + +def infer_problem_type(predictor, df_train_full: pd.DataFrame, label_column: str) -> str: + """ + Return 'binary', 'multiclass', or 'regression'. + Prefer the predictor's own metadata when available; otherwise infer from label dtype/uniques. + """ + # AutoGluon predictors usually expose .problem_type; be defensive. + pt = getattr(predictor, "problem_type", None) + if isinstance(pt, str): + pt_l = pt.lower() + if "regression" in pt_l: + return "regression" + if "binary" in pt_l: + return "binary" + if "multiclass" in pt_l or "multiclass" in pt_l: + return "multiclass" + + y = df_train_full[label_column] + if pd.api.types.is_numeric_dtype(y) and y.nunique() > 10: + return "regression" + return "binary" if y.nunique() == 2 else "multiclass" + + +def _safe_floatify(d: Dict[str, Any]) -> Dict[str, float]: + """Make evaluate() outputs JSON/csv friendly floats.""" + out = {} + for k, v in d.items(): + try: + out[k] = float(v) + except Exception: + # keep only real-valued scalars + pass + return out + + +def evaluate_all( + predictor, + df_train: pd.DataFrame, + df_val: pd.DataFrame, + df_test: pd.DataFrame, + label_column: str, + problem_type: str, +) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float]]: + """ + Run predictor.evaluate on train/val/test and normalize the result dicts to floats. + MultiModalPredictor does not accept the `silent` kwarg, so call defensively. + """ + def _evaluate(df): + try: + return predictor.evaluate(df, silent=True) + except TypeError: + return predictor.evaluate(df) + + train_scores = _safe_floatify(_evaluate(df_train)) + val_scores = _safe_floatify(_evaluate(df_val)) + test_scores = _safe_floatify(_evaluate(df_test)) + return train_scores, val_scores, test_scores + + +def build_summary_html( + predictor, + df_train: pd.DataFrame, + df_val: Optional[pd.DataFrame], + df_test: Optional[pd.DataFrame], + label_column: str, + extra_run_rows: Optional[list[tuple[str, str]]] = None, + class_balance_html: Optional[str] = None, + perf_table_html: Optional[str] = None, +) -> str: + sections = [] + + # Dataset Overview (first section in the tab) + if class_balance_html: + sections.append(f""" +<section class="section"> + <h2 class="section-title">Dataset Overview</h2> + <div class="card"> + {class_balance_html} + </div> +</section> +""".strip()) + + # Performance Summary + if perf_table_html: + sections.append(f""" +<section class="section"> + <h2 class="section-title">Model Performance Summary</h2> + <div class="card"> + {perf_table_html} + </div> +</section> +""".strip()) + + # Model Configuration + + # Remove Predictor type and Framework, and ensure Model Architecture is present + base_rows: list[tuple[str, str]] = [] + if extra_run_rows: + # Remove any rows with keys 'Predictor type' or 'Framework' + base_rows.extend([(k, v) for (k, v) in extra_run_rows if k not in ("Predictor type", "Framework")]) + + def _fmt(v): + if v is None or v == "": + return "—" + return _escape(str(v)) + + rows_html = "\n".join( + f"<tr><td>{_escape(str(k))}</td><td>{_fmt(v)}</td></tr>" + for k, v in base_rows + ) + + sections.append(f""" +<section class="section"> + <h2 class="section-title">Model Configuration</h2> + <div class="card"> + <table class="kv-table"> + <thead><tr><th>Key</th><th>Value</th></tr></thead> + <tbody> + {rows_html} + </tbody> + </table> + </div> +</section> +""".strip()) + + return "\n".join(sections).strip() + + +def build_feature_importance_html(predictor, df_train: pd.DataFrame, label_column: str) -> str: + """Build a visualization of feature importance.""" + try: + # Try to get feature importance from predictor + fi = None + if hasattr(predictor, "feature_importance") and callable(predictor.feature_importance): + try: + fi = predictor.feature_importance(df_train) + except Exception as e: + return f"<p>Could not compute feature importance: {e}</p>" + + if fi is None or (isinstance(fi, pd.DataFrame) and fi.empty): + return "<p>Feature importance not available for this model.</p>" + + # Format as a sortable table + rows = [] + if isinstance(fi, pd.DataFrame): + fi = fi.sort_values("importance", ascending=False) + for _, row in fi.iterrows(): + feat = row.index[0] if isinstance(row.index, pd.Index) else row["feature"] + imp = float(row["importance"]) + rows.append(f"<tr><td>{_escape(str(feat))}</td><td>{imp:.4f}</td></tr>") + else: + # Handle other formats (dict, etc) + for feat, imp in sorted(fi.items(), key=lambda x: float(x[1]), reverse=True): + rows.append(f"<tr><td>{_escape(str(feat))}</td><td>{float(imp):.4f}</td></tr>") + + if not rows: + return "<p>No feature importance values available.</p>" + + table_html = f""" + <table class="performance-summary"> + <thead> + <tr> + <th class="sortable">Feature</th> + <th class="sortable">Importance</th> + </tr> + </thead> + <tbody> + {"".join(rows)} + </tbody> + </table> + """ + return table_html + + except Exception as e: + return f"<p>Error building feature importance visualization: {e}</p>" + + +def build_test_html_and_plots( + predictor, + problem_type: str, + df_test: pd.DataFrame, + label_column: str, + tmpdir: str, + threshold: Optional[float] = None, +) -> Tuple[str, List[str]]: + """ + Create a test-summary section (with a placeholder for metric rows) and a list of Plotly HTML divs. + Returns: (html_template_with_{}, list_of_plot_divs) + """ + plots: List[str] = [] + + y_true = df_test[label_column].values + classes = np.unique(y_true) + + # Try proba/labels where meaningful + pred_labels = None + pred_proba = None + try: + pred_labels = predictor.predict(df_test) + except Exception: + pass + try: + # MultiModalPredictor exposes predict_proba for classification problems. + pred_proba = predictor.predict_proba(df_test) + except Exception: + pred_proba = None + + proba_arr = None + if pred_proba is not None: + if isinstance(pred_proba, pd.Series): + proba_arr = pred_proba.to_numpy().reshape(-1, 1) + elif isinstance(pred_proba, pd.DataFrame): + proba_arr = pred_proba.to_numpy() + else: + proba_arr = np.asarray(pred_proba) + + # Thresholded labels for binary + if problem_type == "binary" and threshold is not None and proba_arr is not None: + pos_label, neg_label = classes.max(), classes.min() + pos_scores = proba_arr.reshape(-1) if (proba_arr.ndim == 1 or proba_arr.shape[1] == 1) else proba_arr[:, -1] + pred_labels = np.where(pos_scores >= float(threshold), pos_label, neg_label) + + # Confusion matrix / per-class now reflect thresholded labels + if problem_type in ("binary", "multiclass") and pred_labels is not None: + cm_title = "Confusion Matrix" + if threshold is not None and problem_type == "binary": + thr_str = f"{float(threshold):.3f}".rstrip("0").rstrip(".") + cm_title = f"Confusion Matrix (Threshold = {thr_str})" + fig_cm = generate_confusion_matrix_plot(y_true, pred_labels, title=cm_title) + plots.append(plot_with_table_style_title(fig_cm, cm_title)) + + fig_pc = generate_per_class_metrics_plot(y_true, pred_labels, title="Per-Class Metrics") + plots.append(plot_with_table_style_title(fig_pc, "Per-Class Metrics")) + + # ROC/PR where possible — choose positive-class scores safely + pos_label = classes.max() # or set explicitly, e.g., 1 or "yes" + + if isinstance(pred_proba, pd.DataFrame): + proba_arr = pred_proba.to_numpy() + if pos_label in pred_proba.columns: + pos_idx = list(pred_proba.columns).index(pos_label) + else: + pos_idx = -1 # fallback to last column + elif isinstance(pred_proba, pd.Series): + proba_arr = pred_proba.to_numpy().reshape(-1, 1) + pos_idx = 0 + else: + proba_arr = np.asarray(pred_proba) if pred_proba is not None else None + pos_idx = -1 if (proba_arr is not None and proba_arr.ndim == 2 and proba_arr.shape[1] > 1) else 0 + + if proba_arr is not None: + y_bin = (y_true == pos_label).astype(int) + pos_scores = ( + proba_arr.reshape(-1) + if proba_arr.ndim == 1 or proba_arr.shape[1] == 1 + else proba_arr[:, pos_idx] + ) + + fig_roc = generate_roc_curve_plot(y_bin, pos_scores, title="ROC Curve", marker_threshold=threshold) + plots.append(plot_with_table_style_title(fig_roc, f"ROC Curve{'' if threshold is None else f' (marker at threshold={threshold:.2f})'}")) + + fig_pr = generate_pr_curve_plot(y_bin, pos_scores, title="Precision–Recall Curve", marker_threshold=threshold) + plots.append(plot_with_table_style_title(fig_pr, f"Precision–Recall Curve{'' if threshold is None else f' (marker at threshold={threshold:.2f})'}")) + + # Additional diagnostics aligned with ImageLearner style + if problem_type == "binary": + conf_fig = plot_confidence_histogram(pos_scores, bins=20, title="Prediction Confidence (Test)") + plots.append(plot_with_table_style_title(conf_fig, "Prediction Confidence (Test)")) + else: + conf_fig = plot_confidence_histogram(proba_arr, bins=20, title="Prediction Confidence (Top-1, Test)") + plots.append(plot_with_table_style_title(conf_fig, "Prediction Confidence (Top-1, Test)")) + + if problem_type == "multiclass" and proba_arr is not None and proba_arr.ndim >= 2: + fig_mc_roc = generate_multiclass_roc_curve_plot(y_true, proba_arr, classes, title="One-vs-Rest ROC (Test)") + plots.append(plot_with_table_style_title(fig_mc_roc, "One-vs-Rest ROC (Test)")) + + # Regression visuals + if problem_type == "regression": + if pred_labels is None: + pred_labels = predictor.predict(df_test) + fig_sc = generate_scatter_plot(y_true, pred_labels, title="Predicted vs Actual") + plots.append(plot_with_table_style_title(fig_sc, "Predicted vs Actual")) + + fig_res = generate_residual_plot(y_true, pred_labels, title="Residual Plot") + plots.append(plot_with_table_style_title(fig_res, "Residual Plot")) + + fig_hist = generate_residual_histogram(y_true, pred_labels, title="Residual Histogram") + plots.append(plot_with_table_style_title(fig_hist, "Residual Histogram")) + + fig_cal = generate_regression_calibration_plot(y_true, pred_labels, title="Regression Calibration") + plots.append(plot_with_table_style_title(fig_cal, "Regression Calibration")) + + # Small HTML template with placeholder for metric rows the caller fills in + test_html_template = """ + <h2>Test Performance Summary</h2> + <table class="performance-summary"> + <thead><tr><th>Metric</th><th>Test</th></tr></thead> + <tbody>{}</tbody> + </table> + """ + return test_html_template, plots + + +def build_feature_html( + predictor, + df_train: pd.DataFrame, + label_column: str, + include_modalities: bool = True, # ← NEW + include_class_balance: bool = True, # ← NEW +) -> str: + sections = [] + + # (Typical feature importance content…) + fi_html = build_feature_importance_html(predictor, df_train, label_column) + sections.append(f"<section class='section'><h2 class='section-title'>Feature Importance</h2><div class='card'>{fi_html}</div></section>") + + # Previously: Modalities & Inputs and/or Class Balance may have been here. + # Only render them if flags are True. + if include_modalities: + from report_utils import build_modalities_html + modalities_html = build_modalities_html(predictor, df_train, label_column) + sections.append(f"<section class='section'><h2 class='section-title'>Modalities & Inputs</h2><div class='card'>{modalities_html}</div></section>") + + if include_class_balance: + from report_utils import build_class_balance_html + cb_html = build_class_balance_html(df_train, label_column) + sections.append(f"<section class='section'><h2 class='section-title'>Class Balance (Train Full)</h2><div class='card'>{cb_html}</div></section>") + + return "\n".join(sections) + + +def assemble_full_html_report( + summary_html: str, + train_html: str, + test_html: str, + plots: List[str], + feature_html: str, +) -> str: + """ + Wrap the four tabs using utils.build_tabbed_html and return full HTML. + """ + # Append plots under the Test tab (already wrapped with titles) + test_full = test_html + "".join(plots) + + tabs = build_tabbed_html(summary_html, train_html, test_full, feature_html, explainer_html=None) + + html_out = get_html_template() + + # 🔧 Ensure Plotly JS is available (we render plots with include_plotlyjs=False) + html_out += '\n<script src="https://cdn.plot.ly/plotly-2.30.0.min.js"></script>\n' + + # Optional: centering tweaks + html_out += """ +<style> + .plotly-center { display: flex; justify-content: center; } + .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { margin: 0 auto !important; } + .js-plotly-plot, .plotly-graph-div { margin-left: auto !important; margin-right: auto !important; } +</style> +""" + # Help modal HTML/JS + html_out += get_metrics_help_modal() + + html_out += tabs + html_out += get_html_closing() + return html_out
