Mercurial > repos > goeckslab > pycaret_predict
comparison utils.py @ 8:1aed7d47c5ec draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author | goeckslab |
---|---|
date | Fri, 25 Jul 2025 19:02:32 +0000 |
parents | a32ff7201629 |
children |
comparison
equal
deleted
inserted
replaced
7:f4cb41f458fd | 8:1aed7d47c5ec |
---|---|
1 import base64 | 1 import base64 |
2 import logging | 2 import logging |
3 from typing import Optional | |
3 | 4 |
4 import numpy as np | 5 import numpy as np |
5 | 6 |
6 logging.basicConfig(level=logging.DEBUG) | 7 logging.basicConfig(level=logging.DEBUG) |
7 LOG = logging.getLogger(__name__) | 8 LOG = logging.getLogger(__name__) |
8 | 9 |
9 | 10 |
10 def get_html_template(): | 11 def get_html_template() -> str: |
11 return """ | 12 return """ |
12 <html> | 13 <html> |
13 <head> | 14 <head> |
14 <meta charset="UTF-8"> | 15 <meta charset="UTF-8"> |
15 <title>Model Training Report</title> | 16 <title>Model Training Report</title> |
18 font-family: Arial, sans-serif; | 19 font-family: Arial, sans-serif; |
19 margin: 0; | 20 margin: 0; |
20 padding: 20px; | 21 padding: 20px; |
21 background-color: #f4f4f4; | 22 background-color: #f4f4f4; |
22 } | 23 } |
24 /* allow horizontal scrolling if content overflows */ | |
23 .container { | 25 .container { |
24 max-width: 800px; | 26 max-width: 800px; |
25 margin: auto; | 27 margin: auto; |
26 background: white; | 28 background: white; |
27 padding: 20px; | 29 padding: 20px; |
28 box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); | 30 box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); |
29 } | 31 overflow-x: auto; |
32 } | |
33 | |
30 h1 { | 34 h1 { |
31 text-align: center; | 35 text-align: center; |
32 color: #333; | 36 color: #333; |
33 } | 37 } |
34 h2 { | 38 h2 { |
35 border-bottom: 2px solid #4CAF50; | 39 border-bottom: 2px solid #4CAF50; |
36 color: #4CAF50; | 40 color: #4CAF50; |
37 padding-bottom: 5px; | 41 padding-bottom: 5px; |
38 } | 42 } |
43 | |
44 /* wrapper for tables to allow individual horizontal scroll */ | |
45 .table-wrapper { | |
46 overflow-x: auto; | |
47 margin: 1rem 0; | |
48 } | |
49 | |
50 /* revert table styling to full borders */ | |
39 table { | 51 table { |
40 width: 100%; | 52 width: 100%; |
41 border-collapse: collapse; | 53 border-collapse: collapse; |
42 margin: 20px 0; | 54 margin: 20px 0; |
43 } | 55 } |
50 } | 62 } |
51 th { | 63 th { |
52 background-color: #4CAF50; | 64 background-color: #4CAF50; |
53 color: white; | 65 color: white; |
54 } | 66 } |
67 | |
55 .plot { | 68 .plot { |
56 text-align: center; | 69 text-align: center; |
57 margin: 20px 0; | 70 margin: 20px 0; |
58 } | 71 } |
59 .plot img { | 72 .plot img { |
60 max-width: 100%; | 73 max-width: 100%; |
61 height: auto; | 74 height: auto; |
62 } | 75 } |
76 | |
63 .tabs { | 77 .tabs { |
64 display: flex; | 78 display: flex; |
65 margin-bottom: 20px; | 79 align-items: center; |
80 border-bottom: 2px solid #ccc; | |
81 margin-bottom: 1rem; | |
82 } | |
83 .tab { | |
84 padding: 10px 20px; | |
66 cursor: pointer; | 85 cursor: pointer; |
67 justify-content: space-around; | 86 border: 1px solid #ccc; |
68 } | 87 border-bottom: none; |
69 .tab { | 88 background: #f9f9f9; |
70 padding: 10px; | 89 margin-right: 5px; |
71 background-color: #4CAF50; | 90 border-top-left-radius: 8px; |
72 color: white; | 91 border-top-right-radius: 8px; |
73 border-radius: 5px 5px 0 0; | 92 } |
74 flex-grow: 1; | 93 .tab.active { |
75 text-align: center; | 94 background: white; |
76 margin: 0 5px; | 95 font-weight: bold; |
77 } | 96 } |
78 .tab.active-tab { | 97 |
79 background-color: #333; | |
80 } | |
81 .tab-content { | 98 .tab-content { |
82 display: none; | 99 display: none; |
83 padding: 20px; | 100 padding: 20px; |
84 border: 1px solid #ddd; | 101 border: 1px solid #ccc; |
85 border-top: none; | 102 border-top: none; |
86 background-color: white; | 103 background: white; |
87 } | 104 } |
88 .tab-content.active-content { | 105 .tab-content.active { |
89 display: block; | 106 display: block; |
90 } | 107 } |
91 </style> | 108 |
109 .help-btn { | |
110 margin-left: auto; | |
111 padding: 6px 12px; | |
112 font-size: 0.9rem; | |
113 border: 1px solid #4CAF50; | |
114 border-radius: 4px; | |
115 background: #4CAF50; | |
116 color: white; | |
117 cursor: pointer; | |
118 } | |
119 | |
120 /* sortable table header arrows */ | |
121 table.sortable th { | |
122 position: relative; | |
123 padding-right: 20px; /* room for the arrow */ | |
124 cursor: pointer; | |
125 } | |
126 table.sortable th::after { | |
127 content: '↕'; | |
128 position: absolute; | |
129 right: 8px; | |
130 opacity: 0.4; | |
131 transition: opacity 0.2s; | |
132 } | |
133 table.sortable th:hover::after { | |
134 opacity: 0.7; | |
135 } | |
136 table.sortable th.sorted-asc::after { | |
137 content: '↑'; | |
138 opacity: 1; | |
139 } | |
140 table.sortable th.sorted-desc::after { | |
141 content: '↓'; | |
142 opacity: 1; | |
143 } | |
144 </style> | |
92 </head> | 145 </head> |
93 <body> | 146 <body> |
94 <div class="container"> | 147 <div class="container"> |
95 """ | 148 """ |
96 | 149 |
97 | 150 |
98 def get_html_closing(): | 151 def get_html_closing() -> str: |
99 return """ | 152 return """ |
100 </div> | 153 </div> |
101 <script> | 154 <script> |
102 function openTab(evt, tabName) {{ | 155 document.addEventListener('DOMContentLoaded', () => { |
103 var i, tabcontent, tablinks; | 156 document.querySelectorAll('table.sortable').forEach(table => { |
104 tabcontent = document.getElementsByClassName("tab-content"); | 157 const getCellValue = (row, idx) => |
105 for (i = 0; i < tabcontent.length; i++) {{ | 158 row.children[idx].innerText.trim() || ''; |
106 tabcontent[i].style.display = "none"; | 159 |
107 }} | 160 const comparer = (idx, asc) => (a, b) => { |
108 tablinks = document.getElementsByClassName("tab"); | 161 const v1 = getCellValue(asc ? a : b, idx); |
109 for (i = 0; i < tablinks.length; i++) {{ | 162 const v2 = getCellValue(asc ? b : a, idx); |
110 tablinks[i].className = | 163 const n1 = parseFloat(v1), n2 = parseFloat(v2); |
111 tablinks[i].className.replace(" active-tab", ""); | 164 if (!isNaN(n1) && !isNaN(n2)) return n1 - n2; |
112 }} | 165 return v1.localeCompare(v2); |
113 document.getElementById(tabName).style.display = "block"; | 166 }; |
114 evt.currentTarget.className += " active-tab"; | 167 |
115 }} | 168 table.querySelectorAll('th').forEach((th, idx) => { |
116 document.addEventListener("DOMContentLoaded", function() {{ | 169 let asc = true; |
117 document.querySelector(".tab").click(); | 170 th.addEventListener('click', () => { |
118 }}); | 171 // sort rows |
119 </script> | 172 const tbody = table.tBodies[0]; |
173 Array.from(tbody.rows) | |
174 .sort(comparer(idx, asc)) | |
175 .forEach(row => tbody.appendChild(row)); | |
176 // update arrow classes | |
177 table.querySelectorAll('th').forEach(h => { | |
178 h.classList.remove('sorted-asc','sorted-desc'); | |
179 }); | |
180 th.classList.add(asc ? 'sorted-asc' : 'sorted-desc'); | |
181 asc = !asc; | |
182 }); | |
183 }); | |
184 }); | |
185 }); | |
186 </script> | |
120 </body> | 187 </body> |
121 </html> | 188 </html> |
122 """ | 189 """ |
123 | 190 |
124 | 191 |
192 def build_tabbed_html( | |
193 summary_html: str, | |
194 test_html: str, | |
195 feature_html: str, | |
196 explainer_html: Optional[str] = None, | |
197 ) -> str: | |
198 """ | |
199 Render the tabbed sections and an always-visible Help button. | |
200 """ | |
201 # CSS | |
202 css = get_html_template().split("<body>")[1].rsplit("</style>", 1)[0] + "</style>" | |
203 | |
204 # Tabs header | |
205 tabs = [ | |
206 '<div class="tabs">', | |
207 '<div class="tab active" onclick="showTab(\'summary\')">Validation Summary & Config</div>', | |
208 '<div class="tab" onclick="showTab(\'test\')">Test Summary</div>', | |
209 '<div class="tab" onclick="showTab(\'feature\')">Feature Importance</div>', | |
210 ] | |
211 if explainer_html: | |
212 tabs.append( | |
213 '<div class="tab" onclick="showTab(\'explainer\')">Explainer Plots</div>' | |
214 ) | |
215 tabs.append('<button id="openMetricsHelp" class="help-btn">Help</button>') | |
216 tabs.append("</div>") | |
217 tabs_section = "\n".join(tabs) | |
218 | |
219 # Content | |
220 contents = [ | |
221 f'<div id="summary" class="tab-content active">{summary_html}</div>', | |
222 f'<div id="test" class="tab-content">{test_html}</div>', | |
223 f'<div id="feature" class="tab-content">{feature_html}</div>', | |
224 ] | |
225 if explainer_html: | |
226 contents.append( | |
227 f'<div id="explainer" class="tab-content">{explainer_html}</div>' | |
228 ) | |
229 content_section = "\n".join(contents) | |
230 | |
231 # JS | |
232 js = """ | |
233 <script> | |
234 function showTab(id) { | |
235 document.querySelectorAll('.tab-content').forEach(el=>el.classList.remove('active')); | |
236 document.querySelectorAll('.tab').forEach(el=>el.classList.remove('active')); | |
237 document.getElementById(id).classList.add('active'); | |
238 document.querySelector(`.tab[onclick*="${id}"]`).classList.add('active'); | |
239 } | |
240 </script> | |
241 """ | |
242 | |
243 return css + "\n" + tabs_section + "\n" + content_section + "\n" + js | |
244 | |
245 | |
125 def customize_figure_layout(fig, margin_dict=None): | 246 def customize_figure_layout(fig, margin_dict=None): |
126 """ | |
127 Update the layout of a Plotly figure to reduce margins. | |
128 | |
129 Parameters: | |
130 fig (plotly.graph_objects.Figure): The Plotly figure to customize. | |
131 margin_dict (dict, optional): A dictionary specifying margin sizes. | |
132 Example: {'l': 10, 'r': 10, 't': 10, 'b': 10} | |
133 | |
134 Returns: | |
135 plotly.graph_objects.Figure: The updated Plotly figure. | |
136 """ | |
137 if margin_dict is None: | 247 if margin_dict is None: |
138 # Set default smaller margins | 248 margin_dict = {"l": 40, "r": 40, "t": 40, "b": 40} |
139 margin_dict = {'l': 40, 'r': 40, 't': 40, 'b': 40} | |
140 | |
141 fig.update_layout(margin=margin_dict) | 249 fig.update_layout(margin=margin_dict) |
142 return fig | 250 return fig |
143 | 251 |
144 | 252 |
145 def add_plot_to_html(fig, include_plotlyjs=True): | 253 def add_plot_to_html(fig, include_plotlyjs=True) -> str: |
146 custom_margin = {'l': 40, 'r': 40, 't': 60, 'b': 60} | 254 custom_margin = {"l": 40, "r": 40, "t": 60, "b": 60} |
147 fig = customize_figure_layout(fig, margin_dict=custom_margin) | 255 fig = customize_figure_layout(fig, margin_dict=custom_margin) |
148 return fig.to_html(full_html=False, | 256 return fig.to_html( |
149 default_height=350, | 257 full_html=False, |
150 include_plotlyjs="cdn" if include_plotlyjs else False) | 258 default_height=350, |
151 | 259 include_plotlyjs="cdn" if include_plotlyjs else False, |
152 | 260 ) |
153 def add_hr_to_html(): | 261 |
262 | |
263 def add_hr_to_html() -> str: | |
154 return "<hr>" | 264 return "<hr>" |
155 | 265 |
156 | 266 |
157 def encode_image_to_base64(image_path): | 267 def encode_image_to_base64(image_path: str) -> str: |
158 """Convert an image file to a base64 encoded string.""" | |
159 with open(image_path, "rb") as img_file: | 268 with open(image_path, "rb") as img_file: |
160 return base64.b64encode(img_file.read()).decode("utf-8") | 269 return base64.b64encode(img_file.read()).decode("utf-8") |
161 | 270 |
162 | 271 |
163 def predict_proba(self, X): | 272 def predict_proba(self, X): |
164 pred = self.predict(X) | 273 pred = self.predict(X) |
165 return np.array([1 - pred, pred]).T | 274 return np.vstack((1 - pred, pred)).T |