Mercurial > repos > goeckslab > multimodal_learner
view plot_logic.py @ 0:375c36923da1 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
| author | goeckslab |
|---|---|
| date | Tue, 09 Dec 2025 23:49:47 +0000 |
| parents | |
| children |
line wrap: on
line source
from __future__ import annotations import html import os from html import escape as _escape from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union import matplotlib.pyplot as plt import numpy as np import pandas as pd import plotly.express as px import plotly.graph_objects as go import shap from feature_help_modal import get_metrics_help_modal from report_utils import build_tabbed_html, get_html_closing, get_html_template from sklearn.calibration import calibration_curve from sklearn.metrics import ( auc, average_precision_score, classification_report, confusion_matrix, log_loss, precision_recall_curve, roc_auc_score, roc_curve, ) from sklearn.model_selection import learning_curve as skl_learning_curve from sklearn.preprocessing import label_binarize # ========================= # Utilities # ========================= def plot_with_table_style_title(fig, title: str) -> str: """ Render a Plotly figure with a report-style <h2> header so it matches the green table section headers. """ # kill Plotly’s built-in title fig.update_layout(title=None) # figure HTML without PlotlyJS (we load it once globally) plot_html = fig.to_html(full_html=False, include_plotlyjs=False) # use <h2> — your CSS already styles <h2> like the table headers return f""" <h2>{html.escape(title)}</h2> <div class="plotly-center">{plot_html}</div> """.strip() def _save_plotly(fig: go.Figure, path: Optional[str]) -> None: """ Save a Plotly figure. If `path` ends with `.html`, save interactive HTML. If it ends with a raster extension (png/jpg/jpeg/webp), uses Kaleido. If None, do nothing (caller may choose to display in notebook). """ if not path: return ext = os.path.splitext(path)[1].lower() if ext == ".html": fig.write_html(path, include_plotlyjs="cdn", full_html=True) else: # Requires kaleido: pip install -U kaleido fig.write_image(path) def _save_matplotlib(path: Optional[str]) -> None: """Save current Matplotlib figure if `path` is provided, else show().""" if path: plt.savefig(path, bbox_inches="tight") plt.close() else: plt.show() # ========================= # Classification Plots # ========================= def generate_confusion_matrix_plot( y_true, y_pred, title: str = "Confusion Matrix", ) -> go.Figure: y_true = np.asarray(y_true) y_pred = np.asarray(y_pred) # Class order (works for strings or numbers) labels = pd.Index(np.unique(np.concatenate([y_true, y_pred])), dtype=object).tolist() cm = confusion_matrix(y_true, y_pred, labels=labels) max_val = cm.max() if cm.size else 0 # Use categorical axes by passing string labels for x/y cats = [str(label) for label in labels] total = int(cm.sum()) fig = go.Figure( data=go.Heatmap( z=cm, x=cats, # categorical x y=cats, # categorical y colorscale="Blues", showscale=True, colorbar=dict(title="Count"), xgap=2, ygap=2, hovertemplate="True=%{y}<br>Pred=%{x}<br>Count=%{z}<extra></extra>", zmin=0 ) ) # Add annotations with count and percentage (all white text, matching sample_output.html) annotations = [] for i in range(cm.shape[0]): for j in range(cm.shape[1]): val = int(cm[i, j]) pct = (val / total * 100) if total > 0 else 0 text_color = "white" if max_val and val > (max_val / 2) else "black" # Count annotation (bold, bottom) annotations.append( dict( x=cats[j], y=cats[i], text=f"<b>{val}</b>", showarrow=False, font=dict(color=text_color, size=14), xanchor="center", yanchor="bottom", yshift=2 ) ) # Percentage annotation (top) annotations.append( dict( x=cats[j], y=cats[i], text=f"{pct:.1f}%", showarrow=False, font=dict(color=text_color, size=13), xanchor="center", yanchor="top", yshift=-2 ) ) fig.update_layout( title=None, xaxis_title="Predicted label", yaxis_title="True label", xaxis=dict(type="category"), yaxis=dict(type="category", autorange="reversed"), # typical CM orientation margin=dict(l=80, r=20, t=40, b=80), template="plotly_white", plot_bgcolor="white", paper_bgcolor="white", annotations=annotations ) return fig def generate_roc_curve_plot( y_true_bin: np.ndarray, y_score: np.ndarray, title: str = "ROC Curve", marker_threshold: float | None = None, ) -> go.Figure: y_true_bin = np.asarray(y_true_bin).astype(int).reshape(-1) y_score = np.asarray(y_score).astype(float).reshape(-1) fpr, tpr, thr = roc_curve(y_true_bin, y_score) roc_auc = auc(fpr, tpr) fig = go.Figure() fig.add_trace(go.Scatter( x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})", line=dict(width=3) )) # 45° chance line (no legend to keep it clean) fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash", width=2, color="#888"), showlegend=False)) # Optional marker at the user threshold if marker_threshold is not None and len(thr): # roc_curve returns thresholds of same length as fpr/tpr; includes inf at idx 0 finite = np.isfinite(thr) if np.any(finite): idx_local = int(np.argmin(np.abs(thr[finite] - float(marker_threshold)))) idx = int(np.nonzero(finite)[0][idx_local]) # map back to original indices x_m, y_m = float(fpr[idx]), float(tpr[idx]) fig.add_trace( go.Scatter( x=[x_m], y=[y_m], mode="markers", name=f"@ {float(marker_threshold):.2f}", marker=dict(size=12, color="red", symbol="x") ) ) fig.add_annotation( x=0.02, y=0.98, xref="paper", yref="paper", text=f"threshold = {float(marker_threshold):.2f}", showarrow=False, font=dict(color="black", size=12), align="left" ) fig.update_layout( title=None, xaxis_title="False Positive Rate", yaxis_title="True Positive Rate", template="plotly_white", legend=dict(x=1, y=0, xanchor="right"), margin=dict(l=60, r=20, t=60, b=60), ) return fig def generate_pr_curve_plot( y_true_bin: np.ndarray, y_score: np.ndarray, title: str = "Precision–Recall Curve", marker_threshold: float | None = None, ) -> go.Figure: y_true_bin = np.asarray(y_true_bin).astype(int).reshape(-1) y_score = np.asarray(y_score).astype(float).reshape(-1) precision, recall, thr = precision_recall_curve(y_true_bin, y_score) pr_auc = auc(recall, precision) fig = go.Figure() fig.add_trace(go.Scatter( x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})", line=dict(width=3) )) # Optional marker at the user threshold if marker_threshold is not None and len(thr): # In PR, thresholds has length len(precision)-1. The point for thr[j] is (recall[j+1], precision[j+1]). j = int(np.argmin(np.abs(thr - float(marker_threshold)))) j = int(np.clip(j, 0, len(thr) - 1)) x_m, y_m = float(recall[j + 1]), float(precision[j + 1]) fig.add_trace( go.Scatter( x=[x_m], y=[y_m], mode="markers", name=f"@ {float(marker_threshold):.2f}", marker=dict(size=12, color="red", symbol="x") ) ) fig.add_annotation( x=0.02, y=0.98, xref="paper", yref="paper", text=f"threshold = {float(marker_threshold):.2f}", showarrow=False, font=dict(color="black", size=12), align="left" ) fig.update_layout( title=None, xaxis_title="Recall", yaxis_title="Precision", template="plotly_white", legend=dict(x=1, y=0, xanchor="right"), margin=dict(l=60, r=20, t=60, b=60), ) return fig def generate_calibration_plot( y_true_bin: np.ndarray, y_prob: np.ndarray, n_bins: int = 10, title: str = "Calibration Plot", path: Optional[str] = None, ) -> go.Figure: """ Binary calibration curve (Plotly). """ prob_true, prob_pred = calibration_curve(y_true_bin, y_prob, n_bins=n_bins, strategy="uniform") fig = go.Figure() fig.add_trace(go.Scatter( x=prob_pred, y=prob_true, mode="lines+markers", name="Model", line=dict(color="#1f77b4", width=3), marker=dict(size=7, color="#1f77b4") )) fig.add_trace( go.Scatter( x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash", color="#808080", width=2), name="Perfect" ) ) fig.update_layout( title=None, xaxis_title="Predicted Probability", yaxis_title="Observed Probability", yaxis=dict(range=[0, 1]), xaxis=dict(range=[0, 1]), template="plotly_white", margin=dict(l=60, r=40, t=50, b=50), ) _save_plotly(fig, path) return fig def generate_threshold_plot( y_true_bin: np.ndarray, y_prob: np.ndarray, title: str = "Threshold Plot", user_threshold: float | None = None, ) -> go.Figure: y_true = np.asarray(y_true_bin, dtype=int).ravel() p = np.asarray(y_prob, dtype=float).ravel() p = np.nan_to_num(p, nan=0.0) p = np.clip(p, 0.0, 1.0) def _compute_metrics(thresholds: np.ndarray): """Vectorized-ish helper to compute precision/recall/F1/queue rate arrays.""" prec, rec, f1, qrate = [], [], [], [] for t in thresholds: yhat = (p >= t).astype(int) tp = int(((yhat == 1) & (y_true == 1)).sum()) fp = int(((yhat == 1) & (y_true == 0)).sum()) fn = int(((yhat == 0) & (y_true == 1)).sum()) pr = tp / (tp + fp) if (tp + fp) else np.nan # undefined when no predicted positives rc = tp / (tp + fn) if (tp + fn) else 0.0 f = (2 * pr * rc) / (pr + rc) if (pr + rc) and not np.isnan(pr) else 0.0 q = float(yhat.mean()) prec.append(pr) rec.append(rc) f1.append(f) qrate.append(q) return np.asarray(prec, dtype=float), np.asarray(rec, dtype=float), np.asarray(f1, dtype=float), np.asarray(qrate, dtype=float) # Use uniform threshold grid for plotting (0 to 1 in steps of 0.01) th = np.linspace(0.0, 1.0, 101) prec, rec, f1_arr, qrate = _compute_metrics(th) # Compute F1*-optimal threshold using actual score distribution (more precise than grid) cand_th = np.unique(np.concatenate(([0.0, 1.0], p))) # cap to a reasonable size by sampling if extremely large if cand_th.size > 2000: cand_th = np.linspace(0.0, 1.0, 2001) _, _, f1_cand, _ = _compute_metrics(cand_th) if np.all(np.isnan(f1_cand)): t_star = 0.5 # fallback when no valid F1 can be computed else: f1_max = np.nanmax(f1_cand) best_idxs = np.where(np.isclose(f1_cand, f1_max, equal_nan=False))[0] # pick the middle of the best candidates to avoid biasing toward 0 best_idx = int(best_idxs[len(best_idxs) // 2]) t_star = float(cand_th[best_idx]) # Replace NaNs for plotting (set to 0 where precision is undefined) prec_plot = np.nan_to_num(prec, nan=0.0) fig = go.Figure() # Precision (blue line) fig.add_trace(go.Scatter( x=th, y=prec_plot, mode="lines", name="Precision", line=dict(width=3, color="#1f77b4"), hovertemplate="Threshold=%{x:.3f}<br>Precision=%{y:.3f}<extra></extra>" )) # Recall (orange line) fig.add_trace(go.Scatter( x=th, y=rec, mode="lines", name="Recall", line=dict(width=3, color="#ff7f0e"), hovertemplate="Threshold=%{x:.3f}<br>Recall=%{y:.3f}<extra></extra>" )) # F1 (green line) fig.add_trace(go.Scatter( x=th, y=f1_arr, mode="lines", name="F1", line=dict(width=3, color="#2ca02c"), hovertemplate="Threshold=%{x:.3f}<br>F1=%{y:.3f}<extra></extra>" )) # Queue Rate (grey dashed line) fig.add_trace(go.Scatter( x=th, y=qrate, mode="lines", name="Queue Rate", line=dict(width=2, color="#808080", dash="dash"), hovertemplate="Threshold=%{x:.3f}<br>Queue Rate=%{y:.3f}<extra></extra>" )) # F1*-optimal threshold marker (dashed vertical line) fig.add_vline( x=t_star, line_width=2, line_dash="dash", line_color="black", annotation_text=f"t* = {t_star:.2f}", annotation_position="top" ) # User threshold (solid red line) if provided if user_threshold is not None: fig.add_vline( x=float(user_threshold), line_width=2, line_color="red", annotation_text=f"threshold = {float(user_threshold):.2f}", annotation_position="top" ) fig.update_layout( title=None, template="plotly_white", xaxis=dict( title="Discrimination Threshold", range=[0, 1], gridcolor="#e0e0e0", showgrid=True, zeroline=False ), yaxis=dict( title="Score", range=[0, 1], gridcolor="#e0e0e0", showgrid=True, zeroline=False ), legend=dict( orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0 ), margin=dict(l=60, r=20, t=40, b=50), plot_bgcolor="white", paper_bgcolor="white", ) return fig def generate_per_class_metrics_plot( y_true: Sequence, y_pred: Sequence, metrics: Sequence[str] = ("precision", "recall", "f1_score"), title: str = "Classification Report", path: Optional[str] = None, ) -> go.Figure: """ Per-class metrics heatmap (Plotly), similar to sklearn's classification report. Rows = classes, columns = metrics; cell text shows the value (0–1). """ # Map display names -> sklearn keys key_map = {"f1_score": "f1-score", "precision": "precision", "recall": "recall"} report = classification_report( y_true, y_pred, output_dict=True, zero_division=0 ) # Order classes sensibly (numeric if possible, else lexical) def _sort_key(x): try: return (0, float(x)) except Exception: return (1, str(x)) # Use all classes seen in y_true or y_pred (so rows don't jump around) uniq = sorted(set(list(y_true) + list(y_pred)), key=_sort_key) classes = [str(c) for c in uniq] # Build Z matrix (rows=classes, cols=metrics) used_metrics = [key_map.get(m, m) for m in metrics] z = [] for c in classes: row = report.get(c, {}) z.append([float(row.get(m, 0.0) or 0.0) for m in used_metrics]) z = np.array(z, dtype=float) # Pretty cell labels z_text = [[f"{v:.2f}" for v in r] for r in z] fig = go.Figure( data=go.Heatmap( z=z, x=list(metrics), # keep display names ("precision", "recall", "f1_score") y=classes, # classes as strings colorscale="Reds", zmin=0.0, zmax=1.0, colorbar=dict(title="Value"), text=z_text, texttemplate="%{text}", hovertemplate="Class %{y}<br>%{x}: %{z:.2f}<extra></extra>", ) ) fig.update_layout( title=None, xaxis_title="", yaxis_title="Class", template="plotly_white", margin=dict(l=60, r=60, t=70, b=40), ) _save_plotly(fig, path) return fig def generate_multiclass_roc_curve_plot( y_true: Sequence, y_prob: np.ndarray, classes: Sequence, title: str = "Multiclass ROC Curve", path: Optional[str] = None, ) -> go.Figure: """ One-vs-rest ROC curves for multiclass (Plotly). Handles binary passed as 2-column probs as well. """ y_true = np.asarray(y_true) y_prob = np.asarray(y_prob) # Normalize to shape (n_samples, n_classes) if y_prob.ndim == 1 or y_prob.shape[1] == 1: y_prob = np.hstack([1 - y_prob.reshape(-1, 1), y_prob.reshape(-1, 1)]) y_true_bin = label_binarize(y_true, classes=classes) if y_true_bin.shape[1] == 1 and y_prob.shape[1] == 2: y_true_bin = np.hstack([1 - y_true_bin, y_true_bin]) if y_prob.shape[1] != y_true_bin.shape[1]: raise ValueError( f"Shape mismatch: y_prob has {y_prob.shape[1]} columns but y_true_bin has {y_true_bin.shape[1]}." ) fig = go.Figure() for i, cls in enumerate(classes): fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_prob[:, i]) auc_val = roc_auc_score(y_true_bin[:, i], y_prob[:, i]) fig.add_trace(go.Scatter(x=fpr, y=tpr, mode="lines", name=f"{cls} (AUC {auc_val:.2f})")) fig.add_trace( go.Scatter(x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash"), showlegend=False) ) fig.update_layout( title=None, xaxis_title="False Positive Rate", yaxis_title="True Positive Rate", template="plotly_white", ) _save_plotly(fig, path) return fig def generate_multiclass_pr_curve_plot( y_true: Sequence, y_prob: np.ndarray, classes: Optional[Sequence] = None, title: str = "Precision–Recall Curve", path: Optional[str] = None, ) -> go.Figure: """ Multiclass PR curves (Plotly). If classes is None or len==2, shows binary PR. """ y_true = np.asarray(y_true) y_prob = np.asarray(y_prob) fig = go.Figure() if not classes or len(classes) == 2: precision, recall, _ = precision_recall_curve(y_true, y_prob[:, 1]) ap = average_precision_score(y_true, y_prob[:, 1]) fig.add_trace(go.Scatter(x=recall, y=precision, mode="lines", name=f"AP = {ap:.2f}")) else: for i, cls in enumerate(classes): y_true_bin = (y_true == cls).astype(int) y_prob_cls = y_prob[:, i] precision, recall, _ = precision_recall_curve(y_true_bin, y_prob_cls) ap = average_precision_score(y_true_bin, y_prob_cls) fig.add_trace(go.Scatter(x=recall, y=precision, mode="lines", name=f"{cls} (AP {ap:.2f})")) fig.update_layout( title=None, xaxis_title="Recall", yaxis_title="Precision", yaxis=dict(range=[0, 1]), xaxis=dict(range=[0, 1]), template="plotly_white", ) _save_plotly(fig, path) return fig def generate_metric_comparison_bar( metrics_scores: Mapping[str, Sequence[float]], phases: Sequence[str] = ("train", "val", "test"), title: str = "Metric Comparison Across Phases", path: Optional[str] = None, ) -> go.Figure: """ Grouped bar chart comparing metrics across phases (Plotly). metrics_scores: {metric_name: [train, val, test]} """ df = pd.DataFrame(metrics_scores, index=phases).T.reset_index().rename(columns={"index": "Metric"}) df_m = df.melt(id_vars="Metric", var_name="Phase", value_name="Score") fig = px.bar(df_m, x="Metric", y="Score", color="Phase", barmode="group", title=None) ymax = max(1.0, df_m["Score"].max() * 1.05) fig.update_yaxes(range=[0, ymax]) fig.update_layout(template="plotly_white") _save_plotly(fig, path) return fig # ========================= # Regression Plots # ========================= def generate_scatter_plot( y_true: Sequence[float], y_pred: Sequence[float], title: str = "Predicted vs Actual", path: Optional[str] = None, ) -> go.Figure: """ Predicted vs. Actual scatter with y=x reference (Plotly). """ y_true = np.asarray(y_true) y_pred = np.asarray(y_pred) vmin = float(min(np.min(y_true), np.min(y_pred))) vmax = float(max(np.max(y_true), np.max(y_pred))) fig = px.scatter(x=y_true, y=y_pred, opacity=0.6, labels={"x": "Actual", "y": "Predicted"}, title=None) fig.add_trace(go.Scatter(x=[vmin, vmax], y=[vmin, vmax], mode="lines", line=dict(dash="dash"), name="Ideal")) fig.update_layout(template="plotly_white") _save_plotly(fig, path) return fig def generate_residual_plot( y_true: Sequence[float], y_pred: Sequence[float], title: str = "Residual Plot", path: Optional[str] = None, ) -> go.Figure: """ Residuals vs Predicted (Plotly). """ y_true = np.asarray(y_true) y_pred = np.asarray(y_pred) residuals = y_true - y_pred fig = px.scatter(x=y_pred, y=residuals, opacity=0.6, labels={"x": "Predicted", "y": "Residual (Actual - Predicted)"}, title=None) fig.add_hline(y=0, line_dash="dash") fig.update_layout(template="plotly_white") _save_plotly(fig, path) return fig def generate_residual_histogram( y_true: Sequence[float], y_pred: Sequence[float], bins: int = 30, title: str = "Residual Histogram", path: Optional[str] = None, ) -> go.Figure: """ Residuals histogram (Plotly). """ residuals = np.asarray(y_true) - np.asarray(y_pred) fig = px.histogram(x=residuals, nbins=bins, labels={"x": "Residual"}, title=None) fig.update_layout(yaxis_title="Frequency", template="plotly_white") _save_plotly(fig, path) return fig def generate_regression_calibration_plot( y_true: Sequence[float], y_pred: Sequence[float], num_bins: int = 10, title: str = "Regression Calibration Plot", path: Optional[str] = None, ) -> go.Figure: """ Binned Actual vs Predicted means (Plotly). """ y_true = np.asarray(y_true) y_pred = np.asarray(y_pred) order = np.argsort(y_pred) y_true_sorted = y_true[order] y_pred_sorted = y_pred[order] bins = np.array_split(np.arange(len(y_pred_sorted)), num_bins) bin_means_pred = [float(np.mean(y_pred_sorted[idx])) for idx in bins if len(idx)] bin_means_true = [float(np.mean(y_true_sorted[idx])) for idx in bins if len(idx)] vmin = float(min(np.min(y_pred), np.min(y_true))) vmax = float(max(np.max(y_pred), np.max(y_true))) fig = go.Figure() fig.add_trace(go.Scatter(x=bin_means_pred, y=bin_means_true, mode="lines+markers", name="Binned Actual vs Predicted")) fig.add_trace(go.Scatter(x=[vmin, vmax], y=[vmin, vmax], mode="lines", line=dict(dash="dash"), name="Ideal")) fig.update_layout( title=None, xaxis_title="Mean Predicted per bin", yaxis_title="Mean Actual per bin", template="plotly_white", ) _save_plotly(fig, path) return fig # ========================= # Confidence / Diagnostics # ========================= def plot_error_vs_confidence( y_true: Union[Sequence[int], np.ndarray], y_proba: Union[Sequence[float], np.ndarray], n_bins: int = 10, title: str = "Error vs Confidence", path: Optional[str] = None, ) -> go.Figure: """ Error rate vs confidence (binary), confidence=max(p, 1-p). Plotly. """ y_true = np.asarray(y_true) y_proba = np.asarray(y_proba).reshape(-1) y_pred = (y_proba >= 0.5).astype(int) confidence = np.maximum(y_proba, 1 - y_proba) error = (y_pred != y_true).astype(int) bins = np.linspace(0.0, 1.0, n_bins + 1) idx = np.digitize(confidence, bins, right=True) centers, err_rates = [], [] for i in range(1, len(bins)): mask = (idx == i) if mask.any(): centers.append(float(confidence[mask].mean())) err_rates.append(float(error[mask].mean())) fig = go.Figure() fig.add_trace(go.Scatter(x=centers, y=err_rates, mode="lines+markers", name="Error rate")) fig.update_layout( title=None, xaxis_title="Confidence (max predicted probability)", yaxis_title="Error Rate", yaxis=dict(range=[0, 1]), template="plotly_white", ) _save_plotly(fig, path) return fig def plot_confidence_histogram( y_proba: np.ndarray, bins: int = 20, title: str = "Confidence Histogram", path: Optional[str] = None, ) -> go.Figure: """ Histogram of max predicted probabilities (Plotly). Works for binary (n_samples,) or (n_samples,2) and multiclass (n_samples,C). """ y_proba = np.asarray(y_proba) if y_proba.ndim == 1: confidences = np.maximum(y_proba, 1 - y_proba) else: confidences = np.max(y_proba, axis=1) fig = px.histogram( x=confidences, nbins=bins, range_x=(0, 1), histnorm="percent", labels={"x": "Confidence (max predicted probability)", "y": "Percent of samples (%)"}, title=None, ) if fig.data: fig.update_traces(hovertemplate="Conf=%{x:.2f}<br>%{y:.2f}%<extra></extra>") fig.update_layout(yaxis_title="Percent of samples (%)", template="plotly_white") _save_plotly(fig, path) return fig # ========================= # Learning Curve # ========================= def generate_learning_curve_from_predictions( y_true, y_pred=None, y_proba=None, classes=None, metric: str = "accuracy", train_fracs: np.ndarray = np.linspace(0.1, 1.0, 10), n_repeats: int = 5, seed: int = 42, title: str = "Learning Curve", path: str | None = None, return_stats: bool = False, ) -> Union[go.Figure, tuple[list[int], list[float], list[float]]]: rng = np.random.default_rng(seed) y_true = np.asarray(y_true) N = len(y_true) if metric == "accuracy" and y_pred is None: raise ValueError("accuracy curve requires y_pred") if metric == "log_loss" and y_proba is None: raise ValueError("log_loss curve requires y_proba") if y_proba is not None: y_proba = np.asarray(y_proba) if y_pred is not None: y_pred = np.asarray(y_pred) sizes = (np.clip((train_fracs * N).astype(int), 1, N)).tolist() means, stds = [], [] for n in sizes: vals = [] for _ in range(n_repeats): idx = rng.choice(N, size=n, replace=False) if metric == "accuracy": vals.append(float((y_true[idx] == y_pred[idx]).mean())) else: if y_proba.ndim == 1: p = y_proba[idx] pp = np.column_stack([1 - p, p]) else: pp = y_proba[idx] vals.append(float(log_loss(y_true[idx], pp, labels=None if classes is None else classes))) means.append(np.mean(vals)) stds.append(np.std(vals)) if return_stats: return sizes, means, stds fig = go.Figure() fig.add_trace(go.Scatter( x=sizes, y=means, mode="lines+markers", name="Train", line=dict(width=3, shape="spline"), marker=dict(size=7), error_y=dict(type="data", array=stds, visible=True) )) fig.update_layout( title=None, template="plotly_white", xaxis=dict(title="epoch" if metric == "log_loss" else "samples", gridcolor="#eee"), yaxis=dict(title=("loss" if metric == "log_loss" else "accuracy"), gridcolor="#eee"), legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), margin=dict(l=50, r=20, t=60, b=50), ) if path: _save_plotly(fig, path) return fig def build_train_html_and_plots( predictor, problem_type: str, df_train: pd.DataFrame, label_column: str, tmpdir: str, df_val: Optional[pd.DataFrame] = None, seed: int = 42, perf_table_html: str | None = None, threshold: Optional[float] = None, section_tile: str = "Training Diagnostics", ) -> str: y_true = df_train[label_column].values y_true_val = df_val[label_column].values if df_val is not None else None # predictions on TRAIN pred_labels, pred_proba = None, None try: pred_labels = predictor.predict(df_train) except Exception: pass try: proba_raw = predictor.predict_proba(df_train) pred_proba = proba_raw.to_numpy() if isinstance(proba_raw, (pd.Series, pd.DataFrame)) else np.asarray(proba_raw) except Exception: pred_proba = None # predictions on VAL (if provided) pred_labels_val, pred_proba_val = None, None if df_val is not None: try: pred_labels_val = predictor.predict(df_val) except Exception: pred_labels_val = None try: proba_raw_val = predictor.predict_proba(df_val) pred_proba_val = proba_raw_val.to_numpy() if isinstance(proba_raw_val, (pd.Series, pd.DataFrame)) else np.asarray(proba_raw_val) except Exception: pred_proba_val = None pos_scores_train: Optional[np.ndarray] = None pos_scores_val: Optional[np.ndarray] = None if problem_type == "binary": if pred_proba is not None: pos_scores_train = ( pred_proba.reshape(-1) if pred_proba.ndim == 1 or (pred_proba.ndim == 2 and pred_proba.shape[1] == 1) else pred_proba[:, -1] ) if pred_proba_val is not None: pos_scores_val = ( pred_proba_val.reshape(-1) if pred_proba_val.ndim == 1 or (pred_proba_val.ndim == 2 and pred_proba_val.shape[1] == 1) else pred_proba_val[:, -1] ) # Collect plots then append in desired order perf_card = f"<div class='card'>{perf_table_html}</div>" if perf_table_html else None acc_plot = loss_plot = None cm_train = pc_train = cm_val = pc_val = None threshold_val_plot = None roc_combined = pr_combined = cal_combined = None mc_roc_val = None conf_train = conf_val = None bar_train = bar_val = None # 1) Learning Curve — Accuracy if problem_type in ("binary", "multiclass"): acc_fig = go.Figure() added_acc = False if pred_labels is not None: train_sizes, train_means, train_stds = generate_learning_curve_from_predictions( y_true=y_true, y_pred=np.asarray(pred_labels), metric="accuracy", title="Learning Curves — Label Accuracy", seed=seed, return_stats=True, ) acc_fig.add_trace(go.Scatter( x=train_sizes, y=train_means, mode="lines+markers", name="Train", line=dict(color="#1f77b4", width=3, shape="spline"), marker=dict(size=7), error_y=dict(type="data", array=train_stds, visible=True), )) added_acc = True if pred_labels_val is not None and y_true_val is not None: val_sizes, val_means, val_stds = generate_learning_curve_from_predictions( y_true=y_true_val, y_pred=np.asarray(pred_labels_val), metric="accuracy", title="Learning Curves — Label Accuracy", seed=seed, return_stats=True, ) acc_fig.add_trace(go.Scatter( x=val_sizes, y=val_means, mode="lines+markers", name="Validation", line=dict(color="#ff7f0e", width=3, shape="spline"), marker=dict(size=7), error_y=dict(type="data", array=val_stds, visible=True), )) added_acc = True if added_acc: acc_fig.update_layout( title=None, template="plotly_white", xaxis=dict(title="samples", gridcolor="#eee"), yaxis=dict(title="accuracy", gridcolor="#eee"), legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), margin=dict(l=50, r=20, t=60, b=50), ) acc_plot = plot_with_table_style_title(acc_fig, "Learning Curves — Label Accuracy") # 2) Learning Curve — Loss if problem_type in ("binary", "multiclass"): classes = np.unique(y_true) loss_fig = go.Figure() added_loss = False if pred_proba is not None: pp = pred_proba.reshape(-1) if pred_proba.ndim == 1 or (pred_proba.ndim == 2 and pred_proba.shape[1] == 1) else pred_proba train_sizes, train_means, train_stds = generate_learning_curve_from_predictions( y_true=y_true, y_proba=pp, classes=classes, metric="log_loss", title="Learning Curves — Label Loss", seed=seed, return_stats=True, ) loss_fig.add_trace(go.Scatter( x=train_sizes, y=train_means, mode="lines+markers", name="Train", line=dict(color="#1f77b4", width=3, shape="spline"), marker=dict(size=7), error_y=dict(type="data", array=train_stds, visible=True), )) added_loss = True if pred_proba_val is not None and y_true_val is not None: pp_val = pred_proba_val.reshape(-1) if pred_proba_val.ndim == 1 or (pred_proba_val.ndim == 2 and pred_proba_val.shape[1] == 1) else pred_proba_val val_sizes, val_means, val_stds = generate_learning_curve_from_predictions( y_true=y_true_val, y_proba=pp_val, classes=classes, metric="log_loss", title="Learning Curves — Label Loss", seed=seed, return_stats=True, ) loss_fig.add_trace(go.Scatter( x=val_sizes, y=val_means, mode="lines+markers", name="Validation", line=dict(color="#ff7f0e", width=3, shape="spline"), marker=dict(size=7), error_y=dict(type="data", array=val_stds, visible=True), )) added_loss = True if added_loss: loss_fig.update_layout( title=None, template="plotly_white", xaxis=dict(title="epoch", gridcolor="#eee"), yaxis=dict(title="loss", gridcolor="#eee"), legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), margin=dict(l=50, r=20, t=60, b=50), ) loss_plot = plot_with_table_style_title(loss_fig, "Learning Curves — Label Loss") # Confusion matrices & per-class metrics cm_train = pc_train = cm_val = pc_val = None # Probability diagnostics (binary) if problem_type == "binary": # Combined Calibration (Train/Val) cal_fig = go.Figure() added_cal = False if pos_scores_train is not None: y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int) prob_true, prob_pred = calibration_curve(y_bin_train, pos_scores_train, n_bins=10, strategy="uniform") cal_fig.add_trace(go.Scatter( x=prob_pred, y=prob_true, mode="lines+markers", name="Train", line=dict(color="#1f77b4", width=3), marker=dict(size=7, color="#1f77b4"), )) added_cal = True if pos_scores_val is not None and y_true_val is not None: y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) prob_true_v, prob_pred_v = calibration_curve(y_bin_val, pos_scores_val, n_bins=10, strategy="uniform") cal_fig.add_trace(go.Scatter( x=prob_pred_v, y=prob_true_v, mode="lines+markers", name="Validation", line=dict(color="#ff7f0e", width=3), marker=dict(size=7, color="#ff7f0e"), )) added_cal = True if added_cal: cal_fig.add_trace(go.Scatter( x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash", color="#808080", width=2), name="Perfect", showlegend=True, )) cal_fig.update_layout( title=None, xaxis_title="Predicted Probability", yaxis_title="Observed Probability", xaxis=dict(range=[0, 1]), yaxis=dict(range=[0, 1]), template="plotly_white", legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), margin=dict(l=60, r=40, t=50, b=50), ) cal_combined = plot_with_table_style_title(cal_fig, "Calibration Curve (Train vs Validation)") # Combined ROC (Train/Val) roc_fig = go.Figure() added_roc = False if pos_scores_train is not None: y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int) fpr_tr, tpr_tr, thr_tr = roc_curve(y_bin_train, pos_scores_train) roc_fig.add_trace(go.Scatter( x=fpr_tr, y=tpr_tr, mode="lines", name="Train", line=dict(color="#1f77b4", width=3), )) if threshold is not None and np.isfinite(thr_tr).any(): finite = np.isfinite(thr_tr) idx_local = int(np.argmin(np.abs(thr_tr[finite] - float(threshold)))) idx = int(np.nonzero(finite)[0][idx_local]) roc_fig.add_trace(go.Scatter( x=[fpr_tr[idx]], y=[tpr_tr[idx]], mode="markers", name="Train @ threshold", marker=dict(size=12, color="#1f77b4", symbol="x") )) added_roc = True if pos_scores_val is not None and y_true_val is not None: y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) fpr_v, tpr_v, thr_v = roc_curve(y_bin_val, pos_scores_val) roc_fig.add_trace(go.Scatter( x=fpr_v, y=tpr_v, mode="lines", name="Validation", line=dict(color="#ff7f0e", width=3), )) if threshold is not None and np.isfinite(thr_v).any(): finite = np.isfinite(thr_v) idx_local = int(np.argmin(np.abs(thr_v[finite] - float(threshold)))) idx = int(np.nonzero(finite)[0][idx_local]) roc_fig.add_trace(go.Scatter( x=[fpr_v[idx]], y=[tpr_v[idx]], mode="markers", name="Val @ threshold", marker=dict(size=12, color="#ff7f0e", symbol="x") )) added_roc = True if added_roc: roc_fig.add_trace(go.Scatter( x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash", width=2, color="#808080"), showlegend=False )) roc_fig.update_layout( title=None, xaxis_title="False Positive Rate", yaxis_title="True Positive Rate", template="plotly_white", legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), margin=dict(l=60, r=20, t=60, b=60), ) roc_combined = plot_with_table_style_title(roc_fig, "ROC Curve (Train vs Validation)") # Combined PR (Train/Val) pr_fig = go.Figure() added_pr = False if pos_scores_train is not None: y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int) prec_tr, rec_tr, thr_tr = precision_recall_curve(y_bin_train, pos_scores_train) pr_auc_tr = auc(rec_tr, prec_tr) pr_fig.add_trace(go.Scatter( x=rec_tr, y=prec_tr, mode="lines", name=f"Train (AUC={pr_auc_tr:.3f})", line=dict(color="#1f77b4", width=3), )) if threshold is not None and len(thr_tr): j = int(np.argmin(np.abs(thr_tr - float(threshold)))) j = int(np.clip(j, 0, len(thr_tr) - 1)) pr_fig.add_trace(go.Scatter( x=[rec_tr[j + 1]], y=[prec_tr[j + 1]], mode="markers", name="Train @ threshold", marker=dict(size=12, color="#1f77b4", symbol="x") )) added_pr = True if pos_scores_val is not None and y_true_val is not None: y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) prec_v, rec_v, thr_v = precision_recall_curve(y_bin_val, pos_scores_val) pr_auc_v = auc(rec_v, prec_v) pr_fig.add_trace(go.Scatter( x=rec_v, y=prec_v, mode="lines", name=f"Validation (AUC={pr_auc_v:.3f})", line=dict(color="#ff7f0e", width=3), )) if threshold is not None and len(thr_v): j = int(np.argmin(np.abs(thr_v - float(threshold)))) j = int(np.clip(j, 0, len(thr_v) - 1)) pr_fig.add_trace(go.Scatter( x=[rec_v[j + 1]], y=[prec_v[j + 1]], mode="markers", name="Val @ threshold", marker=dict(size=12, color="#ff7f0e", symbol="x") )) added_pr = True if added_pr: pr_fig.update_layout( title=None, xaxis_title="Recall", yaxis_title="Precision", template="plotly_white", legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), margin=dict(l=60, r=20, t=60, b=60), ) pr_combined = plot_with_table_style_title(pr_fig, "Precision–Recall Curve (Train vs Validation)") if pos_scores_val is not None and y_true_val is not None: y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) fig_thr_val = generate_threshold_plot(y_true_bin=y_bin_val, y_prob=pos_scores_val, title="Threshold Plot (Validation)", user_threshold=threshold) threshold_val_plot = plot_with_table_style_title(fig_thr_val, "Threshold Plot (Validation)") # Multiclass OVR ROC (validation) if problem_type == "multiclass" and pred_proba_val is not None and pred_proba_val.ndim >= 2 and y_true_val is not None: classes_val = np.unique(y_true_val) fig_mc_roc_val = generate_multiclass_roc_curve_plot(y_true_val, pred_proba_val, classes_val, title="One-vs-Rest ROC (Validation)") mc_roc_val = plot_with_table_style_title(fig_mc_roc_val, "One-vs-Rest ROC (Validation)") # Prediction Confidence Histogram (train/val) conf_train = conf_val = None # Per-class accuracy bars if problem_type in ("binary", "multiclass") and pred_labels is not None: classes_for_bar = pd.Index(np.unique(y_true), dtype=object).tolist() acc_vals = [] for c in classes_for_bar: mask = y_true == c acc_vals.append(float((np.asarray(pred_labels)[mask] == c).mean()) if mask.any() else 0.0) bar_fig = go.Figure(data=go.Bar(x=[str(c) for c in classes_for_bar], y=acc_vals, marker_color="#1f77b4")) bar_fig.update_layout( title=None, template="plotly_white", xaxis=dict(title="Label", gridcolor="#eee"), yaxis=dict(title="Accuracy", gridcolor="#eee", range=[0, 1]), margin=dict(l=50, r=20, t=60, b=50), ) bar_train = plot_with_table_style_title(bar_fig, "Per-Class Training Accuracy") if problem_type in ("binary", "multiclass") and pred_labels_val is not None and y_true_val is not None: classes_for_bar_val = pd.Index(np.unique(y_true_val), dtype=object).tolist() acc_vals_val = [] for c in classes_for_bar_val: mask = y_true_val == c acc_vals_val.append(float((np.asarray(pred_labels_val)[mask] == c).mean()) if mask.any() else 0.0) bar_fig_val = go.Figure(data=go.Bar(x=[str(c) for c in classes_for_bar_val], y=acc_vals_val, marker_color="#ff7f0e")) bar_fig_val.update_layout( title=None, template="plotly_white", xaxis=dict(title="Label", gridcolor="#eee"), yaxis=dict(title="Accuracy", gridcolor="#eee", range=[0, 1]), margin=dict(l=50, r=20, t=60, b=50), ) bar_val = plot_with_table_style_title(bar_fig_val, "Per-Class Validation Accuracy") # Assemble in requested order pieces: list[str] = [] if perf_card: pieces.append(perf_card) for block in (threshold_val_plot, roc_combined, pr_combined): if block: pieces.append(block) # Remaining plots (keep existing order) for block in (cal_combined, cm_train, pc_train, cm_val, pc_val, mc_roc_val, conf_train, conf_val, bar_train, bar_val): if block: pieces.append(block) # Learning curves should appear last in the tab for block in (acc_plot, loss_plot): if block: pieces.append(block) if not pieces: return "<h2>Training Diagnostics</h2><p><em>No training diagnostics available for this run.</em></p>" return "<h2>Train and Validation Performance Summary</h2>" + "".join(pieces) def generate_learning_curve( estimator, X, y, scoring: str = "r2", cv_folds: int = 5, n_jobs: int = -1, train_sizes: np.ndarray = np.linspace(0.1, 1.0, 10), title: str = "Learning Curve", path: Optional[str] = None, ) -> go.Figure: """ Learning curve using sklearn.learning_curve, visualized with Plotly. """ sizes, train_scores, test_scores = skl_learning_curve( estimator, X, y, cv=cv_folds, scoring=scoring, n_jobs=n_jobs, train_sizes=train_sizes ) train_mean = train_scores.mean(axis=1) train_std = train_scores.std(axis=1) test_mean = test_scores.mean(axis=1) test_std = test_scores.std(axis=1) fig = go.Figure() fig.add_trace(go.Scatter( x=sizes, y=train_mean, mode="lines+markers", name="Training score", error_y=dict(type="data", array=train_std, visible=True) )) fig.add_trace(go.Scatter( x=sizes, y=test_mean, mode="lines+markers", name="CV score", error_y=dict(type="data", array=test_std, visible=True) )) fig.update_layout( title=None, xaxis_title="Training examples", yaxis_title=scoring, template="plotly_white", ) _save_plotly(fig, path) return fig # ========================= # SHAP (Matplotlib-based) # ========================= def generate_shap_summary_plot( shap_values, features: pd.DataFrame, title: str = "SHAP Summary Plot", path: Optional[str] = None ) -> None: """ SHAP summary plot (Matplotlib). SHAP's interactive support with Plotly is limited; keep matplotlib for clarity and stability. """ plt.figure(figsize=(10, 8)) shap.summary_plot(shap_values, features, show=False) plt.title(title) _save_matplotlib(path) def generate_shap_force_plot( explainer, instance: pd.DataFrame, title: str = "SHAP Force Plot", path: Optional[str] = None ) -> None: """ SHAP force plot (Matplotlib). """ shap_values = explainer(instance) plt.figure(figsize=(10, 4)) shap.plots.force(shap_values[0], show=False) plt.title(title) _save_matplotlib(path) def generate_shap_waterfall_plot( explainer, instance: pd.DataFrame, title: str = "SHAP Waterfall Plot", path: Optional[str] = None ) -> None: """ SHAP waterfall plot (Matplotlib). """ shap_values = explainer(instance) plt.figure(figsize=(10, 6)) shap.plots.waterfall(shap_values[0], show=False) plt.title(title) _save_matplotlib(path) def infer_problem_type(predictor, df_train_full: pd.DataFrame, label_column: str) -> str: """ Return 'binary', 'multiclass', or 'regression'. Prefer the predictor's own metadata when available; otherwise infer from label dtype/uniques. """ # AutoGluon predictors usually expose .problem_type; be defensive. pt = getattr(predictor, "problem_type", None) if isinstance(pt, str): pt_l = pt.lower() if "regression" in pt_l: return "regression" if "binary" in pt_l: return "binary" if "multiclass" in pt_l or "multiclass" in pt_l: return "multiclass" y = df_train_full[label_column] if pd.api.types.is_numeric_dtype(y) and y.nunique() > 10: return "regression" return "binary" if y.nunique() == 2 else "multiclass" def _safe_floatify(d: Dict[str, Any]) -> Dict[str, float]: """Make evaluate() outputs JSON/csv friendly floats.""" out = {} for k, v in d.items(): try: out[k] = float(v) except Exception: # keep only real-valued scalars pass return out def evaluate_all( predictor, df_train: pd.DataFrame, df_val: pd.DataFrame, df_test: pd.DataFrame, label_column: str, problem_type: str, ) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float]]: """ Run predictor.evaluate on train/val/test and normalize the result dicts to floats. MultiModalPredictor does not accept the `silent` kwarg, so call defensively. """ def _evaluate(df): try: return predictor.evaluate(df, silent=True) except TypeError: return predictor.evaluate(df) train_scores = _safe_floatify(_evaluate(df_train)) val_scores = _safe_floatify(_evaluate(df_val)) test_scores = _safe_floatify(_evaluate(df_test)) return train_scores, val_scores, test_scores def build_summary_html( predictor, df_train: pd.DataFrame, df_val: Optional[pd.DataFrame], df_test: Optional[pd.DataFrame], label_column: str, extra_run_rows: Optional[list[tuple[str, str]]] = None, class_balance_html: Optional[str] = None, perf_table_html: Optional[str] = None, ) -> str: sections = [] # Dataset Overview (first section in the tab) if class_balance_html: sections.append(f""" <section class="section"> <h2 class="section-title">Dataset Overview</h2> <div class="card"> {class_balance_html} </div> </section> """.strip()) # Performance Summary if perf_table_html: sections.append(f""" <section class="section"> <h2 class="section-title">Model Performance Summary</h2> <div class="card"> {perf_table_html} </div> </section> """.strip()) # Model Configuration # Remove Predictor type and Framework, and ensure Model Architecture is present base_rows: list[tuple[str, str]] = [] if extra_run_rows: # Remove any rows with keys 'Predictor type' or 'Framework' base_rows.extend([(k, v) for (k, v) in extra_run_rows if k not in ("Predictor type", "Framework")]) def _fmt(v): if v is None or v == "": return "—" return _escape(str(v)) rows_html = "\n".join( f"<tr><td>{_escape(str(k))}</td><td>{_fmt(v)}</td></tr>" for k, v in base_rows ) sections.append(f""" <section class="section"> <h2 class="section-title">Model Configuration</h2> <div class="card"> <table class="kv-table"> <thead><tr><th>Key</th><th>Value</th></tr></thead> <tbody> {rows_html} </tbody> </table> </div> </section> """.strip()) return "\n".join(sections).strip() def build_feature_importance_html(predictor, df_train: pd.DataFrame, label_column: str) -> str: """Build a visualization of feature importance.""" try: # Try to get feature importance from predictor fi = None if hasattr(predictor, "feature_importance") and callable(predictor.feature_importance): try: fi = predictor.feature_importance(df_train) except Exception as e: return f"<p>Could not compute feature importance: {e}</p>" if fi is None or (isinstance(fi, pd.DataFrame) and fi.empty): return "<p>Feature importance not available for this model.</p>" # Format as a sortable table rows = [] if isinstance(fi, pd.DataFrame): fi = fi.sort_values("importance", ascending=False) for _, row in fi.iterrows(): feat = row.index[0] if isinstance(row.index, pd.Index) else row["feature"] imp = float(row["importance"]) rows.append(f"<tr><td>{_escape(str(feat))}</td><td>{imp:.4f}</td></tr>") else: # Handle other formats (dict, etc) for feat, imp in sorted(fi.items(), key=lambda x: float(x[1]), reverse=True): rows.append(f"<tr><td>{_escape(str(feat))}</td><td>{float(imp):.4f}</td></tr>") if not rows: return "<p>No feature importance values available.</p>" table_html = f""" <table class="performance-summary"> <thead> <tr> <th class="sortable">Feature</th> <th class="sortable">Importance</th> </tr> </thead> <tbody> {"".join(rows)} </tbody> </table> """ return table_html except Exception as e: return f"<p>Error building feature importance visualization: {e}</p>" def build_test_html_and_plots( predictor, problem_type: str, df_test: pd.DataFrame, label_column: str, tmpdir: str, threshold: Optional[float] = None, ) -> Tuple[str, List[str]]: """ Create a test-summary section (with a placeholder for metric rows) and a list of Plotly HTML divs. Returns: (html_template_with_{}, list_of_plot_divs) """ plots: List[str] = [] y_true = df_test[label_column].values classes = np.unique(y_true) # Try proba/labels where meaningful pred_labels = None pred_proba = None try: pred_labels = predictor.predict(df_test) except Exception: pass try: # MultiModalPredictor exposes predict_proba for classification problems. pred_proba = predictor.predict_proba(df_test) except Exception: pred_proba = None proba_arr = None if pred_proba is not None: if isinstance(pred_proba, pd.Series): proba_arr = pred_proba.to_numpy().reshape(-1, 1) elif isinstance(pred_proba, pd.DataFrame): proba_arr = pred_proba.to_numpy() else: proba_arr = np.asarray(pred_proba) # Thresholded labels for binary if problem_type == "binary" and threshold is not None and proba_arr is not None: pos_label, neg_label = classes.max(), classes.min() pos_scores = proba_arr.reshape(-1) if (proba_arr.ndim == 1 or proba_arr.shape[1] == 1) else proba_arr[:, -1] pred_labels = np.where(pos_scores >= float(threshold), pos_label, neg_label) # Confusion matrix / per-class now reflect thresholded labels if problem_type in ("binary", "multiclass") and pred_labels is not None: cm_title = "Confusion Matrix" if threshold is not None and problem_type == "binary": thr_str = f"{float(threshold):.3f}".rstrip("0").rstrip(".") cm_title = f"Confusion Matrix (Threshold = {thr_str})" fig_cm = generate_confusion_matrix_plot(y_true, pred_labels, title=cm_title) plots.append(plot_with_table_style_title(fig_cm, cm_title)) fig_pc = generate_per_class_metrics_plot(y_true, pred_labels, title="Per-Class Metrics") plots.append(plot_with_table_style_title(fig_pc, "Per-Class Metrics")) # ROC/PR where possible — choose positive-class scores safely pos_label = classes.max() # or set explicitly, e.g., 1 or "yes" if isinstance(pred_proba, pd.DataFrame): proba_arr = pred_proba.to_numpy() if pos_label in pred_proba.columns: pos_idx = list(pred_proba.columns).index(pos_label) else: pos_idx = -1 # fallback to last column elif isinstance(pred_proba, pd.Series): proba_arr = pred_proba.to_numpy().reshape(-1, 1) pos_idx = 0 else: proba_arr = np.asarray(pred_proba) if pred_proba is not None else None pos_idx = -1 if (proba_arr is not None and proba_arr.ndim == 2 and proba_arr.shape[1] > 1) else 0 if proba_arr is not None: y_bin = (y_true == pos_label).astype(int) pos_scores = ( proba_arr.reshape(-1) if proba_arr.ndim == 1 or proba_arr.shape[1] == 1 else proba_arr[:, pos_idx] ) fig_roc = generate_roc_curve_plot(y_bin, pos_scores, title="ROC Curve", marker_threshold=threshold) plots.append(plot_with_table_style_title(fig_roc, f"ROC Curve{'' if threshold is None else f' (marker at threshold={threshold:.2f})'}")) fig_pr = generate_pr_curve_plot(y_bin, pos_scores, title="Precision–Recall Curve", marker_threshold=threshold) plots.append(plot_with_table_style_title(fig_pr, f"Precision–Recall Curve{'' if threshold is None else f' (marker at threshold={threshold:.2f})'}")) # Additional diagnostics aligned with ImageLearner style if problem_type == "binary": conf_fig = plot_confidence_histogram(pos_scores, bins=20, title="Prediction Confidence (Test)") plots.append(plot_with_table_style_title(conf_fig, "Prediction Confidence (Test)")) else: conf_fig = plot_confidence_histogram(proba_arr, bins=20, title="Prediction Confidence (Top-1, Test)") plots.append(plot_with_table_style_title(conf_fig, "Prediction Confidence (Top-1, Test)")) if problem_type == "multiclass" and proba_arr is not None and proba_arr.ndim >= 2: fig_mc_roc = generate_multiclass_roc_curve_plot(y_true, proba_arr, classes, title="One-vs-Rest ROC (Test)") plots.append(plot_with_table_style_title(fig_mc_roc, "One-vs-Rest ROC (Test)")) # Regression visuals if problem_type == "regression": if pred_labels is None: pred_labels = predictor.predict(df_test) fig_sc = generate_scatter_plot(y_true, pred_labels, title="Predicted vs Actual") plots.append(plot_with_table_style_title(fig_sc, "Predicted vs Actual")) fig_res = generate_residual_plot(y_true, pred_labels, title="Residual Plot") plots.append(plot_with_table_style_title(fig_res, "Residual Plot")) fig_hist = generate_residual_histogram(y_true, pred_labels, title="Residual Histogram") plots.append(plot_with_table_style_title(fig_hist, "Residual Histogram")) fig_cal = generate_regression_calibration_plot(y_true, pred_labels, title="Regression Calibration") plots.append(plot_with_table_style_title(fig_cal, "Regression Calibration")) # Small HTML template with placeholder for metric rows the caller fills in test_html_template = """ <h2>Test Performance Summary</h2> <table class="performance-summary"> <thead><tr><th>Metric</th><th>Test</th></tr></thead> <tbody>{}</tbody> </table> """ return test_html_template, plots def build_feature_html( predictor, df_train: pd.DataFrame, label_column: str, include_modalities: bool = True, # ← NEW include_class_balance: bool = True, # ← NEW ) -> str: sections = [] # (Typical feature importance content…) fi_html = build_feature_importance_html(predictor, df_train, label_column) sections.append(f"<section class='section'><h2 class='section-title'>Feature Importance</h2><div class='card'>{fi_html}</div></section>") # Previously: Modalities & Inputs and/or Class Balance may have been here. # Only render them if flags are True. if include_modalities: from report_utils import build_modalities_html modalities_html = build_modalities_html(predictor, df_train, label_column) sections.append(f"<section class='section'><h2 class='section-title'>Modalities & Inputs</h2><div class='card'>{modalities_html}</div></section>") if include_class_balance: from report_utils import build_class_balance_html cb_html = build_class_balance_html(df_train, label_column) sections.append(f"<section class='section'><h2 class='section-title'>Class Balance (Train Full)</h2><div class='card'>{cb_html}</div></section>") return "\n".join(sections) def assemble_full_html_report( summary_html: str, train_html: str, test_html: str, plots: List[str], feature_html: str, ) -> str: """ Wrap the four tabs using utils.build_tabbed_html and return full HTML. """ # Append plots under the Test tab (already wrapped with titles) test_full = test_html + "".join(plots) tabs = build_tabbed_html(summary_html, train_html, test_full, feature_html, explainer_html=None) html_out = get_html_template() # 🔧 Ensure Plotly JS is available (we render plots with include_plotlyjs=False) html_out += '\n<script src="https://cdn.plot.ly/plotly-2.30.0.min.js"></script>\n' # Optional: centering tweaks html_out += """ <style> .plotly-center { display: flex; justify-content: center; } .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { margin: 0 auto !important; } .js-plotly-plot, .plotly-graph-div { margin-left: auto !important; margin-right: auto !important; } </style> """ # Help modal HTML/JS html_out += get_metrics_help_modal() html_out += tabs html_out += get_html_closing() return html_out
