Mercurial > repos > goeckslab > multimodal_learner
comparison plot_logic.py @ 0:375c36923da1 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
| author | goeckslab |
|---|---|
| date | Tue, 09 Dec 2025 23:49:47 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:375c36923da1 |
|---|---|
| 1 from __future__ import annotations | |
| 2 | |
| 3 import html | |
| 4 import os | |
| 5 from html import escape as _escape | |
| 6 from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union | |
| 7 | |
| 8 import matplotlib.pyplot as plt | |
| 9 import numpy as np | |
| 10 import pandas as pd | |
| 11 import plotly.express as px | |
| 12 import plotly.graph_objects as go | |
| 13 import shap | |
| 14 from feature_help_modal import get_metrics_help_modal | |
| 15 from report_utils import build_tabbed_html, get_html_closing, get_html_template | |
| 16 from sklearn.calibration import calibration_curve | |
| 17 from sklearn.metrics import ( | |
| 18 auc, | |
| 19 average_precision_score, | |
| 20 classification_report, | |
| 21 confusion_matrix, | |
| 22 log_loss, | |
| 23 precision_recall_curve, | |
| 24 roc_auc_score, | |
| 25 roc_curve, | |
| 26 ) | |
| 27 from sklearn.model_selection import learning_curve as skl_learning_curve | |
| 28 from sklearn.preprocessing import label_binarize | |
| 29 | |
| 30 # ========================= | |
| 31 # Utilities | |
| 32 # ========================= | |
| 33 | |
| 34 | |
| 35 def plot_with_table_style_title(fig, title: str) -> str: | |
| 36 """ | |
| 37 Render a Plotly figure with a report-style <h2> header so it matches the | |
| 38 green table section headers. | |
| 39 """ | |
| 40 # kill Plotly’s built-in title | |
| 41 fig.update_layout(title=None) | |
| 42 | |
| 43 # figure HTML without PlotlyJS (we load it once globally) | |
| 44 plot_html = fig.to_html(full_html=False, include_plotlyjs=False) | |
| 45 | |
| 46 # use <h2> — your CSS already styles <h2> like the table headers | |
| 47 return f""" | |
| 48 <h2>{html.escape(title)}</h2> | |
| 49 <div class="plotly-center">{plot_html}</div> | |
| 50 """.strip() | |
| 51 | |
| 52 | |
| 53 def _save_plotly(fig: go.Figure, path: Optional[str]) -> None: | |
| 54 """ | |
| 55 Save a Plotly figure. If `path` ends with `.html`, save interactive HTML. | |
| 56 If it ends with a raster extension (png/jpg/jpeg/webp), uses Kaleido. | |
| 57 If None, do nothing (caller may choose to display in notebook). | |
| 58 """ | |
| 59 if not path: | |
| 60 return | |
| 61 ext = os.path.splitext(path)[1].lower() | |
| 62 if ext == ".html": | |
| 63 fig.write_html(path, include_plotlyjs="cdn", full_html=True) | |
| 64 else: | |
| 65 # Requires kaleido: pip install -U kaleido | |
| 66 fig.write_image(path) | |
| 67 | |
| 68 | |
| 69 def _save_matplotlib(path: Optional[str]) -> None: | |
| 70 """Save current Matplotlib figure if `path` is provided, else show().""" | |
| 71 if path: | |
| 72 plt.savefig(path, bbox_inches="tight") | |
| 73 plt.close() | |
| 74 else: | |
| 75 plt.show() | |
| 76 | |
| 77 # ========================= | |
| 78 # Classification Plots | |
| 79 # ========================= | |
| 80 | |
| 81 | |
| 82 def generate_confusion_matrix_plot( | |
| 83 y_true, | |
| 84 y_pred, | |
| 85 title: str = "Confusion Matrix", | |
| 86 ) -> go.Figure: | |
| 87 y_true = np.asarray(y_true) | |
| 88 y_pred = np.asarray(y_pred) | |
| 89 | |
| 90 # Class order (works for strings or numbers) | |
| 91 labels = pd.Index(np.unique(np.concatenate([y_true, y_pred])), dtype=object).tolist() | |
| 92 cm = confusion_matrix(y_true, y_pred, labels=labels) | |
| 93 max_val = cm.max() if cm.size else 0 | |
| 94 | |
| 95 # Use categorical axes by passing string labels for x/y | |
| 96 cats = [str(label) for label in labels] | |
| 97 total = int(cm.sum()) | |
| 98 | |
| 99 fig = go.Figure( | |
| 100 data=go.Heatmap( | |
| 101 z=cm, | |
| 102 x=cats, # categorical x | |
| 103 y=cats, # categorical y | |
| 104 colorscale="Blues", | |
| 105 showscale=True, | |
| 106 colorbar=dict(title="Count"), | |
| 107 xgap=2, | |
| 108 ygap=2, | |
| 109 hovertemplate="True=%{y}<br>Pred=%{x}<br>Count=%{z}<extra></extra>", | |
| 110 zmin=0 | |
| 111 ) | |
| 112 ) | |
| 113 | |
| 114 # Add annotations with count and percentage (all white text, matching sample_output.html) | |
| 115 annotations = [] | |
| 116 for i in range(cm.shape[0]): | |
| 117 for j in range(cm.shape[1]): | |
| 118 val = int(cm[i, j]) | |
| 119 pct = (val / total * 100) if total > 0 else 0 | |
| 120 text_color = "white" if max_val and val > (max_val / 2) else "black" | |
| 121 # Count annotation (bold, bottom) | |
| 122 annotations.append( | |
| 123 dict( | |
| 124 x=cats[j], | |
| 125 y=cats[i], | |
| 126 text=f"<b>{val}</b>", | |
| 127 showarrow=False, | |
| 128 font=dict(color=text_color, size=14), | |
| 129 xanchor="center", | |
| 130 yanchor="bottom", | |
| 131 yshift=2 | |
| 132 ) | |
| 133 ) | |
| 134 # Percentage annotation (top) | |
| 135 annotations.append( | |
| 136 dict( | |
| 137 x=cats[j], | |
| 138 y=cats[i], | |
| 139 text=f"{pct:.1f}%", | |
| 140 showarrow=False, | |
| 141 font=dict(color=text_color, size=13), | |
| 142 xanchor="center", | |
| 143 yanchor="top", | |
| 144 yshift=-2 | |
| 145 ) | |
| 146 ) | |
| 147 | |
| 148 fig.update_layout( | |
| 149 title=None, | |
| 150 xaxis_title="Predicted label", | |
| 151 yaxis_title="True label", | |
| 152 xaxis=dict(type="category"), | |
| 153 yaxis=dict(type="category", autorange="reversed"), # typical CM orientation | |
| 154 margin=dict(l=80, r=20, t=40, b=80), | |
| 155 template="plotly_white", | |
| 156 plot_bgcolor="white", | |
| 157 paper_bgcolor="white", | |
| 158 annotations=annotations | |
| 159 ) | |
| 160 return fig | |
| 161 | |
| 162 | |
| 163 def generate_roc_curve_plot( | |
| 164 y_true_bin: np.ndarray, | |
| 165 y_score: np.ndarray, | |
| 166 title: str = "ROC Curve", | |
| 167 marker_threshold: float | None = None, | |
| 168 ) -> go.Figure: | |
| 169 y_true_bin = np.asarray(y_true_bin).astype(int).reshape(-1) | |
| 170 y_score = np.asarray(y_score).astype(float).reshape(-1) | |
| 171 | |
| 172 fpr, tpr, thr = roc_curve(y_true_bin, y_score) | |
| 173 roc_auc = auc(fpr, tpr) | |
| 174 | |
| 175 fig = go.Figure() | |
| 176 fig.add_trace(go.Scatter( | |
| 177 x=fpr, y=tpr, mode="lines", | |
| 178 name=f"ROC (AUC={roc_auc:.3f})", | |
| 179 line=dict(width=3) | |
| 180 )) | |
| 181 | |
| 182 # 45° chance line (no legend to keep it clean) | |
| 183 fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode="lines", | |
| 184 line=dict(dash="dash", width=2, color="#888"), showlegend=False)) | |
| 185 | |
| 186 # Optional marker at the user threshold | |
| 187 if marker_threshold is not None and len(thr): | |
| 188 # roc_curve returns thresholds of same length as fpr/tpr; includes inf at idx 0 | |
| 189 finite = np.isfinite(thr) | |
| 190 if np.any(finite): | |
| 191 idx_local = int(np.argmin(np.abs(thr[finite] - float(marker_threshold)))) | |
| 192 idx = int(np.nonzero(finite)[0][idx_local]) # map back to original indices | |
| 193 x_m, y_m = float(fpr[idx]), float(tpr[idx]) | |
| 194 | |
| 195 fig.add_trace( | |
| 196 go.Scatter( | |
| 197 x=[x_m], y=[y_m], | |
| 198 mode="markers", | |
| 199 name=f"@ {float(marker_threshold):.2f}", | |
| 200 marker=dict(size=12, color="red", symbol="x") | |
| 201 ) | |
| 202 ) | |
| 203 fig.add_annotation( | |
| 204 x=0.02, y=0.98, xref="paper", yref="paper", | |
| 205 text=f"threshold = {float(marker_threshold):.2f}", | |
| 206 showarrow=False, | |
| 207 font=dict(color="black", size=12), | |
| 208 align="left" | |
| 209 ) | |
| 210 | |
| 211 fig.update_layout( | |
| 212 title=None, | |
| 213 xaxis_title="False Positive Rate", | |
| 214 yaxis_title="True Positive Rate", | |
| 215 template="plotly_white", | |
| 216 legend=dict(x=1, y=0, xanchor="right"), | |
| 217 margin=dict(l=60, r=20, t=60, b=60), | |
| 218 ) | |
| 219 return fig | |
| 220 | |
| 221 | |
| 222 def generate_pr_curve_plot( | |
| 223 y_true_bin: np.ndarray, | |
| 224 y_score: np.ndarray, | |
| 225 title: str = "Precision–Recall Curve", | |
| 226 marker_threshold: float | None = None, | |
| 227 ) -> go.Figure: | |
| 228 y_true_bin = np.asarray(y_true_bin).astype(int).reshape(-1) | |
| 229 y_score = np.asarray(y_score).astype(float).reshape(-1) | |
| 230 | |
| 231 precision, recall, thr = precision_recall_curve(y_true_bin, y_score) | |
| 232 pr_auc = auc(recall, precision) | |
| 233 | |
| 234 fig = go.Figure() | |
| 235 fig.add_trace(go.Scatter( | |
| 236 x=recall, y=precision, mode="lines", | |
| 237 name=f"PR (AUC={pr_auc:.3f})", | |
| 238 line=dict(width=3) | |
| 239 )) | |
| 240 | |
| 241 # Optional marker at the user threshold | |
| 242 if marker_threshold is not None and len(thr): | |
| 243 # In PR, thresholds has length len(precision)-1. The point for thr[j] is (recall[j+1], precision[j+1]). | |
| 244 j = int(np.argmin(np.abs(thr - float(marker_threshold)))) | |
| 245 j = int(np.clip(j, 0, len(thr) - 1)) | |
| 246 x_m, y_m = float(recall[j + 1]), float(precision[j + 1]) | |
| 247 | |
| 248 fig.add_trace( | |
| 249 go.Scatter( | |
| 250 x=[x_m], y=[y_m], | |
| 251 mode="markers", | |
| 252 name=f"@ {float(marker_threshold):.2f}", | |
| 253 marker=dict(size=12, color="red", symbol="x") | |
| 254 ) | |
| 255 ) | |
| 256 fig.add_annotation( | |
| 257 x=0.02, y=0.98, xref="paper", yref="paper", | |
| 258 text=f"threshold = {float(marker_threshold):.2f}", | |
| 259 showarrow=False, | |
| 260 font=dict(color="black", size=12), | |
| 261 align="left" | |
| 262 ) | |
| 263 | |
| 264 fig.update_layout( | |
| 265 title=None, | |
| 266 xaxis_title="Recall", | |
| 267 yaxis_title="Precision", | |
| 268 template="plotly_white", | |
| 269 legend=dict(x=1, y=0, xanchor="right"), | |
| 270 margin=dict(l=60, r=20, t=60, b=60), | |
| 271 ) | |
| 272 return fig | |
| 273 | |
| 274 | |
| 275 def generate_calibration_plot( | |
| 276 y_true_bin: np.ndarray, | |
| 277 y_prob: np.ndarray, | |
| 278 n_bins: int = 10, | |
| 279 title: str = "Calibration Plot", | |
| 280 path: Optional[str] = None, | |
| 281 ) -> go.Figure: | |
| 282 """ | |
| 283 Binary calibration curve (Plotly). | |
| 284 """ | |
| 285 prob_true, prob_pred = calibration_curve(y_true_bin, y_prob, n_bins=n_bins, strategy="uniform") | |
| 286 fig = go.Figure() | |
| 287 fig.add_trace(go.Scatter( | |
| 288 x=prob_pred, y=prob_true, mode="lines+markers", name="Model", | |
| 289 line=dict(color="#1f77b4", width=3), marker=dict(size=7, color="#1f77b4") | |
| 290 )) | |
| 291 fig.add_trace( | |
| 292 go.Scatter( | |
| 293 x=[0, 1], y=[0, 1], | |
| 294 mode="lines", | |
| 295 line=dict(dash="dash", color="#808080", width=2), | |
| 296 name="Perfect" | |
| 297 ) | |
| 298 ) | |
| 299 fig.update_layout( | |
| 300 title=None, | |
| 301 xaxis_title="Predicted Probability", | |
| 302 yaxis_title="Observed Probability", | |
| 303 yaxis=dict(range=[0, 1]), | |
| 304 xaxis=dict(range=[0, 1]), | |
| 305 template="plotly_white", | |
| 306 margin=dict(l=60, r=40, t=50, b=50), | |
| 307 ) | |
| 308 _save_plotly(fig, path) | |
| 309 return fig | |
| 310 | |
| 311 | |
| 312 def generate_threshold_plot( | |
| 313 y_true_bin: np.ndarray, | |
| 314 y_prob: np.ndarray, | |
| 315 title: str = "Threshold Plot", | |
| 316 user_threshold: float | None = None, | |
| 317 ) -> go.Figure: | |
| 318 y_true = np.asarray(y_true_bin, dtype=int).ravel() | |
| 319 p = np.asarray(y_prob, dtype=float).ravel() | |
| 320 p = np.nan_to_num(p, nan=0.0) | |
| 321 p = np.clip(p, 0.0, 1.0) | |
| 322 | |
| 323 def _compute_metrics(thresholds: np.ndarray): | |
| 324 """Vectorized-ish helper to compute precision/recall/F1/queue rate arrays.""" | |
| 325 prec, rec, f1, qrate = [], [], [], [] | |
| 326 for t in thresholds: | |
| 327 yhat = (p >= t).astype(int) | |
| 328 tp = int(((yhat == 1) & (y_true == 1)).sum()) | |
| 329 fp = int(((yhat == 1) & (y_true == 0)).sum()) | |
| 330 fn = int(((yhat == 0) & (y_true == 1)).sum()) | |
| 331 | |
| 332 pr = tp / (tp + fp) if (tp + fp) else np.nan # undefined when no predicted positives | |
| 333 rc = tp / (tp + fn) if (tp + fn) else 0.0 | |
| 334 f = (2 * pr * rc) / (pr + rc) if (pr + rc) and not np.isnan(pr) else 0.0 | |
| 335 q = float(yhat.mean()) | |
| 336 | |
| 337 prec.append(pr) | |
| 338 rec.append(rc) | |
| 339 f1.append(f) | |
| 340 qrate.append(q) | |
| 341 return np.asarray(prec, dtype=float), np.asarray(rec, dtype=float), np.asarray(f1, dtype=float), np.asarray(qrate, dtype=float) | |
| 342 | |
| 343 # Use uniform threshold grid for plotting (0 to 1 in steps of 0.01) | |
| 344 th = np.linspace(0.0, 1.0, 101) | |
| 345 prec, rec, f1_arr, qrate = _compute_metrics(th) | |
| 346 | |
| 347 # Compute F1*-optimal threshold using actual score distribution (more precise than grid) | |
| 348 cand_th = np.unique(np.concatenate(([0.0, 1.0], p))) | |
| 349 # cap to a reasonable size by sampling if extremely large | |
| 350 if cand_th.size > 2000: | |
| 351 cand_th = np.linspace(0.0, 1.0, 2001) | |
| 352 _, _, f1_cand, _ = _compute_metrics(cand_th) | |
| 353 | |
| 354 if np.all(np.isnan(f1_cand)): | |
| 355 t_star = 0.5 # fallback when no valid F1 can be computed | |
| 356 else: | |
| 357 f1_max = np.nanmax(f1_cand) | |
| 358 best_idxs = np.where(np.isclose(f1_cand, f1_max, equal_nan=False))[0] | |
| 359 # pick the middle of the best candidates to avoid biasing toward 0 | |
| 360 best_idx = int(best_idxs[len(best_idxs) // 2]) | |
| 361 t_star = float(cand_th[best_idx]) | |
| 362 | |
| 363 # Replace NaNs for plotting (set to 0 where precision is undefined) | |
| 364 prec_plot = np.nan_to_num(prec, nan=0.0) | |
| 365 | |
| 366 fig = go.Figure() | |
| 367 | |
| 368 # Precision (blue line) | |
| 369 fig.add_trace(go.Scatter( | |
| 370 x=th, y=prec_plot, mode="lines", name="Precision", | |
| 371 line=dict(width=3, color="#1f77b4"), | |
| 372 hovertemplate="Threshold=%{x:.3f}<br>Precision=%{y:.3f}<extra></extra>" | |
| 373 )) | |
| 374 | |
| 375 # Recall (orange line) | |
| 376 fig.add_trace(go.Scatter( | |
| 377 x=th, y=rec, mode="lines", name="Recall", | |
| 378 line=dict(width=3, color="#ff7f0e"), | |
| 379 hovertemplate="Threshold=%{x:.3f}<br>Recall=%{y:.3f}<extra></extra>" | |
| 380 )) | |
| 381 | |
| 382 # F1 (green line) | |
| 383 fig.add_trace(go.Scatter( | |
| 384 x=th, y=f1_arr, mode="lines", name="F1", | |
| 385 line=dict(width=3, color="#2ca02c"), | |
| 386 hovertemplate="Threshold=%{x:.3f}<br>F1=%{y:.3f}<extra></extra>" | |
| 387 )) | |
| 388 | |
| 389 # Queue Rate (grey dashed line) | |
| 390 fig.add_trace(go.Scatter( | |
| 391 x=th, y=qrate, mode="lines", name="Queue Rate", | |
| 392 line=dict(width=2, color="#808080", dash="dash"), | |
| 393 hovertemplate="Threshold=%{x:.3f}<br>Queue Rate=%{y:.3f}<extra></extra>" | |
| 394 )) | |
| 395 | |
| 396 # F1*-optimal threshold marker (dashed vertical line) | |
| 397 fig.add_vline( | |
| 398 x=t_star, | |
| 399 line_width=2, | |
| 400 line_dash="dash", | |
| 401 line_color="black", | |
| 402 annotation_text=f"t* = {t_star:.2f}", | |
| 403 annotation_position="top" | |
| 404 ) | |
| 405 | |
| 406 # User threshold (solid red line) if provided | |
| 407 if user_threshold is not None: | |
| 408 fig.add_vline( | |
| 409 x=float(user_threshold), | |
| 410 line_width=2, | |
| 411 line_color="red", | |
| 412 annotation_text=f"threshold = {float(user_threshold):.2f}", | |
| 413 annotation_position="top" | |
| 414 ) | |
| 415 | |
| 416 fig.update_layout( | |
| 417 title=None, | |
| 418 template="plotly_white", | |
| 419 xaxis=dict( | |
| 420 title="Discrimination Threshold", | |
| 421 range=[0, 1], | |
| 422 gridcolor="#e0e0e0", | |
| 423 showgrid=True, | |
| 424 zeroline=False | |
| 425 ), | |
| 426 yaxis=dict( | |
| 427 title="Score", | |
| 428 range=[0, 1], | |
| 429 gridcolor="#e0e0e0", | |
| 430 showgrid=True, | |
| 431 zeroline=False | |
| 432 ), | |
| 433 legend=dict( | |
| 434 orientation="h", | |
| 435 yanchor="bottom", | |
| 436 y=1.02, | |
| 437 xanchor="right", | |
| 438 x=1.0 | |
| 439 ), | |
| 440 margin=dict(l=60, r=20, t=40, b=50), | |
| 441 plot_bgcolor="white", | |
| 442 paper_bgcolor="white", | |
| 443 ) | |
| 444 return fig | |
| 445 | |
| 446 | |
| 447 def generate_per_class_metrics_plot( | |
| 448 y_true: Sequence, | |
| 449 y_pred: Sequence, | |
| 450 metrics: Sequence[str] = ("precision", "recall", "f1_score"), | |
| 451 title: str = "Classification Report", | |
| 452 path: Optional[str] = None, | |
| 453 ) -> go.Figure: | |
| 454 """ | |
| 455 Per-class metrics heatmap (Plotly), similar to sklearn's classification report. | |
| 456 Rows = classes, columns = metrics; cell text shows the value (0–1). | |
| 457 """ | |
| 458 # Map display names -> sklearn keys | |
| 459 key_map = {"f1_score": "f1-score", "precision": "precision", "recall": "recall"} | |
| 460 report = classification_report( | |
| 461 y_true, y_pred, output_dict=True, zero_division=0 | |
| 462 ) | |
| 463 | |
| 464 # Order classes sensibly (numeric if possible, else lexical) | |
| 465 def _sort_key(x): | |
| 466 try: | |
| 467 return (0, float(x)) | |
| 468 except Exception: | |
| 469 return (1, str(x)) | |
| 470 | |
| 471 # Use all classes seen in y_true or y_pred (so rows don't jump around) | |
| 472 uniq = sorted(set(list(y_true) + list(y_pred)), key=_sort_key) | |
| 473 classes = [str(c) for c in uniq] | |
| 474 | |
| 475 # Build Z matrix (rows=classes, cols=metrics) | |
| 476 used_metrics = [key_map.get(m, m) for m in metrics] | |
| 477 z = [] | |
| 478 for c in classes: | |
| 479 row = report.get(c, {}) | |
| 480 z.append([float(row.get(m, 0.0) or 0.0) for m in used_metrics]) | |
| 481 z = np.array(z, dtype=float) | |
| 482 | |
| 483 # Pretty cell labels | |
| 484 z_text = [[f"{v:.2f}" for v in r] for r in z] | |
| 485 | |
| 486 fig = go.Figure( | |
| 487 data=go.Heatmap( | |
| 488 z=z, | |
| 489 x=list(metrics), # keep display names ("precision", "recall", "f1_score") | |
| 490 y=classes, # classes as strings | |
| 491 colorscale="Reds", | |
| 492 zmin=0.0, | |
| 493 zmax=1.0, | |
| 494 colorbar=dict(title="Value"), | |
| 495 text=z_text, | |
| 496 texttemplate="%{text}", | |
| 497 hovertemplate="Class %{y}<br>%{x}: %{z:.2f}<extra></extra>", | |
| 498 ) | |
| 499 ) | |
| 500 fig.update_layout( | |
| 501 title=None, | |
| 502 xaxis_title="", | |
| 503 yaxis_title="Class", | |
| 504 template="plotly_white", | |
| 505 margin=dict(l=60, r=60, t=70, b=40), | |
| 506 ) | |
| 507 | |
| 508 _save_plotly(fig, path) | |
| 509 return fig | |
| 510 | |
| 511 | |
| 512 def generate_multiclass_roc_curve_plot( | |
| 513 y_true: Sequence, | |
| 514 y_prob: np.ndarray, | |
| 515 classes: Sequence, | |
| 516 title: str = "Multiclass ROC Curve", | |
| 517 path: Optional[str] = None, | |
| 518 ) -> go.Figure: | |
| 519 """ | |
| 520 One-vs-rest ROC curves for multiclass (Plotly). | |
| 521 Handles binary passed as 2-column probs as well. | |
| 522 """ | |
| 523 y_true = np.asarray(y_true) | |
| 524 y_prob = np.asarray(y_prob) | |
| 525 | |
| 526 # Normalize to shape (n_samples, n_classes) | |
| 527 if y_prob.ndim == 1 or y_prob.shape[1] == 1: | |
| 528 y_prob = np.hstack([1 - y_prob.reshape(-1, 1), y_prob.reshape(-1, 1)]) | |
| 529 | |
| 530 y_true_bin = label_binarize(y_true, classes=classes) | |
| 531 if y_true_bin.shape[1] == 1 and y_prob.shape[1] == 2: | |
| 532 y_true_bin = np.hstack([1 - y_true_bin, y_true_bin]) | |
| 533 | |
| 534 if y_prob.shape[1] != y_true_bin.shape[1]: | |
| 535 raise ValueError( | |
| 536 f"Shape mismatch: y_prob has {y_prob.shape[1]} columns but y_true_bin has {y_true_bin.shape[1]}." | |
| 537 ) | |
| 538 | |
| 539 fig = go.Figure() | |
| 540 for i, cls in enumerate(classes): | |
| 541 fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_prob[:, i]) | |
| 542 auc_val = roc_auc_score(y_true_bin[:, i], y_prob[:, i]) | |
| 543 fig.add_trace(go.Scatter(x=fpr, y=tpr, mode="lines", name=f"{cls} (AUC {auc_val:.2f})")) | |
| 544 | |
| 545 fig.add_trace( | |
| 546 go.Scatter(x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash"), showlegend=False) | |
| 547 ) | |
| 548 fig.update_layout( | |
| 549 title=None, | |
| 550 xaxis_title="False Positive Rate", | |
| 551 yaxis_title="True Positive Rate", | |
| 552 template="plotly_white", | |
| 553 ) | |
| 554 _save_plotly(fig, path) | |
| 555 return fig | |
| 556 | |
| 557 | |
| 558 def generate_multiclass_pr_curve_plot( | |
| 559 y_true: Sequence, | |
| 560 y_prob: np.ndarray, | |
| 561 classes: Optional[Sequence] = None, | |
| 562 title: str = "Precision–Recall Curve", | |
| 563 path: Optional[str] = None, | |
| 564 ) -> go.Figure: | |
| 565 """ | |
| 566 Multiclass PR curves (Plotly). If classes is None or len==2, shows binary PR. | |
| 567 """ | |
| 568 y_true = np.asarray(y_true) | |
| 569 y_prob = np.asarray(y_prob) | |
| 570 fig = go.Figure() | |
| 571 | |
| 572 if not classes or len(classes) == 2: | |
| 573 precision, recall, _ = precision_recall_curve(y_true, y_prob[:, 1]) | |
| 574 ap = average_precision_score(y_true, y_prob[:, 1]) | |
| 575 fig.add_trace(go.Scatter(x=recall, y=precision, mode="lines", name=f"AP = {ap:.2f}")) | |
| 576 else: | |
| 577 for i, cls in enumerate(classes): | |
| 578 y_true_bin = (y_true == cls).astype(int) | |
| 579 y_prob_cls = y_prob[:, i] | |
| 580 precision, recall, _ = precision_recall_curve(y_true_bin, y_prob_cls) | |
| 581 ap = average_precision_score(y_true_bin, y_prob_cls) | |
| 582 fig.add_trace(go.Scatter(x=recall, y=precision, mode="lines", name=f"{cls} (AP {ap:.2f})")) | |
| 583 | |
| 584 fig.update_layout( | |
| 585 title=None, | |
| 586 xaxis_title="Recall", | |
| 587 yaxis_title="Precision", | |
| 588 yaxis=dict(range=[0, 1]), | |
| 589 xaxis=dict(range=[0, 1]), | |
| 590 template="plotly_white", | |
| 591 ) | |
| 592 _save_plotly(fig, path) | |
| 593 return fig | |
| 594 | |
| 595 | |
| 596 def generate_metric_comparison_bar( | |
| 597 metrics_scores: Mapping[str, Sequence[float]], | |
| 598 phases: Sequence[str] = ("train", "val", "test"), | |
| 599 title: str = "Metric Comparison Across Phases", | |
| 600 path: Optional[str] = None, | |
| 601 ) -> go.Figure: | |
| 602 """ | |
| 603 Grouped bar chart comparing metrics across phases (Plotly). | |
| 604 metrics_scores: {metric_name: [train, val, test]} | |
| 605 """ | |
| 606 df = pd.DataFrame(metrics_scores, index=phases).T.reset_index().rename(columns={"index": "Metric"}) | |
| 607 df_m = df.melt(id_vars="Metric", var_name="Phase", value_name="Score") | |
| 608 fig = px.bar(df_m, x="Metric", y="Score", color="Phase", barmode="group", title=None) | |
| 609 ymax = max(1.0, df_m["Score"].max() * 1.05) | |
| 610 fig.update_yaxes(range=[0, ymax]) | |
| 611 fig.update_layout(template="plotly_white") | |
| 612 _save_plotly(fig, path) | |
| 613 return fig | |
| 614 | |
| 615 # ========================= | |
| 616 # Regression Plots | |
| 617 # ========================= | |
| 618 | |
| 619 | |
| 620 def generate_scatter_plot( | |
| 621 y_true: Sequence[float], | |
| 622 y_pred: Sequence[float], | |
| 623 title: str = "Predicted vs Actual", | |
| 624 path: Optional[str] = None, | |
| 625 ) -> go.Figure: | |
| 626 """ | |
| 627 Predicted vs. Actual scatter with y=x reference (Plotly). | |
| 628 """ | |
| 629 y_true = np.asarray(y_true) | |
| 630 y_pred = np.asarray(y_pred) | |
| 631 vmin = float(min(np.min(y_true), np.min(y_pred))) | |
| 632 vmax = float(max(np.max(y_true), np.max(y_pred))) | |
| 633 | |
| 634 fig = px.scatter(x=y_true, y=y_pred, opacity=0.6, labels={"x": "Actual", "y": "Predicted"}, title=None) | |
| 635 fig.add_trace(go.Scatter(x=[vmin, vmax], y=[vmin, vmax], mode="lines", line=dict(dash="dash"), name="Ideal")) | |
| 636 fig.update_layout(template="plotly_white") | |
| 637 _save_plotly(fig, path) | |
| 638 return fig | |
| 639 | |
| 640 | |
| 641 def generate_residual_plot( | |
| 642 y_true: Sequence[float], | |
| 643 y_pred: Sequence[float], | |
| 644 title: str = "Residual Plot", | |
| 645 path: Optional[str] = None, | |
| 646 ) -> go.Figure: | |
| 647 """ | |
| 648 Residuals vs Predicted (Plotly). | |
| 649 """ | |
| 650 y_true = np.asarray(y_true) | |
| 651 y_pred = np.asarray(y_pred) | |
| 652 residuals = y_true - y_pred | |
| 653 | |
| 654 fig = px.scatter(x=y_pred, y=residuals, opacity=0.6, | |
| 655 labels={"x": "Predicted", "y": "Residual (Actual - Predicted)"}, | |
| 656 title=None) | |
| 657 fig.add_hline(y=0, line_dash="dash") | |
| 658 fig.update_layout(template="plotly_white") | |
| 659 _save_plotly(fig, path) | |
| 660 return fig | |
| 661 | |
| 662 | |
| 663 def generate_residual_histogram( | |
| 664 y_true: Sequence[float], | |
| 665 y_pred: Sequence[float], | |
| 666 bins: int = 30, | |
| 667 title: str = "Residual Histogram", | |
| 668 path: Optional[str] = None, | |
| 669 ) -> go.Figure: | |
| 670 """ | |
| 671 Residuals histogram (Plotly). | |
| 672 """ | |
| 673 residuals = np.asarray(y_true) - np.asarray(y_pred) | |
| 674 fig = px.histogram(x=residuals, nbins=bins, labels={"x": "Residual"}, title=None) | |
| 675 fig.update_layout(yaxis_title="Frequency", template="plotly_white") | |
| 676 _save_plotly(fig, path) | |
| 677 return fig | |
| 678 | |
| 679 | |
| 680 def generate_regression_calibration_plot( | |
| 681 y_true: Sequence[float], | |
| 682 y_pred: Sequence[float], | |
| 683 num_bins: int = 10, | |
| 684 title: str = "Regression Calibration Plot", | |
| 685 path: Optional[str] = None, | |
| 686 ) -> go.Figure: | |
| 687 """ | |
| 688 Binned Actual vs Predicted means (Plotly). | |
| 689 """ | |
| 690 y_true = np.asarray(y_true) | |
| 691 y_pred = np.asarray(y_pred) | |
| 692 | |
| 693 order = np.argsort(y_pred) | |
| 694 y_true_sorted = y_true[order] | |
| 695 y_pred_sorted = y_pred[order] | |
| 696 | |
| 697 bins = np.array_split(np.arange(len(y_pred_sorted)), num_bins) | |
| 698 bin_means_pred = [float(np.mean(y_pred_sorted[idx])) for idx in bins if len(idx)] | |
| 699 bin_means_true = [float(np.mean(y_true_sorted[idx])) for idx in bins if len(idx)] | |
| 700 | |
| 701 vmin = float(min(np.min(y_pred), np.min(y_true))) | |
| 702 vmax = float(max(np.max(y_pred), np.max(y_true))) | |
| 703 | |
| 704 fig = go.Figure() | |
| 705 fig.add_trace(go.Scatter(x=bin_means_pred, y=bin_means_true, mode="lines+markers", | |
| 706 name="Binned Actual vs Predicted")) | |
| 707 fig.add_trace(go.Scatter(x=[vmin, vmax], y=[vmin, vmax], mode="lines", line=dict(dash="dash"), | |
| 708 name="Ideal")) | |
| 709 fig.update_layout( | |
| 710 title=None, | |
| 711 xaxis_title="Mean Predicted per bin", | |
| 712 yaxis_title="Mean Actual per bin", | |
| 713 template="plotly_white", | |
| 714 ) | |
| 715 _save_plotly(fig, path) | |
| 716 return fig | |
| 717 | |
| 718 # ========================= | |
| 719 # Confidence / Diagnostics | |
| 720 # ========================= | |
| 721 | |
| 722 | |
| 723 def plot_error_vs_confidence( | |
| 724 y_true: Union[Sequence[int], np.ndarray], | |
| 725 y_proba: Union[Sequence[float], np.ndarray], | |
| 726 n_bins: int = 10, | |
| 727 title: str = "Error vs Confidence", | |
| 728 path: Optional[str] = None, | |
| 729 ) -> go.Figure: | |
| 730 """ | |
| 731 Error rate vs confidence (binary), confidence=max(p, 1-p). Plotly. | |
| 732 """ | |
| 733 y_true = np.asarray(y_true) | |
| 734 y_proba = np.asarray(y_proba).reshape(-1) | |
| 735 y_pred = (y_proba >= 0.5).astype(int) | |
| 736 confidence = np.maximum(y_proba, 1 - y_proba) | |
| 737 error = (y_pred != y_true).astype(int) | |
| 738 | |
| 739 bins = np.linspace(0.0, 1.0, n_bins + 1) | |
| 740 idx = np.digitize(confidence, bins, right=True) | |
| 741 | |
| 742 centers, err_rates = [], [] | |
| 743 for i in range(1, len(bins)): | |
| 744 mask = (idx == i) | |
| 745 if mask.any(): | |
| 746 centers.append(float(confidence[mask].mean())) | |
| 747 err_rates.append(float(error[mask].mean())) | |
| 748 | |
| 749 fig = go.Figure() | |
| 750 fig.add_trace(go.Scatter(x=centers, y=err_rates, mode="lines+markers", name="Error rate")) | |
| 751 fig.update_layout( | |
| 752 title=None, | |
| 753 xaxis_title="Confidence (max predicted probability)", | |
| 754 yaxis_title="Error Rate", | |
| 755 yaxis=dict(range=[0, 1]), | |
| 756 template="plotly_white", | |
| 757 ) | |
| 758 _save_plotly(fig, path) | |
| 759 return fig | |
| 760 | |
| 761 | |
| 762 def plot_confidence_histogram( | |
| 763 y_proba: np.ndarray, | |
| 764 bins: int = 20, | |
| 765 title: str = "Confidence Histogram", | |
| 766 path: Optional[str] = None, | |
| 767 ) -> go.Figure: | |
| 768 """ | |
| 769 Histogram of max predicted probabilities (Plotly). | |
| 770 Works for binary (n_samples,) or (n_samples,2) and multiclass (n_samples,C). | |
| 771 """ | |
| 772 y_proba = np.asarray(y_proba) | |
| 773 if y_proba.ndim == 1: | |
| 774 confidences = np.maximum(y_proba, 1 - y_proba) | |
| 775 else: | |
| 776 confidences = np.max(y_proba, axis=1) | |
| 777 | |
| 778 fig = px.histogram( | |
| 779 x=confidences, | |
| 780 nbins=bins, | |
| 781 range_x=(0, 1), | |
| 782 histnorm="percent", | |
| 783 labels={"x": "Confidence (max predicted probability)", "y": "Percent of samples (%)"}, | |
| 784 title=None, | |
| 785 ) | |
| 786 if fig.data: | |
| 787 fig.update_traces(hovertemplate="Conf=%{x:.2f}<br>%{y:.2f}%<extra></extra>") | |
| 788 fig.update_layout(yaxis_title="Percent of samples (%)", template="plotly_white") | |
| 789 _save_plotly(fig, path) | |
| 790 return fig | |
| 791 | |
| 792 # ========================= | |
| 793 # Learning Curve | |
| 794 # ========================= | |
| 795 | |
| 796 | |
| 797 def generate_learning_curve_from_predictions( | |
| 798 y_true, | |
| 799 y_pred=None, | |
| 800 y_proba=None, | |
| 801 classes=None, | |
| 802 metric: str = "accuracy", | |
| 803 train_fracs: np.ndarray = np.linspace(0.1, 1.0, 10), | |
| 804 n_repeats: int = 5, | |
| 805 seed: int = 42, | |
| 806 title: str = "Learning Curve", | |
| 807 path: str | None = None, | |
| 808 return_stats: bool = False, | |
| 809 ) -> Union[go.Figure, tuple[list[int], list[float], list[float]]]: | |
| 810 rng = np.random.default_rng(seed) | |
| 811 y_true = np.asarray(y_true) | |
| 812 N = len(y_true) | |
| 813 | |
| 814 if metric == "accuracy" and y_pred is None: | |
| 815 raise ValueError("accuracy curve requires y_pred") | |
| 816 if metric == "log_loss" and y_proba is None: | |
| 817 raise ValueError("log_loss curve requires y_proba") | |
| 818 | |
| 819 if y_proba is not None: | |
| 820 y_proba = np.asarray(y_proba) | |
| 821 if y_pred is not None: | |
| 822 y_pred = np.asarray(y_pred) | |
| 823 | |
| 824 sizes = (np.clip((train_fracs * N).astype(int), 1, N)).tolist() | |
| 825 means, stds = [], [] | |
| 826 for n in sizes: | |
| 827 vals = [] | |
| 828 for _ in range(n_repeats): | |
| 829 idx = rng.choice(N, size=n, replace=False) | |
| 830 if metric == "accuracy": | |
| 831 vals.append(float((y_true[idx] == y_pred[idx]).mean())) | |
| 832 else: | |
| 833 if y_proba.ndim == 1: | |
| 834 p = y_proba[idx] | |
| 835 pp = np.column_stack([1 - p, p]) | |
| 836 else: | |
| 837 pp = y_proba[idx] | |
| 838 vals.append(float(log_loss(y_true[idx], pp, labels=None if classes is None else classes))) | |
| 839 means.append(np.mean(vals)) | |
| 840 stds.append(np.std(vals)) | |
| 841 | |
| 842 if return_stats: | |
| 843 return sizes, means, stds | |
| 844 | |
| 845 fig = go.Figure() | |
| 846 fig.add_trace(go.Scatter( | |
| 847 x=sizes, y=means, mode="lines+markers", name="Train", | |
| 848 line=dict(width=3, shape="spline"), marker=dict(size=7), | |
| 849 error_y=dict(type="data", array=stds, visible=True) | |
| 850 )) | |
| 851 fig.update_layout( | |
| 852 title=None, | |
| 853 template="plotly_white", | |
| 854 xaxis=dict(title="epoch" if metric == "log_loss" else "samples", gridcolor="#eee"), | |
| 855 yaxis=dict(title=("loss" if metric == "log_loss" else "accuracy"), gridcolor="#eee"), | |
| 856 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), | |
| 857 margin=dict(l=50, r=20, t=60, b=50), | |
| 858 ) | |
| 859 if path: | |
| 860 _save_plotly(fig, path) | |
| 861 return fig | |
| 862 | |
| 863 | |
| 864 def build_train_html_and_plots( | |
| 865 predictor, | |
| 866 problem_type: str, | |
| 867 df_train: pd.DataFrame, | |
| 868 label_column: str, | |
| 869 tmpdir: str, | |
| 870 df_val: Optional[pd.DataFrame] = None, | |
| 871 seed: int = 42, | |
| 872 perf_table_html: str | None = None, | |
| 873 threshold: Optional[float] = None, | |
| 874 section_tile: str = "Training Diagnostics", | |
| 875 ) -> str: | |
| 876 y_true = df_train[label_column].values | |
| 877 y_true_val = df_val[label_column].values if df_val is not None else None | |
| 878 # predictions on TRAIN | |
| 879 pred_labels, pred_proba = None, None | |
| 880 try: | |
| 881 pred_labels = predictor.predict(df_train) | |
| 882 except Exception: | |
| 883 pass | |
| 884 try: | |
| 885 proba_raw = predictor.predict_proba(df_train) | |
| 886 pred_proba = proba_raw.to_numpy() if isinstance(proba_raw, (pd.Series, pd.DataFrame)) else np.asarray(proba_raw) | |
| 887 except Exception: | |
| 888 pred_proba = None | |
| 889 | |
| 890 # predictions on VAL (if provided) | |
| 891 pred_labels_val, pred_proba_val = None, None | |
| 892 if df_val is not None: | |
| 893 try: | |
| 894 pred_labels_val = predictor.predict(df_val) | |
| 895 except Exception: | |
| 896 pred_labels_val = None | |
| 897 try: | |
| 898 proba_raw_val = predictor.predict_proba(df_val) | |
| 899 pred_proba_val = proba_raw_val.to_numpy() if isinstance(proba_raw_val, (pd.Series, pd.DataFrame)) else np.asarray(proba_raw_val) | |
| 900 except Exception: | |
| 901 pred_proba_val = None | |
| 902 | |
| 903 pos_scores_train: Optional[np.ndarray] = None | |
| 904 pos_scores_val: Optional[np.ndarray] = None | |
| 905 if problem_type == "binary": | |
| 906 if pred_proba is not None: | |
| 907 pos_scores_train = ( | |
| 908 pred_proba.reshape(-1) | |
| 909 if pred_proba.ndim == 1 or (pred_proba.ndim == 2 and pred_proba.shape[1] == 1) | |
| 910 else pred_proba[:, -1] | |
| 911 ) | |
| 912 if pred_proba_val is not None: | |
| 913 pos_scores_val = ( | |
| 914 pred_proba_val.reshape(-1) | |
| 915 if pred_proba_val.ndim == 1 or (pred_proba_val.ndim == 2 and pred_proba_val.shape[1] == 1) | |
| 916 else pred_proba_val[:, -1] | |
| 917 ) | |
| 918 | |
| 919 # Collect plots then append in desired order | |
| 920 perf_card = f"<div class='card'>{perf_table_html}</div>" if perf_table_html else None | |
| 921 acc_plot = loss_plot = None | |
| 922 cm_train = pc_train = cm_val = pc_val = None | |
| 923 threshold_val_plot = None | |
| 924 roc_combined = pr_combined = cal_combined = None | |
| 925 mc_roc_val = None | |
| 926 conf_train = conf_val = None | |
| 927 bar_train = bar_val = None | |
| 928 | |
| 929 # 1) Learning Curve — Accuracy | |
| 930 if problem_type in ("binary", "multiclass"): | |
| 931 acc_fig = go.Figure() | |
| 932 added_acc = False | |
| 933 if pred_labels is not None: | |
| 934 train_sizes, train_means, train_stds = generate_learning_curve_from_predictions( | |
| 935 y_true=y_true, | |
| 936 y_pred=np.asarray(pred_labels), | |
| 937 metric="accuracy", | |
| 938 title="Learning Curves — Label Accuracy", | |
| 939 seed=seed, | |
| 940 return_stats=True, | |
| 941 ) | |
| 942 acc_fig.add_trace(go.Scatter( | |
| 943 x=train_sizes, y=train_means, mode="lines+markers", name="Train", | |
| 944 line=dict(color="#1f77b4", width=3, shape="spline"), marker=dict(size=7), | |
| 945 error_y=dict(type="data", array=train_stds, visible=True), | |
| 946 )) | |
| 947 added_acc = True | |
| 948 if pred_labels_val is not None and y_true_val is not None: | |
| 949 val_sizes, val_means, val_stds = generate_learning_curve_from_predictions( | |
| 950 y_true=y_true_val, | |
| 951 y_pred=np.asarray(pred_labels_val), | |
| 952 metric="accuracy", | |
| 953 title="Learning Curves — Label Accuracy", | |
| 954 seed=seed, | |
| 955 return_stats=True, | |
| 956 ) | |
| 957 acc_fig.add_trace(go.Scatter( | |
| 958 x=val_sizes, y=val_means, mode="lines+markers", name="Validation", | |
| 959 line=dict(color="#ff7f0e", width=3, shape="spline"), marker=dict(size=7), | |
| 960 error_y=dict(type="data", array=val_stds, visible=True), | |
| 961 )) | |
| 962 added_acc = True | |
| 963 if added_acc: | |
| 964 acc_fig.update_layout( | |
| 965 title=None, | |
| 966 template="plotly_white", | |
| 967 xaxis=dict(title="samples", gridcolor="#eee"), | |
| 968 yaxis=dict(title="accuracy", gridcolor="#eee"), | |
| 969 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), | |
| 970 margin=dict(l=50, r=20, t=60, b=50), | |
| 971 ) | |
| 972 acc_plot = plot_with_table_style_title(acc_fig, "Learning Curves — Label Accuracy") | |
| 973 | |
| 974 # 2) Learning Curve — Loss | |
| 975 if problem_type in ("binary", "multiclass"): | |
| 976 classes = np.unique(y_true) | |
| 977 loss_fig = go.Figure() | |
| 978 added_loss = False | |
| 979 if pred_proba is not None: | |
| 980 pp = pred_proba.reshape(-1) if pred_proba.ndim == 1 or (pred_proba.ndim == 2 and pred_proba.shape[1] == 1) else pred_proba | |
| 981 train_sizes, train_means, train_stds = generate_learning_curve_from_predictions( | |
| 982 y_true=y_true, | |
| 983 y_proba=pp, | |
| 984 classes=classes, | |
| 985 metric="log_loss", | |
| 986 title="Learning Curves — Label Loss", | |
| 987 seed=seed, | |
| 988 return_stats=True, | |
| 989 ) | |
| 990 loss_fig.add_trace(go.Scatter( | |
| 991 x=train_sizes, y=train_means, mode="lines+markers", name="Train", | |
| 992 line=dict(color="#1f77b4", width=3, shape="spline"), marker=dict(size=7), | |
| 993 error_y=dict(type="data", array=train_stds, visible=True), | |
| 994 )) | |
| 995 added_loss = True | |
| 996 if pred_proba_val is not None and y_true_val is not None: | |
| 997 pp_val = pred_proba_val.reshape(-1) if pred_proba_val.ndim == 1 or (pred_proba_val.ndim == 2 and pred_proba_val.shape[1] == 1) else pred_proba_val | |
| 998 val_sizes, val_means, val_stds = generate_learning_curve_from_predictions( | |
| 999 y_true=y_true_val, | |
| 1000 y_proba=pp_val, | |
| 1001 classes=classes, | |
| 1002 metric="log_loss", | |
| 1003 title="Learning Curves — Label Loss", | |
| 1004 seed=seed, | |
| 1005 return_stats=True, | |
| 1006 ) | |
| 1007 loss_fig.add_trace(go.Scatter( | |
| 1008 x=val_sizes, y=val_means, mode="lines+markers", name="Validation", | |
| 1009 line=dict(color="#ff7f0e", width=3, shape="spline"), marker=dict(size=7), | |
| 1010 error_y=dict(type="data", array=val_stds, visible=True), | |
| 1011 )) | |
| 1012 added_loss = True | |
| 1013 if added_loss: | |
| 1014 loss_fig.update_layout( | |
| 1015 title=None, | |
| 1016 template="plotly_white", | |
| 1017 xaxis=dict(title="epoch", gridcolor="#eee"), | |
| 1018 yaxis=dict(title="loss", gridcolor="#eee"), | |
| 1019 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), | |
| 1020 margin=dict(l=50, r=20, t=60, b=50), | |
| 1021 ) | |
| 1022 loss_plot = plot_with_table_style_title(loss_fig, "Learning Curves — Label Loss") | |
| 1023 | |
| 1024 # Confusion matrices & per-class metrics | |
| 1025 cm_train = pc_train = cm_val = pc_val = None | |
| 1026 | |
| 1027 # Probability diagnostics (binary) | |
| 1028 if problem_type == "binary": | |
| 1029 # Combined Calibration (Train/Val) | |
| 1030 cal_fig = go.Figure() | |
| 1031 added_cal = False | |
| 1032 if pos_scores_train is not None: | |
| 1033 y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int) | |
| 1034 prob_true, prob_pred = calibration_curve(y_bin_train, pos_scores_train, n_bins=10, strategy="uniform") | |
| 1035 cal_fig.add_trace(go.Scatter( | |
| 1036 x=prob_pred, y=prob_true, mode="lines+markers", | |
| 1037 name="Train", | |
| 1038 line=dict(color="#1f77b4", width=3), | |
| 1039 marker=dict(size=7, color="#1f77b4"), | |
| 1040 )) | |
| 1041 added_cal = True | |
| 1042 if pos_scores_val is not None and y_true_val is not None: | |
| 1043 y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) | |
| 1044 prob_true_v, prob_pred_v = calibration_curve(y_bin_val, pos_scores_val, n_bins=10, strategy="uniform") | |
| 1045 cal_fig.add_trace(go.Scatter( | |
| 1046 x=prob_pred_v, y=prob_true_v, mode="lines+markers", | |
| 1047 name="Validation", | |
| 1048 line=dict(color="#ff7f0e", width=3), | |
| 1049 marker=dict(size=7, color="#ff7f0e"), | |
| 1050 )) | |
| 1051 added_cal = True | |
| 1052 if added_cal: | |
| 1053 cal_fig.add_trace(go.Scatter( | |
| 1054 x=[0, 1], y=[0, 1], | |
| 1055 mode="lines", | |
| 1056 line=dict(dash="dash", color="#808080", width=2), | |
| 1057 name="Perfect", | |
| 1058 showlegend=True, | |
| 1059 )) | |
| 1060 cal_fig.update_layout( | |
| 1061 title=None, | |
| 1062 xaxis_title="Predicted Probability", | |
| 1063 yaxis_title="Observed Probability", | |
| 1064 xaxis=dict(range=[0, 1]), | |
| 1065 yaxis=dict(range=[0, 1]), | |
| 1066 template="plotly_white", | |
| 1067 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), | |
| 1068 margin=dict(l=60, r=40, t=50, b=50), | |
| 1069 ) | |
| 1070 cal_combined = plot_with_table_style_title(cal_fig, "Calibration Curve (Train vs Validation)") | |
| 1071 | |
| 1072 # Combined ROC (Train/Val) | |
| 1073 roc_fig = go.Figure() | |
| 1074 added_roc = False | |
| 1075 if pos_scores_train is not None: | |
| 1076 y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int) | |
| 1077 fpr_tr, tpr_tr, thr_tr = roc_curve(y_bin_train, pos_scores_train) | |
| 1078 roc_fig.add_trace(go.Scatter( | |
| 1079 x=fpr_tr, y=tpr_tr, mode="lines", | |
| 1080 name="Train", | |
| 1081 line=dict(color="#1f77b4", width=3), | |
| 1082 )) | |
| 1083 if threshold is not None and np.isfinite(thr_tr).any(): | |
| 1084 finite = np.isfinite(thr_tr) | |
| 1085 idx_local = int(np.argmin(np.abs(thr_tr[finite] - float(threshold)))) | |
| 1086 idx = int(np.nonzero(finite)[0][idx_local]) | |
| 1087 roc_fig.add_trace(go.Scatter( | |
| 1088 x=[fpr_tr[idx]], y=[tpr_tr[idx]], | |
| 1089 mode="markers", | |
| 1090 name="Train @ threshold", | |
| 1091 marker=dict(size=12, color="#1f77b4", symbol="x") | |
| 1092 )) | |
| 1093 added_roc = True | |
| 1094 if pos_scores_val is not None and y_true_val is not None: | |
| 1095 y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) | |
| 1096 fpr_v, tpr_v, thr_v = roc_curve(y_bin_val, pos_scores_val) | |
| 1097 roc_fig.add_trace(go.Scatter( | |
| 1098 x=fpr_v, y=tpr_v, mode="lines", | |
| 1099 name="Validation", | |
| 1100 line=dict(color="#ff7f0e", width=3), | |
| 1101 )) | |
| 1102 if threshold is not None and np.isfinite(thr_v).any(): | |
| 1103 finite = np.isfinite(thr_v) | |
| 1104 idx_local = int(np.argmin(np.abs(thr_v[finite] - float(threshold)))) | |
| 1105 idx = int(np.nonzero(finite)[0][idx_local]) | |
| 1106 roc_fig.add_trace(go.Scatter( | |
| 1107 x=[fpr_v[idx]], y=[tpr_v[idx]], | |
| 1108 mode="markers", | |
| 1109 name="Val @ threshold", | |
| 1110 marker=dict(size=12, color="#ff7f0e", symbol="x") | |
| 1111 )) | |
| 1112 added_roc = True | |
| 1113 if added_roc: | |
| 1114 roc_fig.add_trace(go.Scatter( | |
| 1115 x=[0, 1], y=[0, 1], mode="lines", | |
| 1116 line=dict(dash="dash", width=2, color="#808080"), | |
| 1117 showlegend=False | |
| 1118 )) | |
| 1119 roc_fig.update_layout( | |
| 1120 title=None, | |
| 1121 xaxis_title="False Positive Rate", | |
| 1122 yaxis_title="True Positive Rate", | |
| 1123 template="plotly_white", | |
| 1124 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), | |
| 1125 margin=dict(l=60, r=20, t=60, b=60), | |
| 1126 ) | |
| 1127 roc_combined = plot_with_table_style_title(roc_fig, "ROC Curve (Train vs Validation)") | |
| 1128 | |
| 1129 # Combined PR (Train/Val) | |
| 1130 pr_fig = go.Figure() | |
| 1131 added_pr = False | |
| 1132 if pos_scores_train is not None: | |
| 1133 y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int) | |
| 1134 prec_tr, rec_tr, thr_tr = precision_recall_curve(y_bin_train, pos_scores_train) | |
| 1135 pr_auc_tr = auc(rec_tr, prec_tr) | |
| 1136 pr_fig.add_trace(go.Scatter( | |
| 1137 x=rec_tr, y=prec_tr, mode="lines", | |
| 1138 name=f"Train (AUC={pr_auc_tr:.3f})", | |
| 1139 line=dict(color="#1f77b4", width=3), | |
| 1140 )) | |
| 1141 if threshold is not None and len(thr_tr): | |
| 1142 j = int(np.argmin(np.abs(thr_tr - float(threshold)))) | |
| 1143 j = int(np.clip(j, 0, len(thr_tr) - 1)) | |
| 1144 pr_fig.add_trace(go.Scatter( | |
| 1145 x=[rec_tr[j + 1]], y=[prec_tr[j + 1]], | |
| 1146 mode="markers", | |
| 1147 name="Train @ threshold", | |
| 1148 marker=dict(size=12, color="#1f77b4", symbol="x") | |
| 1149 )) | |
| 1150 added_pr = True | |
| 1151 if pos_scores_val is not None and y_true_val is not None: | |
| 1152 y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) | |
| 1153 prec_v, rec_v, thr_v = precision_recall_curve(y_bin_val, pos_scores_val) | |
| 1154 pr_auc_v = auc(rec_v, prec_v) | |
| 1155 pr_fig.add_trace(go.Scatter( | |
| 1156 x=rec_v, y=prec_v, mode="lines", | |
| 1157 name=f"Validation (AUC={pr_auc_v:.3f})", | |
| 1158 line=dict(color="#ff7f0e", width=3), | |
| 1159 )) | |
| 1160 if threshold is not None and len(thr_v): | |
| 1161 j = int(np.argmin(np.abs(thr_v - float(threshold)))) | |
| 1162 j = int(np.clip(j, 0, len(thr_v) - 1)) | |
| 1163 pr_fig.add_trace(go.Scatter( | |
| 1164 x=[rec_v[j + 1]], y=[prec_v[j + 1]], | |
| 1165 mode="markers", | |
| 1166 name="Val @ threshold", | |
| 1167 marker=dict(size=12, color="#ff7f0e", symbol="x") | |
| 1168 )) | |
| 1169 added_pr = True | |
| 1170 if added_pr: | |
| 1171 pr_fig.update_layout( | |
| 1172 title=None, | |
| 1173 xaxis_title="Recall", | |
| 1174 yaxis_title="Precision", | |
| 1175 template="plotly_white", | |
| 1176 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), | |
| 1177 margin=dict(l=60, r=20, t=60, b=60), | |
| 1178 ) | |
| 1179 pr_combined = plot_with_table_style_title(pr_fig, "Precision–Recall Curve (Train vs Validation)") | |
| 1180 | |
| 1181 if pos_scores_val is not None and y_true_val is not None: | |
| 1182 y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) | |
| 1183 fig_thr_val = generate_threshold_plot(y_true_bin=y_bin_val, y_prob=pos_scores_val, title="Threshold Plot (Validation)", | |
| 1184 user_threshold=threshold) | |
| 1185 threshold_val_plot = plot_with_table_style_title(fig_thr_val, "Threshold Plot (Validation)") | |
| 1186 | |
| 1187 # Multiclass OVR ROC (validation) | |
| 1188 if problem_type == "multiclass" and pred_proba_val is not None and pred_proba_val.ndim >= 2 and y_true_val is not None: | |
| 1189 classes_val = np.unique(y_true_val) | |
| 1190 fig_mc_roc_val = generate_multiclass_roc_curve_plot(y_true_val, pred_proba_val, classes_val, title="One-vs-Rest ROC (Validation)") | |
| 1191 mc_roc_val = plot_with_table_style_title(fig_mc_roc_val, "One-vs-Rest ROC (Validation)") | |
| 1192 | |
| 1193 # Prediction Confidence Histogram (train/val) | |
| 1194 conf_train = conf_val = None | |
| 1195 | |
| 1196 # Per-class accuracy bars | |
| 1197 if problem_type in ("binary", "multiclass") and pred_labels is not None: | |
| 1198 classes_for_bar = pd.Index(np.unique(y_true), dtype=object).tolist() | |
| 1199 acc_vals = [] | |
| 1200 for c in classes_for_bar: | |
| 1201 mask = y_true == c | |
| 1202 acc_vals.append(float((np.asarray(pred_labels)[mask] == c).mean()) if mask.any() else 0.0) | |
| 1203 bar_fig = go.Figure(data=go.Bar(x=[str(c) for c in classes_for_bar], y=acc_vals, marker_color="#1f77b4")) | |
| 1204 bar_fig.update_layout( | |
| 1205 title=None, | |
| 1206 template="plotly_white", | |
| 1207 xaxis=dict(title="Label", gridcolor="#eee"), | |
| 1208 yaxis=dict(title="Accuracy", gridcolor="#eee", range=[0, 1]), | |
| 1209 margin=dict(l=50, r=20, t=60, b=50), | |
| 1210 ) | |
| 1211 bar_train = plot_with_table_style_title(bar_fig, "Per-Class Training Accuracy") | |
| 1212 if problem_type in ("binary", "multiclass") and pred_labels_val is not None and y_true_val is not None: | |
| 1213 classes_for_bar_val = pd.Index(np.unique(y_true_val), dtype=object).tolist() | |
| 1214 acc_vals_val = [] | |
| 1215 for c in classes_for_bar_val: | |
| 1216 mask = y_true_val == c | |
| 1217 acc_vals_val.append(float((np.asarray(pred_labels_val)[mask] == c).mean()) if mask.any() else 0.0) | |
| 1218 bar_fig_val = go.Figure(data=go.Bar(x=[str(c) for c in classes_for_bar_val], y=acc_vals_val, marker_color="#ff7f0e")) | |
| 1219 bar_fig_val.update_layout( | |
| 1220 title=None, | |
| 1221 template="plotly_white", | |
| 1222 xaxis=dict(title="Label", gridcolor="#eee"), | |
| 1223 yaxis=dict(title="Accuracy", gridcolor="#eee", range=[0, 1]), | |
| 1224 margin=dict(l=50, r=20, t=60, b=50), | |
| 1225 ) | |
| 1226 bar_val = plot_with_table_style_title(bar_fig_val, "Per-Class Validation Accuracy") | |
| 1227 | |
| 1228 # Assemble in requested order | |
| 1229 pieces: list[str] = [] | |
| 1230 if perf_card: | |
| 1231 pieces.append(perf_card) | |
| 1232 for block in (threshold_val_plot, roc_combined, pr_combined): | |
| 1233 if block: | |
| 1234 pieces.append(block) | |
| 1235 # Remaining plots (keep existing order) | |
| 1236 for block in (cal_combined, cm_train, pc_train, cm_val, pc_val, mc_roc_val, conf_train, conf_val, bar_train, bar_val): | |
| 1237 if block: | |
| 1238 pieces.append(block) | |
| 1239 # Learning curves should appear last in the tab | |
| 1240 for block in (acc_plot, loss_plot): | |
| 1241 if block: | |
| 1242 pieces.append(block) | |
| 1243 | |
| 1244 if not pieces: | |
| 1245 return "<h2>Training Diagnostics</h2><p><em>No training diagnostics available for this run.</em></p>" | |
| 1246 | |
| 1247 return "<h2>Train and Validation Performance Summary</h2>" + "".join(pieces) | |
| 1248 | |
| 1249 | |
| 1250 def generate_learning_curve( | |
| 1251 estimator, | |
| 1252 X, | |
| 1253 y, | |
| 1254 scoring: str = "r2", | |
| 1255 cv_folds: int = 5, | |
| 1256 n_jobs: int = -1, | |
| 1257 train_sizes: np.ndarray = np.linspace(0.1, 1.0, 10), | |
| 1258 title: str = "Learning Curve", | |
| 1259 path: Optional[str] = None, | |
| 1260 ) -> go.Figure: | |
| 1261 """ | |
| 1262 Learning curve using sklearn.learning_curve, visualized with Plotly. | |
| 1263 """ | |
| 1264 sizes, train_scores, test_scores = skl_learning_curve( | |
| 1265 estimator, X, y, cv=cv_folds, scoring=scoring, n_jobs=n_jobs, train_sizes=train_sizes | |
| 1266 ) | |
| 1267 train_mean = train_scores.mean(axis=1) | |
| 1268 train_std = train_scores.std(axis=1) | |
| 1269 test_mean = test_scores.mean(axis=1) | |
| 1270 test_std = test_scores.std(axis=1) | |
| 1271 | |
| 1272 fig = go.Figure() | |
| 1273 fig.add_trace(go.Scatter( | |
| 1274 x=sizes, y=train_mean, mode="lines+markers", name="Training score", | |
| 1275 error_y=dict(type="data", array=train_std, visible=True) | |
| 1276 )) | |
| 1277 fig.add_trace(go.Scatter( | |
| 1278 x=sizes, y=test_mean, mode="lines+markers", name="CV score", | |
| 1279 error_y=dict(type="data", array=test_std, visible=True) | |
| 1280 )) | |
| 1281 fig.update_layout( | |
| 1282 title=None, | |
| 1283 xaxis_title="Training examples", | |
| 1284 yaxis_title=scoring, | |
| 1285 template="plotly_white", | |
| 1286 ) | |
| 1287 _save_plotly(fig, path) | |
| 1288 return fig | |
| 1289 | |
| 1290 # ========================= | |
| 1291 # SHAP (Matplotlib-based) | |
| 1292 # ========================= | |
| 1293 | |
| 1294 | |
| 1295 def generate_shap_summary_plot( | |
| 1296 shap_values, features: pd.DataFrame, title: str = "SHAP Summary Plot", path: Optional[str] = None | |
| 1297 ) -> None: | |
| 1298 """ | |
| 1299 SHAP summary plot (Matplotlib). SHAP's interactive support with Plotly is limited; | |
| 1300 keep matplotlib for clarity and stability. | |
| 1301 """ | |
| 1302 plt.figure(figsize=(10, 8)) | |
| 1303 shap.summary_plot(shap_values, features, show=False) | |
| 1304 plt.title(title) | |
| 1305 _save_matplotlib(path) | |
| 1306 | |
| 1307 | |
| 1308 def generate_shap_force_plot( | |
| 1309 explainer, instance: pd.DataFrame, title: str = "SHAP Force Plot", path: Optional[str] = None | |
| 1310 ) -> None: | |
| 1311 """ | |
| 1312 SHAP force plot (Matplotlib). | |
| 1313 """ | |
| 1314 shap_values = explainer(instance) | |
| 1315 plt.figure(figsize=(10, 4)) | |
| 1316 shap.plots.force(shap_values[0], show=False) | |
| 1317 plt.title(title) | |
| 1318 _save_matplotlib(path) | |
| 1319 | |
| 1320 | |
| 1321 def generate_shap_waterfall_plot( | |
| 1322 explainer, instance: pd.DataFrame, title: str = "SHAP Waterfall Plot", path: Optional[str] = None | |
| 1323 ) -> None: | |
| 1324 """ | |
| 1325 SHAP waterfall plot (Matplotlib). | |
| 1326 """ | |
| 1327 shap_values = explainer(instance) | |
| 1328 plt.figure(figsize=(10, 6)) | |
| 1329 shap.plots.waterfall(shap_values[0], show=False) | |
| 1330 plt.title(title) | |
| 1331 _save_matplotlib(path) | |
| 1332 | |
| 1333 | |
| 1334 def infer_problem_type(predictor, df_train_full: pd.DataFrame, label_column: str) -> str: | |
| 1335 """ | |
| 1336 Return 'binary', 'multiclass', or 'regression'. | |
| 1337 Prefer the predictor's own metadata when available; otherwise infer from label dtype/uniques. | |
| 1338 """ | |
| 1339 # AutoGluon predictors usually expose .problem_type; be defensive. | |
| 1340 pt = getattr(predictor, "problem_type", None) | |
| 1341 if isinstance(pt, str): | |
| 1342 pt_l = pt.lower() | |
| 1343 if "regression" in pt_l: | |
| 1344 return "regression" | |
| 1345 if "binary" in pt_l: | |
| 1346 return "binary" | |
| 1347 if "multiclass" in pt_l or "multiclass" in pt_l: | |
| 1348 return "multiclass" | |
| 1349 | |
| 1350 y = df_train_full[label_column] | |
| 1351 if pd.api.types.is_numeric_dtype(y) and y.nunique() > 10: | |
| 1352 return "regression" | |
| 1353 return "binary" if y.nunique() == 2 else "multiclass" | |
| 1354 | |
| 1355 | |
| 1356 def _safe_floatify(d: Dict[str, Any]) -> Dict[str, float]: | |
| 1357 """Make evaluate() outputs JSON/csv friendly floats.""" | |
| 1358 out = {} | |
| 1359 for k, v in d.items(): | |
| 1360 try: | |
| 1361 out[k] = float(v) | |
| 1362 except Exception: | |
| 1363 # keep only real-valued scalars | |
| 1364 pass | |
| 1365 return out | |
| 1366 | |
| 1367 | |
| 1368 def evaluate_all( | |
| 1369 predictor, | |
| 1370 df_train: pd.DataFrame, | |
| 1371 df_val: pd.DataFrame, | |
| 1372 df_test: pd.DataFrame, | |
| 1373 label_column: str, | |
| 1374 problem_type: str, | |
| 1375 ) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float]]: | |
| 1376 """ | |
| 1377 Run predictor.evaluate on train/val/test and normalize the result dicts to floats. | |
| 1378 MultiModalPredictor does not accept the `silent` kwarg, so call defensively. | |
| 1379 """ | |
| 1380 def _evaluate(df): | |
| 1381 try: | |
| 1382 return predictor.evaluate(df, silent=True) | |
| 1383 except TypeError: | |
| 1384 return predictor.evaluate(df) | |
| 1385 | |
| 1386 train_scores = _safe_floatify(_evaluate(df_train)) | |
| 1387 val_scores = _safe_floatify(_evaluate(df_val)) | |
| 1388 test_scores = _safe_floatify(_evaluate(df_test)) | |
| 1389 return train_scores, val_scores, test_scores | |
| 1390 | |
| 1391 | |
| 1392 def build_summary_html( | |
| 1393 predictor, | |
| 1394 df_train: pd.DataFrame, | |
| 1395 df_val: Optional[pd.DataFrame], | |
| 1396 df_test: Optional[pd.DataFrame], | |
| 1397 label_column: str, | |
| 1398 extra_run_rows: Optional[list[tuple[str, str]]] = None, | |
| 1399 class_balance_html: Optional[str] = None, | |
| 1400 perf_table_html: Optional[str] = None, | |
| 1401 ) -> str: | |
| 1402 sections = [] | |
| 1403 | |
| 1404 # Dataset Overview (first section in the tab) | |
| 1405 if class_balance_html: | |
| 1406 sections.append(f""" | |
| 1407 <section class="section"> | |
| 1408 <h2 class="section-title">Dataset Overview</h2> | |
| 1409 <div class="card"> | |
| 1410 {class_balance_html} | |
| 1411 </div> | |
| 1412 </section> | |
| 1413 """.strip()) | |
| 1414 | |
| 1415 # Performance Summary | |
| 1416 if perf_table_html: | |
| 1417 sections.append(f""" | |
| 1418 <section class="section"> | |
| 1419 <h2 class="section-title">Model Performance Summary</h2> | |
| 1420 <div class="card"> | |
| 1421 {perf_table_html} | |
| 1422 </div> | |
| 1423 </section> | |
| 1424 """.strip()) | |
| 1425 | |
| 1426 # Model Configuration | |
| 1427 | |
| 1428 # Remove Predictor type and Framework, and ensure Model Architecture is present | |
| 1429 base_rows: list[tuple[str, str]] = [] | |
| 1430 if extra_run_rows: | |
| 1431 # Remove any rows with keys 'Predictor type' or 'Framework' | |
| 1432 base_rows.extend([(k, v) for (k, v) in extra_run_rows if k not in ("Predictor type", "Framework")]) | |
| 1433 | |
| 1434 def _fmt(v): | |
| 1435 if v is None or v == "": | |
| 1436 return "—" | |
| 1437 return _escape(str(v)) | |
| 1438 | |
| 1439 rows_html = "\n".join( | |
| 1440 f"<tr><td>{_escape(str(k))}</td><td>{_fmt(v)}</td></tr>" | |
| 1441 for k, v in base_rows | |
| 1442 ) | |
| 1443 | |
| 1444 sections.append(f""" | |
| 1445 <section class="section"> | |
| 1446 <h2 class="section-title">Model Configuration</h2> | |
| 1447 <div class="card"> | |
| 1448 <table class="kv-table"> | |
| 1449 <thead><tr><th>Key</th><th>Value</th></tr></thead> | |
| 1450 <tbody> | |
| 1451 {rows_html} | |
| 1452 </tbody> | |
| 1453 </table> | |
| 1454 </div> | |
| 1455 </section> | |
| 1456 """.strip()) | |
| 1457 | |
| 1458 return "\n".join(sections).strip() | |
| 1459 | |
| 1460 | |
| 1461 def build_feature_importance_html(predictor, df_train: pd.DataFrame, label_column: str) -> str: | |
| 1462 """Build a visualization of feature importance.""" | |
| 1463 try: | |
| 1464 # Try to get feature importance from predictor | |
| 1465 fi = None | |
| 1466 if hasattr(predictor, "feature_importance") and callable(predictor.feature_importance): | |
| 1467 try: | |
| 1468 fi = predictor.feature_importance(df_train) | |
| 1469 except Exception as e: | |
| 1470 return f"<p>Could not compute feature importance: {e}</p>" | |
| 1471 | |
| 1472 if fi is None or (isinstance(fi, pd.DataFrame) and fi.empty): | |
| 1473 return "<p>Feature importance not available for this model.</p>" | |
| 1474 | |
| 1475 # Format as a sortable table | |
| 1476 rows = [] | |
| 1477 if isinstance(fi, pd.DataFrame): | |
| 1478 fi = fi.sort_values("importance", ascending=False) | |
| 1479 for _, row in fi.iterrows(): | |
| 1480 feat = row.index[0] if isinstance(row.index, pd.Index) else row["feature"] | |
| 1481 imp = float(row["importance"]) | |
| 1482 rows.append(f"<tr><td>{_escape(str(feat))}</td><td>{imp:.4f}</td></tr>") | |
| 1483 else: | |
| 1484 # Handle other formats (dict, etc) | |
| 1485 for feat, imp in sorted(fi.items(), key=lambda x: float(x[1]), reverse=True): | |
| 1486 rows.append(f"<tr><td>{_escape(str(feat))}</td><td>{float(imp):.4f}</td></tr>") | |
| 1487 | |
| 1488 if not rows: | |
| 1489 return "<p>No feature importance values available.</p>" | |
| 1490 | |
| 1491 table_html = f""" | |
| 1492 <table class="performance-summary"> | |
| 1493 <thead> | |
| 1494 <tr> | |
| 1495 <th class="sortable">Feature</th> | |
| 1496 <th class="sortable">Importance</th> | |
| 1497 </tr> | |
| 1498 </thead> | |
| 1499 <tbody> | |
| 1500 {"".join(rows)} | |
| 1501 </tbody> | |
| 1502 </table> | |
| 1503 """ | |
| 1504 return table_html | |
| 1505 | |
| 1506 except Exception as e: | |
| 1507 return f"<p>Error building feature importance visualization: {e}</p>" | |
| 1508 | |
| 1509 | |
| 1510 def build_test_html_and_plots( | |
| 1511 predictor, | |
| 1512 problem_type: str, | |
| 1513 df_test: pd.DataFrame, | |
| 1514 label_column: str, | |
| 1515 tmpdir: str, | |
| 1516 threshold: Optional[float] = None, | |
| 1517 ) -> Tuple[str, List[str]]: | |
| 1518 """ | |
| 1519 Create a test-summary section (with a placeholder for metric rows) and a list of Plotly HTML divs. | |
| 1520 Returns: (html_template_with_{}, list_of_plot_divs) | |
| 1521 """ | |
| 1522 plots: List[str] = [] | |
| 1523 | |
| 1524 y_true = df_test[label_column].values | |
| 1525 classes = np.unique(y_true) | |
| 1526 | |
| 1527 # Try proba/labels where meaningful | |
| 1528 pred_labels = None | |
| 1529 pred_proba = None | |
| 1530 try: | |
| 1531 pred_labels = predictor.predict(df_test) | |
| 1532 except Exception: | |
| 1533 pass | |
| 1534 try: | |
| 1535 # MultiModalPredictor exposes predict_proba for classification problems. | |
| 1536 pred_proba = predictor.predict_proba(df_test) | |
| 1537 except Exception: | |
| 1538 pred_proba = None | |
| 1539 | |
| 1540 proba_arr = None | |
| 1541 if pred_proba is not None: | |
| 1542 if isinstance(pred_proba, pd.Series): | |
| 1543 proba_arr = pred_proba.to_numpy().reshape(-1, 1) | |
| 1544 elif isinstance(pred_proba, pd.DataFrame): | |
| 1545 proba_arr = pred_proba.to_numpy() | |
| 1546 else: | |
| 1547 proba_arr = np.asarray(pred_proba) | |
| 1548 | |
| 1549 # Thresholded labels for binary | |
| 1550 if problem_type == "binary" and threshold is not None and proba_arr is not None: | |
| 1551 pos_label, neg_label = classes.max(), classes.min() | |
| 1552 pos_scores = proba_arr.reshape(-1) if (proba_arr.ndim == 1 or proba_arr.shape[1] == 1) else proba_arr[:, -1] | |
| 1553 pred_labels = np.where(pos_scores >= float(threshold), pos_label, neg_label) | |
| 1554 | |
| 1555 # Confusion matrix / per-class now reflect thresholded labels | |
| 1556 if problem_type in ("binary", "multiclass") and pred_labels is not None: | |
| 1557 cm_title = "Confusion Matrix" | |
| 1558 if threshold is not None and problem_type == "binary": | |
| 1559 thr_str = f"{float(threshold):.3f}".rstrip("0").rstrip(".") | |
| 1560 cm_title = f"Confusion Matrix (Threshold = {thr_str})" | |
| 1561 fig_cm = generate_confusion_matrix_plot(y_true, pred_labels, title=cm_title) | |
| 1562 plots.append(plot_with_table_style_title(fig_cm, cm_title)) | |
| 1563 | |
| 1564 fig_pc = generate_per_class_metrics_plot(y_true, pred_labels, title="Per-Class Metrics") | |
| 1565 plots.append(plot_with_table_style_title(fig_pc, "Per-Class Metrics")) | |
| 1566 | |
| 1567 # ROC/PR where possible — choose positive-class scores safely | |
| 1568 pos_label = classes.max() # or set explicitly, e.g., 1 or "yes" | |
| 1569 | |
| 1570 if isinstance(pred_proba, pd.DataFrame): | |
| 1571 proba_arr = pred_proba.to_numpy() | |
| 1572 if pos_label in pred_proba.columns: | |
| 1573 pos_idx = list(pred_proba.columns).index(pos_label) | |
| 1574 else: | |
| 1575 pos_idx = -1 # fallback to last column | |
| 1576 elif isinstance(pred_proba, pd.Series): | |
| 1577 proba_arr = pred_proba.to_numpy().reshape(-1, 1) | |
| 1578 pos_idx = 0 | |
| 1579 else: | |
| 1580 proba_arr = np.asarray(pred_proba) if pred_proba is not None else None | |
| 1581 pos_idx = -1 if (proba_arr is not None and proba_arr.ndim == 2 and proba_arr.shape[1] > 1) else 0 | |
| 1582 | |
| 1583 if proba_arr is not None: | |
| 1584 y_bin = (y_true == pos_label).astype(int) | |
| 1585 pos_scores = ( | |
| 1586 proba_arr.reshape(-1) | |
| 1587 if proba_arr.ndim == 1 or proba_arr.shape[1] == 1 | |
| 1588 else proba_arr[:, pos_idx] | |
| 1589 ) | |
| 1590 | |
| 1591 fig_roc = generate_roc_curve_plot(y_bin, pos_scores, title="ROC Curve", marker_threshold=threshold) | |
| 1592 plots.append(plot_with_table_style_title(fig_roc, f"ROC Curve{'' if threshold is None else f' (marker at threshold={threshold:.2f})'}")) | |
| 1593 | |
| 1594 fig_pr = generate_pr_curve_plot(y_bin, pos_scores, title="Precision–Recall Curve", marker_threshold=threshold) | |
| 1595 plots.append(plot_with_table_style_title(fig_pr, f"Precision–Recall Curve{'' if threshold is None else f' (marker at threshold={threshold:.2f})'}")) | |
| 1596 | |
| 1597 # Additional diagnostics aligned with ImageLearner style | |
| 1598 if problem_type == "binary": | |
| 1599 conf_fig = plot_confidence_histogram(pos_scores, bins=20, title="Prediction Confidence (Test)") | |
| 1600 plots.append(plot_with_table_style_title(conf_fig, "Prediction Confidence (Test)")) | |
| 1601 else: | |
| 1602 conf_fig = plot_confidence_histogram(proba_arr, bins=20, title="Prediction Confidence (Top-1, Test)") | |
| 1603 plots.append(plot_with_table_style_title(conf_fig, "Prediction Confidence (Top-1, Test)")) | |
| 1604 | |
| 1605 if problem_type == "multiclass" and proba_arr is not None and proba_arr.ndim >= 2: | |
| 1606 fig_mc_roc = generate_multiclass_roc_curve_plot(y_true, proba_arr, classes, title="One-vs-Rest ROC (Test)") | |
| 1607 plots.append(plot_with_table_style_title(fig_mc_roc, "One-vs-Rest ROC (Test)")) | |
| 1608 | |
| 1609 # Regression visuals | |
| 1610 if problem_type == "regression": | |
| 1611 if pred_labels is None: | |
| 1612 pred_labels = predictor.predict(df_test) | |
| 1613 fig_sc = generate_scatter_plot(y_true, pred_labels, title="Predicted vs Actual") | |
| 1614 plots.append(plot_with_table_style_title(fig_sc, "Predicted vs Actual")) | |
| 1615 | |
| 1616 fig_res = generate_residual_plot(y_true, pred_labels, title="Residual Plot") | |
| 1617 plots.append(plot_with_table_style_title(fig_res, "Residual Plot")) | |
| 1618 | |
| 1619 fig_hist = generate_residual_histogram(y_true, pred_labels, title="Residual Histogram") | |
| 1620 plots.append(plot_with_table_style_title(fig_hist, "Residual Histogram")) | |
| 1621 | |
| 1622 fig_cal = generate_regression_calibration_plot(y_true, pred_labels, title="Regression Calibration") | |
| 1623 plots.append(plot_with_table_style_title(fig_cal, "Regression Calibration")) | |
| 1624 | |
| 1625 # Small HTML template with placeholder for metric rows the caller fills in | |
| 1626 test_html_template = """ | |
| 1627 <h2>Test Performance Summary</h2> | |
| 1628 <table class="performance-summary"> | |
| 1629 <thead><tr><th>Metric</th><th>Test</th></tr></thead> | |
| 1630 <tbody>{}</tbody> | |
| 1631 </table> | |
| 1632 """ | |
| 1633 return test_html_template, plots | |
| 1634 | |
| 1635 | |
| 1636 def build_feature_html( | |
| 1637 predictor, | |
| 1638 df_train: pd.DataFrame, | |
| 1639 label_column: str, | |
| 1640 include_modalities: bool = True, # ← NEW | |
| 1641 include_class_balance: bool = True, # ← NEW | |
| 1642 ) -> str: | |
| 1643 sections = [] | |
| 1644 | |
| 1645 # (Typical feature importance content…) | |
| 1646 fi_html = build_feature_importance_html(predictor, df_train, label_column) | |
| 1647 sections.append(f"<section class='section'><h2 class='section-title'>Feature Importance</h2><div class='card'>{fi_html}</div></section>") | |
| 1648 | |
| 1649 # Previously: Modalities & Inputs and/or Class Balance may have been here. | |
| 1650 # Only render them if flags are True. | |
| 1651 if include_modalities: | |
| 1652 from report_utils import build_modalities_html | |
| 1653 modalities_html = build_modalities_html(predictor, df_train, label_column) | |
| 1654 sections.append(f"<section class='section'><h2 class='section-title'>Modalities & Inputs</h2><div class='card'>{modalities_html}</div></section>") | |
| 1655 | |
| 1656 if include_class_balance: | |
| 1657 from report_utils import build_class_balance_html | |
| 1658 cb_html = build_class_balance_html(df_train, label_column) | |
| 1659 sections.append(f"<section class='section'><h2 class='section-title'>Class Balance (Train Full)</h2><div class='card'>{cb_html}</div></section>") | |
| 1660 | |
| 1661 return "\n".join(sections) | |
| 1662 | |
| 1663 | |
| 1664 def assemble_full_html_report( | |
| 1665 summary_html: str, | |
| 1666 train_html: str, | |
| 1667 test_html: str, | |
| 1668 plots: List[str], | |
| 1669 feature_html: str, | |
| 1670 ) -> str: | |
| 1671 """ | |
| 1672 Wrap the four tabs using utils.build_tabbed_html and return full HTML. | |
| 1673 """ | |
| 1674 # Append plots under the Test tab (already wrapped with titles) | |
| 1675 test_full = test_html + "".join(plots) | |
| 1676 | |
| 1677 tabs = build_tabbed_html(summary_html, train_html, test_full, feature_html, explainer_html=None) | |
| 1678 | |
| 1679 html_out = get_html_template() | |
| 1680 | |
| 1681 # 🔧 Ensure Plotly JS is available (we render plots with include_plotlyjs=False) | |
| 1682 html_out += '\n<script src="https://cdn.plot.ly/plotly-2.30.0.min.js"></script>\n' | |
| 1683 | |
| 1684 # Optional: centering tweaks | |
| 1685 html_out += """ | |
| 1686 <style> | |
| 1687 .plotly-center { display: flex; justify-content: center; } | |
| 1688 .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { margin: 0 auto !important; } | |
| 1689 .js-plotly-plot, .plotly-graph-div { margin-left: auto !important; margin-right: auto !important; } | |
| 1690 </style> | |
| 1691 """ | |
| 1692 # Help modal HTML/JS | |
| 1693 html_out += get_metrics_help_modal() | |
| 1694 | |
| 1695 html_out += tabs | |
| 1696 html_out += get_html_closing() | |
| 1697 return html_out |
