comparison plot_logic.py @ 0:375c36923da1 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
author goeckslab
date Tue, 09 Dec 2025 23:49:47 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:375c36923da1
1 from __future__ import annotations
2
3 import html
4 import os
5 from html import escape as _escape
6 from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
7
8 import matplotlib.pyplot as plt
9 import numpy as np
10 import pandas as pd
11 import plotly.express as px
12 import plotly.graph_objects as go
13 import shap
14 from feature_help_modal import get_metrics_help_modal
15 from report_utils import build_tabbed_html, get_html_closing, get_html_template
16 from sklearn.calibration import calibration_curve
17 from sklearn.metrics import (
18 auc,
19 average_precision_score,
20 classification_report,
21 confusion_matrix,
22 log_loss,
23 precision_recall_curve,
24 roc_auc_score,
25 roc_curve,
26 )
27 from sklearn.model_selection import learning_curve as skl_learning_curve
28 from sklearn.preprocessing import label_binarize
29
30 # =========================
31 # Utilities
32 # =========================
33
34
35 def plot_with_table_style_title(fig, title: str) -> str:
36 """
37 Render a Plotly figure with a report-style <h2> header so it matches the
38 green table section headers.
39 """
40 # kill Plotly’s built-in title
41 fig.update_layout(title=None)
42
43 # figure HTML without PlotlyJS (we load it once globally)
44 plot_html = fig.to_html(full_html=False, include_plotlyjs=False)
45
46 # use <h2> — your CSS already styles <h2> like the table headers
47 return f"""
48 <h2>{html.escape(title)}</h2>
49 <div class="plotly-center">{plot_html}</div>
50 """.strip()
51
52
53 def _save_plotly(fig: go.Figure, path: Optional[str]) -> None:
54 """
55 Save a Plotly figure. If `path` ends with `.html`, save interactive HTML.
56 If it ends with a raster extension (png/jpg/jpeg/webp), uses Kaleido.
57 If None, do nothing (caller may choose to display in notebook).
58 """
59 if not path:
60 return
61 ext = os.path.splitext(path)[1].lower()
62 if ext == ".html":
63 fig.write_html(path, include_plotlyjs="cdn", full_html=True)
64 else:
65 # Requires kaleido: pip install -U kaleido
66 fig.write_image(path)
67
68
69 def _save_matplotlib(path: Optional[str]) -> None:
70 """Save current Matplotlib figure if `path` is provided, else show()."""
71 if path:
72 plt.savefig(path, bbox_inches="tight")
73 plt.close()
74 else:
75 plt.show()
76
77 # =========================
78 # Classification Plots
79 # =========================
80
81
82 def generate_confusion_matrix_plot(
83 y_true,
84 y_pred,
85 title: str = "Confusion Matrix",
86 ) -> go.Figure:
87 y_true = np.asarray(y_true)
88 y_pred = np.asarray(y_pred)
89
90 # Class order (works for strings or numbers)
91 labels = pd.Index(np.unique(np.concatenate([y_true, y_pred])), dtype=object).tolist()
92 cm = confusion_matrix(y_true, y_pred, labels=labels)
93 max_val = cm.max() if cm.size else 0
94
95 # Use categorical axes by passing string labels for x/y
96 cats = [str(label) for label in labels]
97 total = int(cm.sum())
98
99 fig = go.Figure(
100 data=go.Heatmap(
101 z=cm,
102 x=cats, # categorical x
103 y=cats, # categorical y
104 colorscale="Blues",
105 showscale=True,
106 colorbar=dict(title="Count"),
107 xgap=2,
108 ygap=2,
109 hovertemplate="True=%{y}<br>Pred=%{x}<br>Count=%{z}<extra></extra>",
110 zmin=0
111 )
112 )
113
114 # Add annotations with count and percentage (all white text, matching sample_output.html)
115 annotations = []
116 for i in range(cm.shape[0]):
117 for j in range(cm.shape[1]):
118 val = int(cm[i, j])
119 pct = (val / total * 100) if total > 0 else 0
120 text_color = "white" if max_val and val > (max_val / 2) else "black"
121 # Count annotation (bold, bottom)
122 annotations.append(
123 dict(
124 x=cats[j],
125 y=cats[i],
126 text=f"<b>{val}</b>",
127 showarrow=False,
128 font=dict(color=text_color, size=14),
129 xanchor="center",
130 yanchor="bottom",
131 yshift=2
132 )
133 )
134 # Percentage annotation (top)
135 annotations.append(
136 dict(
137 x=cats[j],
138 y=cats[i],
139 text=f"{pct:.1f}%",
140 showarrow=False,
141 font=dict(color=text_color, size=13),
142 xanchor="center",
143 yanchor="top",
144 yshift=-2
145 )
146 )
147
148 fig.update_layout(
149 title=None,
150 xaxis_title="Predicted label",
151 yaxis_title="True label",
152 xaxis=dict(type="category"),
153 yaxis=dict(type="category", autorange="reversed"), # typical CM orientation
154 margin=dict(l=80, r=20, t=40, b=80),
155 template="plotly_white",
156 plot_bgcolor="white",
157 paper_bgcolor="white",
158 annotations=annotations
159 )
160 return fig
161
162
163 def generate_roc_curve_plot(
164 y_true_bin: np.ndarray,
165 y_score: np.ndarray,
166 title: str = "ROC Curve",
167 marker_threshold: float | None = None,
168 ) -> go.Figure:
169 y_true_bin = np.asarray(y_true_bin).astype(int).reshape(-1)
170 y_score = np.asarray(y_score).astype(float).reshape(-1)
171
172 fpr, tpr, thr = roc_curve(y_true_bin, y_score)
173 roc_auc = auc(fpr, tpr)
174
175 fig = go.Figure()
176 fig.add_trace(go.Scatter(
177 x=fpr, y=tpr, mode="lines",
178 name=f"ROC (AUC={roc_auc:.3f})",
179 line=dict(width=3)
180 ))
181
182 # 45° chance line (no legend to keep it clean)
183 fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode="lines",
184 line=dict(dash="dash", width=2, color="#888"), showlegend=False))
185
186 # Optional marker at the user threshold
187 if marker_threshold is not None and len(thr):
188 # roc_curve returns thresholds of same length as fpr/tpr; includes inf at idx 0
189 finite = np.isfinite(thr)
190 if np.any(finite):
191 idx_local = int(np.argmin(np.abs(thr[finite] - float(marker_threshold))))
192 idx = int(np.nonzero(finite)[0][idx_local]) # map back to original indices
193 x_m, y_m = float(fpr[idx]), float(tpr[idx])
194
195 fig.add_trace(
196 go.Scatter(
197 x=[x_m], y=[y_m],
198 mode="markers",
199 name=f"@ {float(marker_threshold):.2f}",
200 marker=dict(size=12, color="red", symbol="x")
201 )
202 )
203 fig.add_annotation(
204 x=0.02, y=0.98, xref="paper", yref="paper",
205 text=f"threshold = {float(marker_threshold):.2f}",
206 showarrow=False,
207 font=dict(color="black", size=12),
208 align="left"
209 )
210
211 fig.update_layout(
212 title=None,
213 xaxis_title="False Positive Rate",
214 yaxis_title="True Positive Rate",
215 template="plotly_white",
216 legend=dict(x=1, y=0, xanchor="right"),
217 margin=dict(l=60, r=20, t=60, b=60),
218 )
219 return fig
220
221
222 def generate_pr_curve_plot(
223 y_true_bin: np.ndarray,
224 y_score: np.ndarray,
225 title: str = "Precision–Recall Curve",
226 marker_threshold: float | None = None,
227 ) -> go.Figure:
228 y_true_bin = np.asarray(y_true_bin).astype(int).reshape(-1)
229 y_score = np.asarray(y_score).astype(float).reshape(-1)
230
231 precision, recall, thr = precision_recall_curve(y_true_bin, y_score)
232 pr_auc = auc(recall, precision)
233
234 fig = go.Figure()
235 fig.add_trace(go.Scatter(
236 x=recall, y=precision, mode="lines",
237 name=f"PR (AUC={pr_auc:.3f})",
238 line=dict(width=3)
239 ))
240
241 # Optional marker at the user threshold
242 if marker_threshold is not None and len(thr):
243 # In PR, thresholds has length len(precision)-1. The point for thr[j] is (recall[j+1], precision[j+1]).
244 j = int(np.argmin(np.abs(thr - float(marker_threshold))))
245 j = int(np.clip(j, 0, len(thr) - 1))
246 x_m, y_m = float(recall[j + 1]), float(precision[j + 1])
247
248 fig.add_trace(
249 go.Scatter(
250 x=[x_m], y=[y_m],
251 mode="markers",
252 name=f"@ {float(marker_threshold):.2f}",
253 marker=dict(size=12, color="red", symbol="x")
254 )
255 )
256 fig.add_annotation(
257 x=0.02, y=0.98, xref="paper", yref="paper",
258 text=f"threshold = {float(marker_threshold):.2f}",
259 showarrow=False,
260 font=dict(color="black", size=12),
261 align="left"
262 )
263
264 fig.update_layout(
265 title=None,
266 xaxis_title="Recall",
267 yaxis_title="Precision",
268 template="plotly_white",
269 legend=dict(x=1, y=0, xanchor="right"),
270 margin=dict(l=60, r=20, t=60, b=60),
271 )
272 return fig
273
274
275 def generate_calibration_plot(
276 y_true_bin: np.ndarray,
277 y_prob: np.ndarray,
278 n_bins: int = 10,
279 title: str = "Calibration Plot",
280 path: Optional[str] = None,
281 ) -> go.Figure:
282 """
283 Binary calibration curve (Plotly).
284 """
285 prob_true, prob_pred = calibration_curve(y_true_bin, y_prob, n_bins=n_bins, strategy="uniform")
286 fig = go.Figure()
287 fig.add_trace(go.Scatter(
288 x=prob_pred, y=prob_true, mode="lines+markers", name="Model",
289 line=dict(color="#1f77b4", width=3), marker=dict(size=7, color="#1f77b4")
290 ))
291 fig.add_trace(
292 go.Scatter(
293 x=[0, 1], y=[0, 1],
294 mode="lines",
295 line=dict(dash="dash", color="#808080", width=2),
296 name="Perfect"
297 )
298 )
299 fig.update_layout(
300 title=None,
301 xaxis_title="Predicted Probability",
302 yaxis_title="Observed Probability",
303 yaxis=dict(range=[0, 1]),
304 xaxis=dict(range=[0, 1]),
305 template="plotly_white",
306 margin=dict(l=60, r=40, t=50, b=50),
307 )
308 _save_plotly(fig, path)
309 return fig
310
311
312 def generate_threshold_plot(
313 y_true_bin: np.ndarray,
314 y_prob: np.ndarray,
315 title: str = "Threshold Plot",
316 user_threshold: float | None = None,
317 ) -> go.Figure:
318 y_true = np.asarray(y_true_bin, dtype=int).ravel()
319 p = np.asarray(y_prob, dtype=float).ravel()
320 p = np.nan_to_num(p, nan=0.0)
321 p = np.clip(p, 0.0, 1.0)
322
323 def _compute_metrics(thresholds: np.ndarray):
324 """Vectorized-ish helper to compute precision/recall/F1/queue rate arrays."""
325 prec, rec, f1, qrate = [], [], [], []
326 for t in thresholds:
327 yhat = (p >= t).astype(int)
328 tp = int(((yhat == 1) & (y_true == 1)).sum())
329 fp = int(((yhat == 1) & (y_true == 0)).sum())
330 fn = int(((yhat == 0) & (y_true == 1)).sum())
331
332 pr = tp / (tp + fp) if (tp + fp) else np.nan # undefined when no predicted positives
333 rc = tp / (tp + fn) if (tp + fn) else 0.0
334 f = (2 * pr * rc) / (pr + rc) if (pr + rc) and not np.isnan(pr) else 0.0
335 q = float(yhat.mean())
336
337 prec.append(pr)
338 rec.append(rc)
339 f1.append(f)
340 qrate.append(q)
341 return np.asarray(prec, dtype=float), np.asarray(rec, dtype=float), np.asarray(f1, dtype=float), np.asarray(qrate, dtype=float)
342
343 # Use uniform threshold grid for plotting (0 to 1 in steps of 0.01)
344 th = np.linspace(0.0, 1.0, 101)
345 prec, rec, f1_arr, qrate = _compute_metrics(th)
346
347 # Compute F1*-optimal threshold using actual score distribution (more precise than grid)
348 cand_th = np.unique(np.concatenate(([0.0, 1.0], p)))
349 # cap to a reasonable size by sampling if extremely large
350 if cand_th.size > 2000:
351 cand_th = np.linspace(0.0, 1.0, 2001)
352 _, _, f1_cand, _ = _compute_metrics(cand_th)
353
354 if np.all(np.isnan(f1_cand)):
355 t_star = 0.5 # fallback when no valid F1 can be computed
356 else:
357 f1_max = np.nanmax(f1_cand)
358 best_idxs = np.where(np.isclose(f1_cand, f1_max, equal_nan=False))[0]
359 # pick the middle of the best candidates to avoid biasing toward 0
360 best_idx = int(best_idxs[len(best_idxs) // 2])
361 t_star = float(cand_th[best_idx])
362
363 # Replace NaNs for plotting (set to 0 where precision is undefined)
364 prec_plot = np.nan_to_num(prec, nan=0.0)
365
366 fig = go.Figure()
367
368 # Precision (blue line)
369 fig.add_trace(go.Scatter(
370 x=th, y=prec_plot, mode="lines", name="Precision",
371 line=dict(width=3, color="#1f77b4"),
372 hovertemplate="Threshold=%{x:.3f}<br>Precision=%{y:.3f}<extra></extra>"
373 ))
374
375 # Recall (orange line)
376 fig.add_trace(go.Scatter(
377 x=th, y=rec, mode="lines", name="Recall",
378 line=dict(width=3, color="#ff7f0e"),
379 hovertemplate="Threshold=%{x:.3f}<br>Recall=%{y:.3f}<extra></extra>"
380 ))
381
382 # F1 (green line)
383 fig.add_trace(go.Scatter(
384 x=th, y=f1_arr, mode="lines", name="F1",
385 line=dict(width=3, color="#2ca02c"),
386 hovertemplate="Threshold=%{x:.3f}<br>F1=%{y:.3f}<extra></extra>"
387 ))
388
389 # Queue Rate (grey dashed line)
390 fig.add_trace(go.Scatter(
391 x=th, y=qrate, mode="lines", name="Queue Rate",
392 line=dict(width=2, color="#808080", dash="dash"),
393 hovertemplate="Threshold=%{x:.3f}<br>Queue Rate=%{y:.3f}<extra></extra>"
394 ))
395
396 # F1*-optimal threshold marker (dashed vertical line)
397 fig.add_vline(
398 x=t_star,
399 line_width=2,
400 line_dash="dash",
401 line_color="black",
402 annotation_text=f"t* = {t_star:.2f}",
403 annotation_position="top"
404 )
405
406 # User threshold (solid red line) if provided
407 if user_threshold is not None:
408 fig.add_vline(
409 x=float(user_threshold),
410 line_width=2,
411 line_color="red",
412 annotation_text=f"threshold = {float(user_threshold):.2f}",
413 annotation_position="top"
414 )
415
416 fig.update_layout(
417 title=None,
418 template="plotly_white",
419 xaxis=dict(
420 title="Discrimination Threshold",
421 range=[0, 1],
422 gridcolor="#e0e0e0",
423 showgrid=True,
424 zeroline=False
425 ),
426 yaxis=dict(
427 title="Score",
428 range=[0, 1],
429 gridcolor="#e0e0e0",
430 showgrid=True,
431 zeroline=False
432 ),
433 legend=dict(
434 orientation="h",
435 yanchor="bottom",
436 y=1.02,
437 xanchor="right",
438 x=1.0
439 ),
440 margin=dict(l=60, r=20, t=40, b=50),
441 plot_bgcolor="white",
442 paper_bgcolor="white",
443 )
444 return fig
445
446
447 def generate_per_class_metrics_plot(
448 y_true: Sequence,
449 y_pred: Sequence,
450 metrics: Sequence[str] = ("precision", "recall", "f1_score"),
451 title: str = "Classification Report",
452 path: Optional[str] = None,
453 ) -> go.Figure:
454 """
455 Per-class metrics heatmap (Plotly), similar to sklearn's classification report.
456 Rows = classes, columns = metrics; cell text shows the value (0–1).
457 """
458 # Map display names -> sklearn keys
459 key_map = {"f1_score": "f1-score", "precision": "precision", "recall": "recall"}
460 report = classification_report(
461 y_true, y_pred, output_dict=True, zero_division=0
462 )
463
464 # Order classes sensibly (numeric if possible, else lexical)
465 def _sort_key(x):
466 try:
467 return (0, float(x))
468 except Exception:
469 return (1, str(x))
470
471 # Use all classes seen in y_true or y_pred (so rows don't jump around)
472 uniq = sorted(set(list(y_true) + list(y_pred)), key=_sort_key)
473 classes = [str(c) for c in uniq]
474
475 # Build Z matrix (rows=classes, cols=metrics)
476 used_metrics = [key_map.get(m, m) for m in metrics]
477 z = []
478 for c in classes:
479 row = report.get(c, {})
480 z.append([float(row.get(m, 0.0) or 0.0) for m in used_metrics])
481 z = np.array(z, dtype=float)
482
483 # Pretty cell labels
484 z_text = [[f"{v:.2f}" for v in r] for r in z]
485
486 fig = go.Figure(
487 data=go.Heatmap(
488 z=z,
489 x=list(metrics), # keep display names ("precision", "recall", "f1_score")
490 y=classes, # classes as strings
491 colorscale="Reds",
492 zmin=0.0,
493 zmax=1.0,
494 colorbar=dict(title="Value"),
495 text=z_text,
496 texttemplate="%{text}",
497 hovertemplate="Class %{y}<br>%{x}: %{z:.2f}<extra></extra>",
498 )
499 )
500 fig.update_layout(
501 title=None,
502 xaxis_title="",
503 yaxis_title="Class",
504 template="plotly_white",
505 margin=dict(l=60, r=60, t=70, b=40),
506 )
507
508 _save_plotly(fig, path)
509 return fig
510
511
512 def generate_multiclass_roc_curve_plot(
513 y_true: Sequence,
514 y_prob: np.ndarray,
515 classes: Sequence,
516 title: str = "Multiclass ROC Curve",
517 path: Optional[str] = None,
518 ) -> go.Figure:
519 """
520 One-vs-rest ROC curves for multiclass (Plotly).
521 Handles binary passed as 2-column probs as well.
522 """
523 y_true = np.asarray(y_true)
524 y_prob = np.asarray(y_prob)
525
526 # Normalize to shape (n_samples, n_classes)
527 if y_prob.ndim == 1 or y_prob.shape[1] == 1:
528 y_prob = np.hstack([1 - y_prob.reshape(-1, 1), y_prob.reshape(-1, 1)])
529
530 y_true_bin = label_binarize(y_true, classes=classes)
531 if y_true_bin.shape[1] == 1 and y_prob.shape[1] == 2:
532 y_true_bin = np.hstack([1 - y_true_bin, y_true_bin])
533
534 if y_prob.shape[1] != y_true_bin.shape[1]:
535 raise ValueError(
536 f"Shape mismatch: y_prob has {y_prob.shape[1]} columns but y_true_bin has {y_true_bin.shape[1]}."
537 )
538
539 fig = go.Figure()
540 for i, cls in enumerate(classes):
541 fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_prob[:, i])
542 auc_val = roc_auc_score(y_true_bin[:, i], y_prob[:, i])
543 fig.add_trace(go.Scatter(x=fpr, y=tpr, mode="lines", name=f"{cls} (AUC {auc_val:.2f})"))
544
545 fig.add_trace(
546 go.Scatter(x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash"), showlegend=False)
547 )
548 fig.update_layout(
549 title=None,
550 xaxis_title="False Positive Rate",
551 yaxis_title="True Positive Rate",
552 template="plotly_white",
553 )
554 _save_plotly(fig, path)
555 return fig
556
557
558 def generate_multiclass_pr_curve_plot(
559 y_true: Sequence,
560 y_prob: np.ndarray,
561 classes: Optional[Sequence] = None,
562 title: str = "Precision–Recall Curve",
563 path: Optional[str] = None,
564 ) -> go.Figure:
565 """
566 Multiclass PR curves (Plotly). If classes is None or len==2, shows binary PR.
567 """
568 y_true = np.asarray(y_true)
569 y_prob = np.asarray(y_prob)
570 fig = go.Figure()
571
572 if not classes or len(classes) == 2:
573 precision, recall, _ = precision_recall_curve(y_true, y_prob[:, 1])
574 ap = average_precision_score(y_true, y_prob[:, 1])
575 fig.add_trace(go.Scatter(x=recall, y=precision, mode="lines", name=f"AP = {ap:.2f}"))
576 else:
577 for i, cls in enumerate(classes):
578 y_true_bin = (y_true == cls).astype(int)
579 y_prob_cls = y_prob[:, i]
580 precision, recall, _ = precision_recall_curve(y_true_bin, y_prob_cls)
581 ap = average_precision_score(y_true_bin, y_prob_cls)
582 fig.add_trace(go.Scatter(x=recall, y=precision, mode="lines", name=f"{cls} (AP {ap:.2f})"))
583
584 fig.update_layout(
585 title=None,
586 xaxis_title="Recall",
587 yaxis_title="Precision",
588 yaxis=dict(range=[0, 1]),
589 xaxis=dict(range=[0, 1]),
590 template="plotly_white",
591 )
592 _save_plotly(fig, path)
593 return fig
594
595
596 def generate_metric_comparison_bar(
597 metrics_scores: Mapping[str, Sequence[float]],
598 phases: Sequence[str] = ("train", "val", "test"),
599 title: str = "Metric Comparison Across Phases",
600 path: Optional[str] = None,
601 ) -> go.Figure:
602 """
603 Grouped bar chart comparing metrics across phases (Plotly).
604 metrics_scores: {metric_name: [train, val, test]}
605 """
606 df = pd.DataFrame(metrics_scores, index=phases).T.reset_index().rename(columns={"index": "Metric"})
607 df_m = df.melt(id_vars="Metric", var_name="Phase", value_name="Score")
608 fig = px.bar(df_m, x="Metric", y="Score", color="Phase", barmode="group", title=None)
609 ymax = max(1.0, df_m["Score"].max() * 1.05)
610 fig.update_yaxes(range=[0, ymax])
611 fig.update_layout(template="plotly_white")
612 _save_plotly(fig, path)
613 return fig
614
615 # =========================
616 # Regression Plots
617 # =========================
618
619
620 def generate_scatter_plot(
621 y_true: Sequence[float],
622 y_pred: Sequence[float],
623 title: str = "Predicted vs Actual",
624 path: Optional[str] = None,
625 ) -> go.Figure:
626 """
627 Predicted vs. Actual scatter with y=x reference (Plotly).
628 """
629 y_true = np.asarray(y_true)
630 y_pred = np.asarray(y_pred)
631 vmin = float(min(np.min(y_true), np.min(y_pred)))
632 vmax = float(max(np.max(y_true), np.max(y_pred)))
633
634 fig = px.scatter(x=y_true, y=y_pred, opacity=0.6, labels={"x": "Actual", "y": "Predicted"}, title=None)
635 fig.add_trace(go.Scatter(x=[vmin, vmax], y=[vmin, vmax], mode="lines", line=dict(dash="dash"), name="Ideal"))
636 fig.update_layout(template="plotly_white")
637 _save_plotly(fig, path)
638 return fig
639
640
641 def generate_residual_plot(
642 y_true: Sequence[float],
643 y_pred: Sequence[float],
644 title: str = "Residual Plot",
645 path: Optional[str] = None,
646 ) -> go.Figure:
647 """
648 Residuals vs Predicted (Plotly).
649 """
650 y_true = np.asarray(y_true)
651 y_pred = np.asarray(y_pred)
652 residuals = y_true - y_pred
653
654 fig = px.scatter(x=y_pred, y=residuals, opacity=0.6,
655 labels={"x": "Predicted", "y": "Residual (Actual - Predicted)"},
656 title=None)
657 fig.add_hline(y=0, line_dash="dash")
658 fig.update_layout(template="plotly_white")
659 _save_plotly(fig, path)
660 return fig
661
662
663 def generate_residual_histogram(
664 y_true: Sequence[float],
665 y_pred: Sequence[float],
666 bins: int = 30,
667 title: str = "Residual Histogram",
668 path: Optional[str] = None,
669 ) -> go.Figure:
670 """
671 Residuals histogram (Plotly).
672 """
673 residuals = np.asarray(y_true) - np.asarray(y_pred)
674 fig = px.histogram(x=residuals, nbins=bins, labels={"x": "Residual"}, title=None)
675 fig.update_layout(yaxis_title="Frequency", template="plotly_white")
676 _save_plotly(fig, path)
677 return fig
678
679
680 def generate_regression_calibration_plot(
681 y_true: Sequence[float],
682 y_pred: Sequence[float],
683 num_bins: int = 10,
684 title: str = "Regression Calibration Plot",
685 path: Optional[str] = None,
686 ) -> go.Figure:
687 """
688 Binned Actual vs Predicted means (Plotly).
689 """
690 y_true = np.asarray(y_true)
691 y_pred = np.asarray(y_pred)
692
693 order = np.argsort(y_pred)
694 y_true_sorted = y_true[order]
695 y_pred_sorted = y_pred[order]
696
697 bins = np.array_split(np.arange(len(y_pred_sorted)), num_bins)
698 bin_means_pred = [float(np.mean(y_pred_sorted[idx])) for idx in bins if len(idx)]
699 bin_means_true = [float(np.mean(y_true_sorted[idx])) for idx in bins if len(idx)]
700
701 vmin = float(min(np.min(y_pred), np.min(y_true)))
702 vmax = float(max(np.max(y_pred), np.max(y_true)))
703
704 fig = go.Figure()
705 fig.add_trace(go.Scatter(x=bin_means_pred, y=bin_means_true, mode="lines+markers",
706 name="Binned Actual vs Predicted"))
707 fig.add_trace(go.Scatter(x=[vmin, vmax], y=[vmin, vmax], mode="lines", line=dict(dash="dash"),
708 name="Ideal"))
709 fig.update_layout(
710 title=None,
711 xaxis_title="Mean Predicted per bin",
712 yaxis_title="Mean Actual per bin",
713 template="plotly_white",
714 )
715 _save_plotly(fig, path)
716 return fig
717
718 # =========================
719 # Confidence / Diagnostics
720 # =========================
721
722
723 def plot_error_vs_confidence(
724 y_true: Union[Sequence[int], np.ndarray],
725 y_proba: Union[Sequence[float], np.ndarray],
726 n_bins: int = 10,
727 title: str = "Error vs Confidence",
728 path: Optional[str] = None,
729 ) -> go.Figure:
730 """
731 Error rate vs confidence (binary), confidence=max(p, 1-p). Plotly.
732 """
733 y_true = np.asarray(y_true)
734 y_proba = np.asarray(y_proba).reshape(-1)
735 y_pred = (y_proba >= 0.5).astype(int)
736 confidence = np.maximum(y_proba, 1 - y_proba)
737 error = (y_pred != y_true).astype(int)
738
739 bins = np.linspace(0.0, 1.0, n_bins + 1)
740 idx = np.digitize(confidence, bins, right=True)
741
742 centers, err_rates = [], []
743 for i in range(1, len(bins)):
744 mask = (idx == i)
745 if mask.any():
746 centers.append(float(confidence[mask].mean()))
747 err_rates.append(float(error[mask].mean()))
748
749 fig = go.Figure()
750 fig.add_trace(go.Scatter(x=centers, y=err_rates, mode="lines+markers", name="Error rate"))
751 fig.update_layout(
752 title=None,
753 xaxis_title="Confidence (max predicted probability)",
754 yaxis_title="Error Rate",
755 yaxis=dict(range=[0, 1]),
756 template="plotly_white",
757 )
758 _save_plotly(fig, path)
759 return fig
760
761
762 def plot_confidence_histogram(
763 y_proba: np.ndarray,
764 bins: int = 20,
765 title: str = "Confidence Histogram",
766 path: Optional[str] = None,
767 ) -> go.Figure:
768 """
769 Histogram of max predicted probabilities (Plotly).
770 Works for binary (n_samples,) or (n_samples,2) and multiclass (n_samples,C).
771 """
772 y_proba = np.asarray(y_proba)
773 if y_proba.ndim == 1:
774 confidences = np.maximum(y_proba, 1 - y_proba)
775 else:
776 confidences = np.max(y_proba, axis=1)
777
778 fig = px.histogram(
779 x=confidences,
780 nbins=bins,
781 range_x=(0, 1),
782 histnorm="percent",
783 labels={"x": "Confidence (max predicted probability)", "y": "Percent of samples (%)"},
784 title=None,
785 )
786 if fig.data:
787 fig.update_traces(hovertemplate="Conf=%{x:.2f}<br>%{y:.2f}%<extra></extra>")
788 fig.update_layout(yaxis_title="Percent of samples (%)", template="plotly_white")
789 _save_plotly(fig, path)
790 return fig
791
792 # =========================
793 # Learning Curve
794 # =========================
795
796
797 def generate_learning_curve_from_predictions(
798 y_true,
799 y_pred=None,
800 y_proba=None,
801 classes=None,
802 metric: str = "accuracy",
803 train_fracs: np.ndarray = np.linspace(0.1, 1.0, 10),
804 n_repeats: int = 5,
805 seed: int = 42,
806 title: str = "Learning Curve",
807 path: str | None = None,
808 return_stats: bool = False,
809 ) -> Union[go.Figure, tuple[list[int], list[float], list[float]]]:
810 rng = np.random.default_rng(seed)
811 y_true = np.asarray(y_true)
812 N = len(y_true)
813
814 if metric == "accuracy" and y_pred is None:
815 raise ValueError("accuracy curve requires y_pred")
816 if metric == "log_loss" and y_proba is None:
817 raise ValueError("log_loss curve requires y_proba")
818
819 if y_proba is not None:
820 y_proba = np.asarray(y_proba)
821 if y_pred is not None:
822 y_pred = np.asarray(y_pred)
823
824 sizes = (np.clip((train_fracs * N).astype(int), 1, N)).tolist()
825 means, stds = [], []
826 for n in sizes:
827 vals = []
828 for _ in range(n_repeats):
829 idx = rng.choice(N, size=n, replace=False)
830 if metric == "accuracy":
831 vals.append(float((y_true[idx] == y_pred[idx]).mean()))
832 else:
833 if y_proba.ndim == 1:
834 p = y_proba[idx]
835 pp = np.column_stack([1 - p, p])
836 else:
837 pp = y_proba[idx]
838 vals.append(float(log_loss(y_true[idx], pp, labels=None if classes is None else classes)))
839 means.append(np.mean(vals))
840 stds.append(np.std(vals))
841
842 if return_stats:
843 return sizes, means, stds
844
845 fig = go.Figure()
846 fig.add_trace(go.Scatter(
847 x=sizes, y=means, mode="lines+markers", name="Train",
848 line=dict(width=3, shape="spline"), marker=dict(size=7),
849 error_y=dict(type="data", array=stds, visible=True)
850 ))
851 fig.update_layout(
852 title=None,
853 template="plotly_white",
854 xaxis=dict(title="epoch" if metric == "log_loss" else "samples", gridcolor="#eee"),
855 yaxis=dict(title=("loss" if metric == "log_loss" else "accuracy"), gridcolor="#eee"),
856 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0),
857 margin=dict(l=50, r=20, t=60, b=50),
858 )
859 if path:
860 _save_plotly(fig, path)
861 return fig
862
863
864 def build_train_html_and_plots(
865 predictor,
866 problem_type: str,
867 df_train: pd.DataFrame,
868 label_column: str,
869 tmpdir: str,
870 df_val: Optional[pd.DataFrame] = None,
871 seed: int = 42,
872 perf_table_html: str | None = None,
873 threshold: Optional[float] = None,
874 section_tile: str = "Training Diagnostics",
875 ) -> str:
876 y_true = df_train[label_column].values
877 y_true_val = df_val[label_column].values if df_val is not None else None
878 # predictions on TRAIN
879 pred_labels, pred_proba = None, None
880 try:
881 pred_labels = predictor.predict(df_train)
882 except Exception:
883 pass
884 try:
885 proba_raw = predictor.predict_proba(df_train)
886 pred_proba = proba_raw.to_numpy() if isinstance(proba_raw, (pd.Series, pd.DataFrame)) else np.asarray(proba_raw)
887 except Exception:
888 pred_proba = None
889
890 # predictions on VAL (if provided)
891 pred_labels_val, pred_proba_val = None, None
892 if df_val is not None:
893 try:
894 pred_labels_val = predictor.predict(df_val)
895 except Exception:
896 pred_labels_val = None
897 try:
898 proba_raw_val = predictor.predict_proba(df_val)
899 pred_proba_val = proba_raw_val.to_numpy() if isinstance(proba_raw_val, (pd.Series, pd.DataFrame)) else np.asarray(proba_raw_val)
900 except Exception:
901 pred_proba_val = None
902
903 pos_scores_train: Optional[np.ndarray] = None
904 pos_scores_val: Optional[np.ndarray] = None
905 if problem_type == "binary":
906 if pred_proba is not None:
907 pos_scores_train = (
908 pred_proba.reshape(-1)
909 if pred_proba.ndim == 1 or (pred_proba.ndim == 2 and pred_proba.shape[1] == 1)
910 else pred_proba[:, -1]
911 )
912 if pred_proba_val is not None:
913 pos_scores_val = (
914 pred_proba_val.reshape(-1)
915 if pred_proba_val.ndim == 1 or (pred_proba_val.ndim == 2 and pred_proba_val.shape[1] == 1)
916 else pred_proba_val[:, -1]
917 )
918
919 # Collect plots then append in desired order
920 perf_card = f"<div class='card'>{perf_table_html}</div>" if perf_table_html else None
921 acc_plot = loss_plot = None
922 cm_train = pc_train = cm_val = pc_val = None
923 threshold_val_plot = None
924 roc_combined = pr_combined = cal_combined = None
925 mc_roc_val = None
926 conf_train = conf_val = None
927 bar_train = bar_val = None
928
929 # 1) Learning Curve — Accuracy
930 if problem_type in ("binary", "multiclass"):
931 acc_fig = go.Figure()
932 added_acc = False
933 if pred_labels is not None:
934 train_sizes, train_means, train_stds = generate_learning_curve_from_predictions(
935 y_true=y_true,
936 y_pred=np.asarray(pred_labels),
937 metric="accuracy",
938 title="Learning Curves — Label Accuracy",
939 seed=seed,
940 return_stats=True,
941 )
942 acc_fig.add_trace(go.Scatter(
943 x=train_sizes, y=train_means, mode="lines+markers", name="Train",
944 line=dict(color="#1f77b4", width=3, shape="spline"), marker=dict(size=7),
945 error_y=dict(type="data", array=train_stds, visible=True),
946 ))
947 added_acc = True
948 if pred_labels_val is not None and y_true_val is not None:
949 val_sizes, val_means, val_stds = generate_learning_curve_from_predictions(
950 y_true=y_true_val,
951 y_pred=np.asarray(pred_labels_val),
952 metric="accuracy",
953 title="Learning Curves — Label Accuracy",
954 seed=seed,
955 return_stats=True,
956 )
957 acc_fig.add_trace(go.Scatter(
958 x=val_sizes, y=val_means, mode="lines+markers", name="Validation",
959 line=dict(color="#ff7f0e", width=3, shape="spline"), marker=dict(size=7),
960 error_y=dict(type="data", array=val_stds, visible=True),
961 ))
962 added_acc = True
963 if added_acc:
964 acc_fig.update_layout(
965 title=None,
966 template="plotly_white",
967 xaxis=dict(title="samples", gridcolor="#eee"),
968 yaxis=dict(title="accuracy", gridcolor="#eee"),
969 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0),
970 margin=dict(l=50, r=20, t=60, b=50),
971 )
972 acc_plot = plot_with_table_style_title(acc_fig, "Learning Curves — Label Accuracy")
973
974 # 2) Learning Curve — Loss
975 if problem_type in ("binary", "multiclass"):
976 classes = np.unique(y_true)
977 loss_fig = go.Figure()
978 added_loss = False
979 if pred_proba is not None:
980 pp = pred_proba.reshape(-1) if pred_proba.ndim == 1 or (pred_proba.ndim == 2 and pred_proba.shape[1] == 1) else pred_proba
981 train_sizes, train_means, train_stds = generate_learning_curve_from_predictions(
982 y_true=y_true,
983 y_proba=pp,
984 classes=classes,
985 metric="log_loss",
986 title="Learning Curves — Label Loss",
987 seed=seed,
988 return_stats=True,
989 )
990 loss_fig.add_trace(go.Scatter(
991 x=train_sizes, y=train_means, mode="lines+markers", name="Train",
992 line=dict(color="#1f77b4", width=3, shape="spline"), marker=dict(size=7),
993 error_y=dict(type="data", array=train_stds, visible=True),
994 ))
995 added_loss = True
996 if pred_proba_val is not None and y_true_val is not None:
997 pp_val = pred_proba_val.reshape(-1) if pred_proba_val.ndim == 1 or (pred_proba_val.ndim == 2 and pred_proba_val.shape[1] == 1) else pred_proba_val
998 val_sizes, val_means, val_stds = generate_learning_curve_from_predictions(
999 y_true=y_true_val,
1000 y_proba=pp_val,
1001 classes=classes,
1002 metric="log_loss",
1003 title="Learning Curves — Label Loss",
1004 seed=seed,
1005 return_stats=True,
1006 )
1007 loss_fig.add_trace(go.Scatter(
1008 x=val_sizes, y=val_means, mode="lines+markers", name="Validation",
1009 line=dict(color="#ff7f0e", width=3, shape="spline"), marker=dict(size=7),
1010 error_y=dict(type="data", array=val_stds, visible=True),
1011 ))
1012 added_loss = True
1013 if added_loss:
1014 loss_fig.update_layout(
1015 title=None,
1016 template="plotly_white",
1017 xaxis=dict(title="epoch", gridcolor="#eee"),
1018 yaxis=dict(title="loss", gridcolor="#eee"),
1019 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0),
1020 margin=dict(l=50, r=20, t=60, b=50),
1021 )
1022 loss_plot = plot_with_table_style_title(loss_fig, "Learning Curves — Label Loss")
1023
1024 # Confusion matrices & per-class metrics
1025 cm_train = pc_train = cm_val = pc_val = None
1026
1027 # Probability diagnostics (binary)
1028 if problem_type == "binary":
1029 # Combined Calibration (Train/Val)
1030 cal_fig = go.Figure()
1031 added_cal = False
1032 if pos_scores_train is not None:
1033 y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int)
1034 prob_true, prob_pred = calibration_curve(y_bin_train, pos_scores_train, n_bins=10, strategy="uniform")
1035 cal_fig.add_trace(go.Scatter(
1036 x=prob_pred, y=prob_true, mode="lines+markers",
1037 name="Train",
1038 line=dict(color="#1f77b4", width=3),
1039 marker=dict(size=7, color="#1f77b4"),
1040 ))
1041 added_cal = True
1042 if pos_scores_val is not None and y_true_val is not None:
1043 y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int)
1044 prob_true_v, prob_pred_v = calibration_curve(y_bin_val, pos_scores_val, n_bins=10, strategy="uniform")
1045 cal_fig.add_trace(go.Scatter(
1046 x=prob_pred_v, y=prob_true_v, mode="lines+markers",
1047 name="Validation",
1048 line=dict(color="#ff7f0e", width=3),
1049 marker=dict(size=7, color="#ff7f0e"),
1050 ))
1051 added_cal = True
1052 if added_cal:
1053 cal_fig.add_trace(go.Scatter(
1054 x=[0, 1], y=[0, 1],
1055 mode="lines",
1056 line=dict(dash="dash", color="#808080", width=2),
1057 name="Perfect",
1058 showlegend=True,
1059 ))
1060 cal_fig.update_layout(
1061 title=None,
1062 xaxis_title="Predicted Probability",
1063 yaxis_title="Observed Probability",
1064 xaxis=dict(range=[0, 1]),
1065 yaxis=dict(range=[0, 1]),
1066 template="plotly_white",
1067 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0),
1068 margin=dict(l=60, r=40, t=50, b=50),
1069 )
1070 cal_combined = plot_with_table_style_title(cal_fig, "Calibration Curve (Train vs Validation)")
1071
1072 # Combined ROC (Train/Val)
1073 roc_fig = go.Figure()
1074 added_roc = False
1075 if pos_scores_train is not None:
1076 y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int)
1077 fpr_tr, tpr_tr, thr_tr = roc_curve(y_bin_train, pos_scores_train)
1078 roc_fig.add_trace(go.Scatter(
1079 x=fpr_tr, y=tpr_tr, mode="lines",
1080 name="Train",
1081 line=dict(color="#1f77b4", width=3),
1082 ))
1083 if threshold is not None and np.isfinite(thr_tr).any():
1084 finite = np.isfinite(thr_tr)
1085 idx_local = int(np.argmin(np.abs(thr_tr[finite] - float(threshold))))
1086 idx = int(np.nonzero(finite)[0][idx_local])
1087 roc_fig.add_trace(go.Scatter(
1088 x=[fpr_tr[idx]], y=[tpr_tr[idx]],
1089 mode="markers",
1090 name="Train @ threshold",
1091 marker=dict(size=12, color="#1f77b4", symbol="x")
1092 ))
1093 added_roc = True
1094 if pos_scores_val is not None and y_true_val is not None:
1095 y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int)
1096 fpr_v, tpr_v, thr_v = roc_curve(y_bin_val, pos_scores_val)
1097 roc_fig.add_trace(go.Scatter(
1098 x=fpr_v, y=tpr_v, mode="lines",
1099 name="Validation",
1100 line=dict(color="#ff7f0e", width=3),
1101 ))
1102 if threshold is not None and np.isfinite(thr_v).any():
1103 finite = np.isfinite(thr_v)
1104 idx_local = int(np.argmin(np.abs(thr_v[finite] - float(threshold))))
1105 idx = int(np.nonzero(finite)[0][idx_local])
1106 roc_fig.add_trace(go.Scatter(
1107 x=[fpr_v[idx]], y=[tpr_v[idx]],
1108 mode="markers",
1109 name="Val @ threshold",
1110 marker=dict(size=12, color="#ff7f0e", symbol="x")
1111 ))
1112 added_roc = True
1113 if added_roc:
1114 roc_fig.add_trace(go.Scatter(
1115 x=[0, 1], y=[0, 1], mode="lines",
1116 line=dict(dash="dash", width=2, color="#808080"),
1117 showlegend=False
1118 ))
1119 roc_fig.update_layout(
1120 title=None,
1121 xaxis_title="False Positive Rate",
1122 yaxis_title="True Positive Rate",
1123 template="plotly_white",
1124 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0),
1125 margin=dict(l=60, r=20, t=60, b=60),
1126 )
1127 roc_combined = plot_with_table_style_title(roc_fig, "ROC Curve (Train vs Validation)")
1128
1129 # Combined PR (Train/Val)
1130 pr_fig = go.Figure()
1131 added_pr = False
1132 if pos_scores_train is not None:
1133 y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int)
1134 prec_tr, rec_tr, thr_tr = precision_recall_curve(y_bin_train, pos_scores_train)
1135 pr_auc_tr = auc(rec_tr, prec_tr)
1136 pr_fig.add_trace(go.Scatter(
1137 x=rec_tr, y=prec_tr, mode="lines",
1138 name=f"Train (AUC={pr_auc_tr:.3f})",
1139 line=dict(color="#1f77b4", width=3),
1140 ))
1141 if threshold is not None and len(thr_tr):
1142 j = int(np.argmin(np.abs(thr_tr - float(threshold))))
1143 j = int(np.clip(j, 0, len(thr_tr) - 1))
1144 pr_fig.add_trace(go.Scatter(
1145 x=[rec_tr[j + 1]], y=[prec_tr[j + 1]],
1146 mode="markers",
1147 name="Train @ threshold",
1148 marker=dict(size=12, color="#1f77b4", symbol="x")
1149 ))
1150 added_pr = True
1151 if pos_scores_val is not None and y_true_val is not None:
1152 y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int)
1153 prec_v, rec_v, thr_v = precision_recall_curve(y_bin_val, pos_scores_val)
1154 pr_auc_v = auc(rec_v, prec_v)
1155 pr_fig.add_trace(go.Scatter(
1156 x=rec_v, y=prec_v, mode="lines",
1157 name=f"Validation (AUC={pr_auc_v:.3f})",
1158 line=dict(color="#ff7f0e", width=3),
1159 ))
1160 if threshold is not None and len(thr_v):
1161 j = int(np.argmin(np.abs(thr_v - float(threshold))))
1162 j = int(np.clip(j, 0, len(thr_v) - 1))
1163 pr_fig.add_trace(go.Scatter(
1164 x=[rec_v[j + 1]], y=[prec_v[j + 1]],
1165 mode="markers",
1166 name="Val @ threshold",
1167 marker=dict(size=12, color="#ff7f0e", symbol="x")
1168 ))
1169 added_pr = True
1170 if added_pr:
1171 pr_fig.update_layout(
1172 title=None,
1173 xaxis_title="Recall",
1174 yaxis_title="Precision",
1175 template="plotly_white",
1176 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0),
1177 margin=dict(l=60, r=20, t=60, b=60),
1178 )
1179 pr_combined = plot_with_table_style_title(pr_fig, "Precision–Recall Curve (Train vs Validation)")
1180
1181 if pos_scores_val is not None and y_true_val is not None:
1182 y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int)
1183 fig_thr_val = generate_threshold_plot(y_true_bin=y_bin_val, y_prob=pos_scores_val, title="Threshold Plot (Validation)",
1184 user_threshold=threshold)
1185 threshold_val_plot = plot_with_table_style_title(fig_thr_val, "Threshold Plot (Validation)")
1186
1187 # Multiclass OVR ROC (validation)
1188 if problem_type == "multiclass" and pred_proba_val is not None and pred_proba_val.ndim >= 2 and y_true_val is not None:
1189 classes_val = np.unique(y_true_val)
1190 fig_mc_roc_val = generate_multiclass_roc_curve_plot(y_true_val, pred_proba_val, classes_val, title="One-vs-Rest ROC (Validation)")
1191 mc_roc_val = plot_with_table_style_title(fig_mc_roc_val, "One-vs-Rest ROC (Validation)")
1192
1193 # Prediction Confidence Histogram (train/val)
1194 conf_train = conf_val = None
1195
1196 # Per-class accuracy bars
1197 if problem_type in ("binary", "multiclass") and pred_labels is not None:
1198 classes_for_bar = pd.Index(np.unique(y_true), dtype=object).tolist()
1199 acc_vals = []
1200 for c in classes_for_bar:
1201 mask = y_true == c
1202 acc_vals.append(float((np.asarray(pred_labels)[mask] == c).mean()) if mask.any() else 0.0)
1203 bar_fig = go.Figure(data=go.Bar(x=[str(c) for c in classes_for_bar], y=acc_vals, marker_color="#1f77b4"))
1204 bar_fig.update_layout(
1205 title=None,
1206 template="plotly_white",
1207 xaxis=dict(title="Label", gridcolor="#eee"),
1208 yaxis=dict(title="Accuracy", gridcolor="#eee", range=[0, 1]),
1209 margin=dict(l=50, r=20, t=60, b=50),
1210 )
1211 bar_train = plot_with_table_style_title(bar_fig, "Per-Class Training Accuracy")
1212 if problem_type in ("binary", "multiclass") and pred_labels_val is not None and y_true_val is not None:
1213 classes_for_bar_val = pd.Index(np.unique(y_true_val), dtype=object).tolist()
1214 acc_vals_val = []
1215 for c in classes_for_bar_val:
1216 mask = y_true_val == c
1217 acc_vals_val.append(float((np.asarray(pred_labels_val)[mask] == c).mean()) if mask.any() else 0.0)
1218 bar_fig_val = go.Figure(data=go.Bar(x=[str(c) for c in classes_for_bar_val], y=acc_vals_val, marker_color="#ff7f0e"))
1219 bar_fig_val.update_layout(
1220 title=None,
1221 template="plotly_white",
1222 xaxis=dict(title="Label", gridcolor="#eee"),
1223 yaxis=dict(title="Accuracy", gridcolor="#eee", range=[0, 1]),
1224 margin=dict(l=50, r=20, t=60, b=50),
1225 )
1226 bar_val = plot_with_table_style_title(bar_fig_val, "Per-Class Validation Accuracy")
1227
1228 # Assemble in requested order
1229 pieces: list[str] = []
1230 if perf_card:
1231 pieces.append(perf_card)
1232 for block in (threshold_val_plot, roc_combined, pr_combined):
1233 if block:
1234 pieces.append(block)
1235 # Remaining plots (keep existing order)
1236 for block in (cal_combined, cm_train, pc_train, cm_val, pc_val, mc_roc_val, conf_train, conf_val, bar_train, bar_val):
1237 if block:
1238 pieces.append(block)
1239 # Learning curves should appear last in the tab
1240 for block in (acc_plot, loss_plot):
1241 if block:
1242 pieces.append(block)
1243
1244 if not pieces:
1245 return "<h2>Training Diagnostics</h2><p><em>No training diagnostics available for this run.</em></p>"
1246
1247 return "<h2>Train and Validation Performance Summary</h2>" + "".join(pieces)
1248
1249
1250 def generate_learning_curve(
1251 estimator,
1252 X,
1253 y,
1254 scoring: str = "r2",
1255 cv_folds: int = 5,
1256 n_jobs: int = -1,
1257 train_sizes: np.ndarray = np.linspace(0.1, 1.0, 10),
1258 title: str = "Learning Curve",
1259 path: Optional[str] = None,
1260 ) -> go.Figure:
1261 """
1262 Learning curve using sklearn.learning_curve, visualized with Plotly.
1263 """
1264 sizes, train_scores, test_scores = skl_learning_curve(
1265 estimator, X, y, cv=cv_folds, scoring=scoring, n_jobs=n_jobs, train_sizes=train_sizes
1266 )
1267 train_mean = train_scores.mean(axis=1)
1268 train_std = train_scores.std(axis=1)
1269 test_mean = test_scores.mean(axis=1)
1270 test_std = test_scores.std(axis=1)
1271
1272 fig = go.Figure()
1273 fig.add_trace(go.Scatter(
1274 x=sizes, y=train_mean, mode="lines+markers", name="Training score",
1275 error_y=dict(type="data", array=train_std, visible=True)
1276 ))
1277 fig.add_trace(go.Scatter(
1278 x=sizes, y=test_mean, mode="lines+markers", name="CV score",
1279 error_y=dict(type="data", array=test_std, visible=True)
1280 ))
1281 fig.update_layout(
1282 title=None,
1283 xaxis_title="Training examples",
1284 yaxis_title=scoring,
1285 template="plotly_white",
1286 )
1287 _save_plotly(fig, path)
1288 return fig
1289
1290 # =========================
1291 # SHAP (Matplotlib-based)
1292 # =========================
1293
1294
1295 def generate_shap_summary_plot(
1296 shap_values, features: pd.DataFrame, title: str = "SHAP Summary Plot", path: Optional[str] = None
1297 ) -> None:
1298 """
1299 SHAP summary plot (Matplotlib). SHAP's interactive support with Plotly is limited;
1300 keep matplotlib for clarity and stability.
1301 """
1302 plt.figure(figsize=(10, 8))
1303 shap.summary_plot(shap_values, features, show=False)
1304 plt.title(title)
1305 _save_matplotlib(path)
1306
1307
1308 def generate_shap_force_plot(
1309 explainer, instance: pd.DataFrame, title: str = "SHAP Force Plot", path: Optional[str] = None
1310 ) -> None:
1311 """
1312 SHAP force plot (Matplotlib).
1313 """
1314 shap_values = explainer(instance)
1315 plt.figure(figsize=(10, 4))
1316 shap.plots.force(shap_values[0], show=False)
1317 plt.title(title)
1318 _save_matplotlib(path)
1319
1320
1321 def generate_shap_waterfall_plot(
1322 explainer, instance: pd.DataFrame, title: str = "SHAP Waterfall Plot", path: Optional[str] = None
1323 ) -> None:
1324 """
1325 SHAP waterfall plot (Matplotlib).
1326 """
1327 shap_values = explainer(instance)
1328 plt.figure(figsize=(10, 6))
1329 shap.plots.waterfall(shap_values[0], show=False)
1330 plt.title(title)
1331 _save_matplotlib(path)
1332
1333
1334 def infer_problem_type(predictor, df_train_full: pd.DataFrame, label_column: str) -> str:
1335 """
1336 Return 'binary', 'multiclass', or 'regression'.
1337 Prefer the predictor's own metadata when available; otherwise infer from label dtype/uniques.
1338 """
1339 # AutoGluon predictors usually expose .problem_type; be defensive.
1340 pt = getattr(predictor, "problem_type", None)
1341 if isinstance(pt, str):
1342 pt_l = pt.lower()
1343 if "regression" in pt_l:
1344 return "regression"
1345 if "binary" in pt_l:
1346 return "binary"
1347 if "multiclass" in pt_l or "multiclass" in pt_l:
1348 return "multiclass"
1349
1350 y = df_train_full[label_column]
1351 if pd.api.types.is_numeric_dtype(y) and y.nunique() > 10:
1352 return "regression"
1353 return "binary" if y.nunique() == 2 else "multiclass"
1354
1355
1356 def _safe_floatify(d: Dict[str, Any]) -> Dict[str, float]:
1357 """Make evaluate() outputs JSON/csv friendly floats."""
1358 out = {}
1359 for k, v in d.items():
1360 try:
1361 out[k] = float(v)
1362 except Exception:
1363 # keep only real-valued scalars
1364 pass
1365 return out
1366
1367
1368 def evaluate_all(
1369 predictor,
1370 df_train: pd.DataFrame,
1371 df_val: pd.DataFrame,
1372 df_test: pd.DataFrame,
1373 label_column: str,
1374 problem_type: str,
1375 ) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float]]:
1376 """
1377 Run predictor.evaluate on train/val/test and normalize the result dicts to floats.
1378 MultiModalPredictor does not accept the `silent` kwarg, so call defensively.
1379 """
1380 def _evaluate(df):
1381 try:
1382 return predictor.evaluate(df, silent=True)
1383 except TypeError:
1384 return predictor.evaluate(df)
1385
1386 train_scores = _safe_floatify(_evaluate(df_train))
1387 val_scores = _safe_floatify(_evaluate(df_val))
1388 test_scores = _safe_floatify(_evaluate(df_test))
1389 return train_scores, val_scores, test_scores
1390
1391
1392 def build_summary_html(
1393 predictor,
1394 df_train: pd.DataFrame,
1395 df_val: Optional[pd.DataFrame],
1396 df_test: Optional[pd.DataFrame],
1397 label_column: str,
1398 extra_run_rows: Optional[list[tuple[str, str]]] = None,
1399 class_balance_html: Optional[str] = None,
1400 perf_table_html: Optional[str] = None,
1401 ) -> str:
1402 sections = []
1403
1404 # Dataset Overview (first section in the tab)
1405 if class_balance_html:
1406 sections.append(f"""
1407 <section class="section">
1408 <h2 class="section-title">Dataset Overview</h2>
1409 <div class="card">
1410 {class_balance_html}
1411 </div>
1412 </section>
1413 """.strip())
1414
1415 # Performance Summary
1416 if perf_table_html:
1417 sections.append(f"""
1418 <section class="section">
1419 <h2 class="section-title">Model Performance Summary</h2>
1420 <div class="card">
1421 {perf_table_html}
1422 </div>
1423 </section>
1424 """.strip())
1425
1426 # Model Configuration
1427
1428 # Remove Predictor type and Framework, and ensure Model Architecture is present
1429 base_rows: list[tuple[str, str]] = []
1430 if extra_run_rows:
1431 # Remove any rows with keys 'Predictor type' or 'Framework'
1432 base_rows.extend([(k, v) for (k, v) in extra_run_rows if k not in ("Predictor type", "Framework")])
1433
1434 def _fmt(v):
1435 if v is None or v == "":
1436 return "—"
1437 return _escape(str(v))
1438
1439 rows_html = "\n".join(
1440 f"<tr><td>{_escape(str(k))}</td><td>{_fmt(v)}</td></tr>"
1441 for k, v in base_rows
1442 )
1443
1444 sections.append(f"""
1445 <section class="section">
1446 <h2 class="section-title">Model Configuration</h2>
1447 <div class="card">
1448 <table class="kv-table">
1449 <thead><tr><th>Key</th><th>Value</th></tr></thead>
1450 <tbody>
1451 {rows_html}
1452 </tbody>
1453 </table>
1454 </div>
1455 </section>
1456 """.strip())
1457
1458 return "\n".join(sections).strip()
1459
1460
1461 def build_feature_importance_html(predictor, df_train: pd.DataFrame, label_column: str) -> str:
1462 """Build a visualization of feature importance."""
1463 try:
1464 # Try to get feature importance from predictor
1465 fi = None
1466 if hasattr(predictor, "feature_importance") and callable(predictor.feature_importance):
1467 try:
1468 fi = predictor.feature_importance(df_train)
1469 except Exception as e:
1470 return f"<p>Could not compute feature importance: {e}</p>"
1471
1472 if fi is None or (isinstance(fi, pd.DataFrame) and fi.empty):
1473 return "<p>Feature importance not available for this model.</p>"
1474
1475 # Format as a sortable table
1476 rows = []
1477 if isinstance(fi, pd.DataFrame):
1478 fi = fi.sort_values("importance", ascending=False)
1479 for _, row in fi.iterrows():
1480 feat = row.index[0] if isinstance(row.index, pd.Index) else row["feature"]
1481 imp = float(row["importance"])
1482 rows.append(f"<tr><td>{_escape(str(feat))}</td><td>{imp:.4f}</td></tr>")
1483 else:
1484 # Handle other formats (dict, etc)
1485 for feat, imp in sorted(fi.items(), key=lambda x: float(x[1]), reverse=True):
1486 rows.append(f"<tr><td>{_escape(str(feat))}</td><td>{float(imp):.4f}</td></tr>")
1487
1488 if not rows:
1489 return "<p>No feature importance values available.</p>"
1490
1491 table_html = f"""
1492 <table class="performance-summary">
1493 <thead>
1494 <tr>
1495 <th class="sortable">Feature</th>
1496 <th class="sortable">Importance</th>
1497 </tr>
1498 </thead>
1499 <tbody>
1500 {"".join(rows)}
1501 </tbody>
1502 </table>
1503 """
1504 return table_html
1505
1506 except Exception as e:
1507 return f"<p>Error building feature importance visualization: {e}</p>"
1508
1509
1510 def build_test_html_and_plots(
1511 predictor,
1512 problem_type: str,
1513 df_test: pd.DataFrame,
1514 label_column: str,
1515 tmpdir: str,
1516 threshold: Optional[float] = None,
1517 ) -> Tuple[str, List[str]]:
1518 """
1519 Create a test-summary section (with a placeholder for metric rows) and a list of Plotly HTML divs.
1520 Returns: (html_template_with_{}, list_of_plot_divs)
1521 """
1522 plots: List[str] = []
1523
1524 y_true = df_test[label_column].values
1525 classes = np.unique(y_true)
1526
1527 # Try proba/labels where meaningful
1528 pred_labels = None
1529 pred_proba = None
1530 try:
1531 pred_labels = predictor.predict(df_test)
1532 except Exception:
1533 pass
1534 try:
1535 # MultiModalPredictor exposes predict_proba for classification problems.
1536 pred_proba = predictor.predict_proba(df_test)
1537 except Exception:
1538 pred_proba = None
1539
1540 proba_arr = None
1541 if pred_proba is not None:
1542 if isinstance(pred_proba, pd.Series):
1543 proba_arr = pred_proba.to_numpy().reshape(-1, 1)
1544 elif isinstance(pred_proba, pd.DataFrame):
1545 proba_arr = pred_proba.to_numpy()
1546 else:
1547 proba_arr = np.asarray(pred_proba)
1548
1549 # Thresholded labels for binary
1550 if problem_type == "binary" and threshold is not None and proba_arr is not None:
1551 pos_label, neg_label = classes.max(), classes.min()
1552 pos_scores = proba_arr.reshape(-1) if (proba_arr.ndim == 1 or proba_arr.shape[1] == 1) else proba_arr[:, -1]
1553 pred_labels = np.where(pos_scores >= float(threshold), pos_label, neg_label)
1554
1555 # Confusion matrix / per-class now reflect thresholded labels
1556 if problem_type in ("binary", "multiclass") and pred_labels is not None:
1557 cm_title = "Confusion Matrix"
1558 if threshold is not None and problem_type == "binary":
1559 thr_str = f"{float(threshold):.3f}".rstrip("0").rstrip(".")
1560 cm_title = f"Confusion Matrix (Threshold = {thr_str})"
1561 fig_cm = generate_confusion_matrix_plot(y_true, pred_labels, title=cm_title)
1562 plots.append(plot_with_table_style_title(fig_cm, cm_title))
1563
1564 fig_pc = generate_per_class_metrics_plot(y_true, pred_labels, title="Per-Class Metrics")
1565 plots.append(plot_with_table_style_title(fig_pc, "Per-Class Metrics"))
1566
1567 # ROC/PR where possible — choose positive-class scores safely
1568 pos_label = classes.max() # or set explicitly, e.g., 1 or "yes"
1569
1570 if isinstance(pred_proba, pd.DataFrame):
1571 proba_arr = pred_proba.to_numpy()
1572 if pos_label in pred_proba.columns:
1573 pos_idx = list(pred_proba.columns).index(pos_label)
1574 else:
1575 pos_idx = -1 # fallback to last column
1576 elif isinstance(pred_proba, pd.Series):
1577 proba_arr = pred_proba.to_numpy().reshape(-1, 1)
1578 pos_idx = 0
1579 else:
1580 proba_arr = np.asarray(pred_proba) if pred_proba is not None else None
1581 pos_idx = -1 if (proba_arr is not None and proba_arr.ndim == 2 and proba_arr.shape[1] > 1) else 0
1582
1583 if proba_arr is not None:
1584 y_bin = (y_true == pos_label).astype(int)
1585 pos_scores = (
1586 proba_arr.reshape(-1)
1587 if proba_arr.ndim == 1 or proba_arr.shape[1] == 1
1588 else proba_arr[:, pos_idx]
1589 )
1590
1591 fig_roc = generate_roc_curve_plot(y_bin, pos_scores, title="ROC Curve", marker_threshold=threshold)
1592 plots.append(plot_with_table_style_title(fig_roc, f"ROC Curve{'' if threshold is None else f' (marker at threshold={threshold:.2f})'}"))
1593
1594 fig_pr = generate_pr_curve_plot(y_bin, pos_scores, title="Precision–Recall Curve", marker_threshold=threshold)
1595 plots.append(plot_with_table_style_title(fig_pr, f"Precision–Recall Curve{'' if threshold is None else f' (marker at threshold={threshold:.2f})'}"))
1596
1597 # Additional diagnostics aligned with ImageLearner style
1598 if problem_type == "binary":
1599 conf_fig = plot_confidence_histogram(pos_scores, bins=20, title="Prediction Confidence (Test)")
1600 plots.append(plot_with_table_style_title(conf_fig, "Prediction Confidence (Test)"))
1601 else:
1602 conf_fig = plot_confidence_histogram(proba_arr, bins=20, title="Prediction Confidence (Top-1, Test)")
1603 plots.append(plot_with_table_style_title(conf_fig, "Prediction Confidence (Top-1, Test)"))
1604
1605 if problem_type == "multiclass" and proba_arr is not None and proba_arr.ndim >= 2:
1606 fig_mc_roc = generate_multiclass_roc_curve_plot(y_true, proba_arr, classes, title="One-vs-Rest ROC (Test)")
1607 plots.append(plot_with_table_style_title(fig_mc_roc, "One-vs-Rest ROC (Test)"))
1608
1609 # Regression visuals
1610 if problem_type == "regression":
1611 if pred_labels is None:
1612 pred_labels = predictor.predict(df_test)
1613 fig_sc = generate_scatter_plot(y_true, pred_labels, title="Predicted vs Actual")
1614 plots.append(plot_with_table_style_title(fig_sc, "Predicted vs Actual"))
1615
1616 fig_res = generate_residual_plot(y_true, pred_labels, title="Residual Plot")
1617 plots.append(plot_with_table_style_title(fig_res, "Residual Plot"))
1618
1619 fig_hist = generate_residual_histogram(y_true, pred_labels, title="Residual Histogram")
1620 plots.append(plot_with_table_style_title(fig_hist, "Residual Histogram"))
1621
1622 fig_cal = generate_regression_calibration_plot(y_true, pred_labels, title="Regression Calibration")
1623 plots.append(plot_with_table_style_title(fig_cal, "Regression Calibration"))
1624
1625 # Small HTML template with placeholder for metric rows the caller fills in
1626 test_html_template = """
1627 <h2>Test Performance Summary</h2>
1628 <table class="performance-summary">
1629 <thead><tr><th>Metric</th><th>Test</th></tr></thead>
1630 <tbody>{}</tbody>
1631 </table>
1632 """
1633 return test_html_template, plots
1634
1635
1636 def build_feature_html(
1637 predictor,
1638 df_train: pd.DataFrame,
1639 label_column: str,
1640 include_modalities: bool = True, # ← NEW
1641 include_class_balance: bool = True, # ← NEW
1642 ) -> str:
1643 sections = []
1644
1645 # (Typical feature importance content…)
1646 fi_html = build_feature_importance_html(predictor, df_train, label_column)
1647 sections.append(f"<section class='section'><h2 class='section-title'>Feature Importance</h2><div class='card'>{fi_html}</div></section>")
1648
1649 # Previously: Modalities & Inputs and/or Class Balance may have been here.
1650 # Only render them if flags are True.
1651 if include_modalities:
1652 from report_utils import build_modalities_html
1653 modalities_html = build_modalities_html(predictor, df_train, label_column)
1654 sections.append(f"<section class='section'><h2 class='section-title'>Modalities & Inputs</h2><div class='card'>{modalities_html}</div></section>")
1655
1656 if include_class_balance:
1657 from report_utils import build_class_balance_html
1658 cb_html = build_class_balance_html(df_train, label_column)
1659 sections.append(f"<section class='section'><h2 class='section-title'>Class Balance (Train Full)</h2><div class='card'>{cb_html}</div></section>")
1660
1661 return "\n".join(sections)
1662
1663
1664 def assemble_full_html_report(
1665 summary_html: str,
1666 train_html: str,
1667 test_html: str,
1668 plots: List[str],
1669 feature_html: str,
1670 ) -> str:
1671 """
1672 Wrap the four tabs using utils.build_tabbed_html and return full HTML.
1673 """
1674 # Append plots under the Test tab (already wrapped with titles)
1675 test_full = test_html + "".join(plots)
1676
1677 tabs = build_tabbed_html(summary_html, train_html, test_full, feature_html, explainer_html=None)
1678
1679 html_out = get_html_template()
1680
1681 # 🔧 Ensure Plotly JS is available (we render plots with include_plotlyjs=False)
1682 html_out += '\n<script src="https://cdn.plot.ly/plotly-2.30.0.min.js"></script>\n'
1683
1684 # Optional: centering tweaks
1685 html_out += """
1686 <style>
1687 .plotly-center { display: flex; justify-content: center; }
1688 .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { margin: 0 auto !important; }
1689 .js-plotly-plot, .plotly-graph-div { margin-left: auto !important; margin-right: auto !important; }
1690 </style>
1691 """
1692 # Help modal HTML/JS
1693 html_out += get_metrics_help_modal()
1694
1695 html_out += tabs
1696 html_out += get_html_closing()
1697 return html_out