Mercurial > repos > goeckslab > image_learner
comparison plotly_plots.py @ 8:85e6f4b2ad18 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 8a42eb9b33df7e1df5ad5153b380e20b910a05b6
author | goeckslab |
---|---|
date | Thu, 14 Aug 2025 14:53:10 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
7:801a8b6973fb | 8:85e6f4b2ad18 |
---|---|
1 import json | |
2 from typing import Dict, List, Optional | |
3 | |
4 import numpy as np | |
5 import plotly.graph_objects as go | |
6 import plotly.io as pio | |
7 | |
8 | |
9 def build_classification_plots( | |
10 test_stats_path: str, | |
11 training_stats_path: Optional[str] = None, | |
12 ) -> List[Dict[str, str]]: | |
13 """ | |
14 Read Ludwig’s test_statistics.json and build three interactive Plotly panels: | |
15 - Confusion Matrix | |
16 - ROC-AUC | |
17 - Classification Report Heatmap | |
18 | |
19 Returns a list of dicts, each with: | |
20 { | |
21 "title": <plot title>, | |
22 "html": <HTML fragment for embedding> | |
23 } | |
24 """ | |
25 # --- Load test stats --- | |
26 with open(test_stats_path, "r") as f: | |
27 test_stats = json.load(f) | |
28 label_stats = test_stats["label"] | |
29 | |
30 # common sizing | |
31 cell = 40 | |
32 n_classes = len(label_stats["confusion_matrix"]) | |
33 side_px = max(cell * n_classes + 200, 600) | |
34 common_cfg = {"displayModeBar": True, "scrollZoom": True} | |
35 | |
36 plots: List[Dict[str, str]] = [] | |
37 | |
38 # 0) Confusion Matrix | |
39 cm = np.array(label_stats["confusion_matrix"], dtype=int) | |
40 labels = label_stats.get("labels", [str(i) for i in range(cm.shape[0])]) | |
41 total = cm.sum() | |
42 | |
43 fig_cm = go.Figure( | |
44 go.Heatmap( | |
45 z=cm, | |
46 x=labels, | |
47 y=labels, | |
48 colorscale="Blues", | |
49 showscale=True, | |
50 colorbar=dict(title="Count"), | |
51 ) | |
52 ) | |
53 fig_cm.update_traces(xgap=2, ygap=2) | |
54 fig_cm.update_layout( | |
55 title=dict(text="Confusion Matrix", x=0.5), | |
56 xaxis_title="Predicted", | |
57 yaxis_title="Observed", | |
58 yaxis_autorange="reversed", | |
59 width=side_px, | |
60 height=side_px, | |
61 margin=dict(t=100, l=80, r=80, b=80), | |
62 ) | |
63 | |
64 # annotate counts and percentages | |
65 mval = cm.max() if cm.size else 0 | |
66 thresh = mval / 2 | |
67 for i in range(cm.shape[0]): | |
68 for j in range(cm.shape[1]): | |
69 v = cm[i, j] | |
70 pct = (v / total * 100) if total > 0 else 0 | |
71 color = "white" if v > thresh else "black" | |
72 fig_cm.add_annotation( | |
73 x=labels[j], | |
74 y=labels[i], | |
75 text=f"<b>{v}</b>", | |
76 showarrow=False, | |
77 font=dict(color=color, size=14), | |
78 xanchor="center", | |
79 yanchor="bottom", | |
80 yshift=2, | |
81 ) | |
82 fig_cm.add_annotation( | |
83 x=labels[j], | |
84 y=labels[i], | |
85 text=f"{pct:.1f}%", | |
86 showarrow=False, | |
87 font=dict(color=color, size=13), | |
88 xanchor="center", | |
89 yanchor="top", | |
90 yshift=-2, | |
91 ) | |
92 | |
93 plots.append({ | |
94 "title": "Confusion Matrix", | |
95 "html": pio.to_html( | |
96 fig_cm, | |
97 full_html=False, | |
98 include_plotlyjs="cdn", | |
99 config=common_cfg | |
100 ) | |
101 }) | |
102 | |
103 # 2) Classification Report Heatmap | |
104 pcs = label_stats.get("per_class_stats", {}) | |
105 if pcs: | |
106 classes = list(pcs.keys()) | |
107 metrics = ["precision", "recall", "f1_score"] | |
108 z, txt = [], [] | |
109 for c in classes: | |
110 row, trow = [], [] | |
111 for m in metrics: | |
112 val = pcs[c].get(m, 0) | |
113 row.append(val) | |
114 trow.append(f"{val:.2f}") | |
115 z.append(row) | |
116 txt.append(trow) | |
117 | |
118 fig_cr = go.Figure( | |
119 go.Heatmap( | |
120 z=z, | |
121 x=metrics, | |
122 y=[str(c) for c in classes], | |
123 text=txt, | |
124 texttemplate="%{text}", | |
125 colorscale="Reds", | |
126 showscale=True, | |
127 colorbar=dict(title="Value"), | |
128 ) | |
129 ) | |
130 fig_cr.update_layout( | |
131 title="Classification Report", | |
132 xaxis_title="", | |
133 yaxis_title="Class", | |
134 width=side_px, | |
135 height=side_px, | |
136 margin=dict(t=80, l=80, r=80, b=80), | |
137 ) | |
138 plots.append({ | |
139 "title": "Classification Report", | |
140 "html": pio.to_html( | |
141 fig_cr, | |
142 full_html=False, | |
143 include_plotlyjs=False, | |
144 config=common_cfg | |
145 ) | |
146 }) | |
147 | |
148 return plots |