Mercurial > repos > goeckslab > image_learner
comparison plotly_plots.py @ 8:85e6f4b2ad18 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 8a42eb9b33df7e1df5ad5153b380e20b910a05b6
| author | goeckslab |
|---|---|
| date | Thu, 14 Aug 2025 14:53:10 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 7:801a8b6973fb | 8:85e6f4b2ad18 |
|---|---|
| 1 import json | |
| 2 from typing import Dict, List, Optional | |
| 3 | |
| 4 import numpy as np | |
| 5 import plotly.graph_objects as go | |
| 6 import plotly.io as pio | |
| 7 | |
| 8 | |
| 9 def build_classification_plots( | |
| 10 test_stats_path: str, | |
| 11 training_stats_path: Optional[str] = None, | |
| 12 ) -> List[Dict[str, str]]: | |
| 13 """ | |
| 14 Read Ludwig’s test_statistics.json and build three interactive Plotly panels: | |
| 15 - Confusion Matrix | |
| 16 - ROC-AUC | |
| 17 - Classification Report Heatmap | |
| 18 | |
| 19 Returns a list of dicts, each with: | |
| 20 { | |
| 21 "title": <plot title>, | |
| 22 "html": <HTML fragment for embedding> | |
| 23 } | |
| 24 """ | |
| 25 # --- Load test stats --- | |
| 26 with open(test_stats_path, "r") as f: | |
| 27 test_stats = json.load(f) | |
| 28 label_stats = test_stats["label"] | |
| 29 | |
| 30 # common sizing | |
| 31 cell = 40 | |
| 32 n_classes = len(label_stats["confusion_matrix"]) | |
| 33 side_px = max(cell * n_classes + 200, 600) | |
| 34 common_cfg = {"displayModeBar": True, "scrollZoom": True} | |
| 35 | |
| 36 plots: List[Dict[str, str]] = [] | |
| 37 | |
| 38 # 0) Confusion Matrix | |
| 39 cm = np.array(label_stats["confusion_matrix"], dtype=int) | |
| 40 labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])]) | |
| 41 total = cm.sum() | |
| 42 | |
| 43 fig_cm = go.Figure( | |
| 44 go.Heatmap( | |
| 45 z=cm, | |
| 46 x=labels, | |
| 47 y=labels, | |
| 48 colorscale="Blues", | |
| 49 showscale=True, | |
| 50 colorbar=dict(title="Count"), | |
| 51 ) | |
| 52 ) | |
| 53 fig_cm.update_traces(xgap=2, ygap=2) | |
| 54 fig_cm.update_layout( | |
| 55 title=dict(text="Confusion Matrix", x=0.5), | |
| 56 xaxis_title="Predicted", | |
| 57 yaxis_title="Observed", | |
| 58 yaxis_autorange="reversed", | |
| 59 width=side_px, | |
| 60 height=side_px, | |
| 61 margin=dict(t=100, l=80, r=80, b=80), | |
| 62 ) | |
| 63 | |
| 64 # annotate counts and percentages | |
| 65 mval = cm.max() if cm.size else 0 | |
| 66 thresh = mval / 2 | |
| 67 for i in range(cm.shape[0]): | |
| 68 for j in range(cm.shape[1]): | |
| 69 v = cm[i, j] | |
| 70 pct = (v / total * 100) if total > 0 else 0 | |
| 71 color = "white" if v > thresh else "black" | |
| 72 fig_cm.add_annotation( | |
| 73 x=labels[j], | |
| 74 y=labels[i], | |
| 75 text=f"<b>{v}</b>", | |
| 76 showarrow=False, | |
| 77 font=dict(color=color, size=14), | |
| 78 xanchor="center", | |
| 79 yanchor="bottom", | |
| 80 yshift=2, | |
| 81 ) | |
| 82 fig_cm.add_annotation( | |
| 83 x=labels[j], | |
| 84 y=labels[i], | |
| 85 text=f"{pct:.1f}%", | |
| 86 showarrow=False, | |
| 87 font=dict(color=color, size=13), | |
| 88 xanchor="center", | |
| 89 yanchor="top", | |
| 90 yshift=-2, | |
| 91 ) | |
| 92 | |
| 93 plots.append({ | |
| 94 "title": "Confusion Matrix", | |
| 95 "html": pio.to_html( | |
| 96 fig_cm, | |
| 97 full_html=False, | |
| 98 include_plotlyjs="cdn", | |
| 99 config=common_cfg | |
| 100 ) | |
| 101 }) | |
| 102 | |
| 103 # 2) Classification Report Heatmap | |
| 104 pcs = label_stats.get("per_class_stats", {}) | |
| 105 if pcs: | |
| 106 classes = list(pcs.keys()) | |
| 107 metrics = ["precision", "recall", "f1_score"] | |
| 108 z, txt = [], [] | |
| 109 for c in classes: | |
| 110 row, trow = [], [] | |
| 111 for m in metrics: | |
| 112 val = pcs[c].get(m, 0) | |
| 113 row.append(val) | |
| 114 trow.append(f"{val:.2f}") | |
| 115 z.append(row) | |
| 116 txt.append(trow) | |
| 117 | |
| 118 fig_cr = go.Figure( | |
| 119 go.Heatmap( | |
| 120 z=z, | |
| 121 x=metrics, | |
| 122 y=[str(c) for c in classes], | |
| 123 text=txt, | |
| 124 texttemplate="%{text}", | |
| 125 colorscale="Reds", | |
| 126 showscale=True, | |
| 127 colorbar=dict(title="Value"), | |
| 128 ) | |
| 129 ) | |
| 130 fig_cr.update_layout( | |
| 131 title="Classification Report", | |
| 132 xaxis_title="", | |
| 133 yaxis_title="Class", | |
| 134 width=side_px, | |
| 135 height=side_px, | |
| 136 margin=dict(t=80, l=80, r=80, b=80), | |
| 137 ) | |
| 138 plots.append({ | |
| 139 "title": "Classification Report", | |
| 140 "html": pio.to_html( | |
| 141 fig_cr, | |
| 142 full_html=False, | |
| 143 include_plotlyjs=False, | |
| 144 config=common_cfg | |
| 145 ) | |
| 146 }) | |
| 147 | |
| 148 return plots |
