Mercurial > repos > goeckslab > pycaret_predict
comparison utils.py @ 8:1aed7d47c5ec draft
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
| author | goeckslab |
|---|---|
| date | Fri, 25 Jul 2025 19:02:32 +0000 |
| parents | a32ff7201629 |
| children | e2a6fed32d54 |
comparison
equal
deleted
inserted
replaced
| 7:f4cb41f458fd | 8:1aed7d47c5ec |
|---|---|
| 1 import base64 | 1 import base64 |
| 2 import logging | 2 import logging |
| 3 from typing import Optional | |
| 3 | 4 |
| 4 import numpy as np | 5 import numpy as np |
| 5 | 6 |
| 6 logging.basicConfig(level=logging.DEBUG) | 7 logging.basicConfig(level=logging.DEBUG) |
| 7 LOG = logging.getLogger(__name__) | 8 LOG = logging.getLogger(__name__) |
| 8 | 9 |
| 9 | 10 |
| 10 def get_html_template(): | 11 def get_html_template() -> str: |
| 11 return """ | 12 return """ |
| 12 <html> | 13 <html> |
| 13 <head> | 14 <head> |
| 14 <meta charset="UTF-8"> | 15 <meta charset="UTF-8"> |
| 15 <title>Model Training Report</title> | 16 <title>Model Training Report</title> |
| 18 font-family: Arial, sans-serif; | 19 font-family: Arial, sans-serif; |
| 19 margin: 0; | 20 margin: 0; |
| 20 padding: 20px; | 21 padding: 20px; |
| 21 background-color: #f4f4f4; | 22 background-color: #f4f4f4; |
| 22 } | 23 } |
| 24 /* allow horizontal scrolling if content overflows */ | |
| 23 .container { | 25 .container { |
| 24 max-width: 800px; | 26 max-width: 800px; |
| 25 margin: auto; | 27 margin: auto; |
| 26 background: white; | 28 background: white; |
| 27 padding: 20px; | 29 padding: 20px; |
| 28 box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); | 30 box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); |
| 29 } | 31 overflow-x: auto; |
| 32 } | |
| 33 | |
| 30 h1 { | 34 h1 { |
| 31 text-align: center; | 35 text-align: center; |
| 32 color: #333; | 36 color: #333; |
| 33 } | 37 } |
| 34 h2 { | 38 h2 { |
| 35 border-bottom: 2px solid #4CAF50; | 39 border-bottom: 2px solid #4CAF50; |
| 36 color: #4CAF50; | 40 color: #4CAF50; |
| 37 padding-bottom: 5px; | 41 padding-bottom: 5px; |
| 38 } | 42 } |
| 43 | |
| 44 /* wrapper for tables to allow individual horizontal scroll */ | |
| 45 .table-wrapper { | |
| 46 overflow-x: auto; | |
| 47 margin: 1rem 0; | |
| 48 } | |
| 49 | |
| 50 /* revert table styling to full borders */ | |
| 39 table { | 51 table { |
| 40 width: 100%; | 52 width: 100%; |
| 41 border-collapse: collapse; | 53 border-collapse: collapse; |
| 42 margin: 20px 0; | 54 margin: 20px 0; |
| 43 } | 55 } |
| 50 } | 62 } |
| 51 th { | 63 th { |
| 52 background-color: #4CAF50; | 64 background-color: #4CAF50; |
| 53 color: white; | 65 color: white; |
| 54 } | 66 } |
| 67 | |
| 55 .plot { | 68 .plot { |
| 56 text-align: center; | 69 text-align: center; |
| 57 margin: 20px 0; | 70 margin: 20px 0; |
| 58 } | 71 } |
| 59 .plot img { | 72 .plot img { |
| 60 max-width: 100%; | 73 max-width: 100%; |
| 61 height: auto; | 74 height: auto; |
| 62 } | 75 } |
| 76 | |
| 63 .tabs { | 77 .tabs { |
| 64 display: flex; | 78 display: flex; |
| 65 margin-bottom: 20px; | 79 align-items: center; |
| 80 border-bottom: 2px solid #ccc; | |
| 81 margin-bottom: 1rem; | |
| 82 } | |
| 83 .tab { | |
| 84 padding: 10px 20px; | |
| 66 cursor: pointer; | 85 cursor: pointer; |
| 67 justify-content: space-around; | 86 border: 1px solid #ccc; |
| 68 } | 87 border-bottom: none; |
| 69 .tab { | 88 background: #f9f9f9; |
| 70 padding: 10px; | 89 margin-right: 5px; |
| 71 background-color: #4CAF50; | 90 border-top-left-radius: 8px; |
| 72 color: white; | 91 border-top-right-radius: 8px; |
| 73 border-radius: 5px 5px 0 0; | 92 } |
| 74 flex-grow: 1; | 93 .tab.active { |
| 75 text-align: center; | 94 background: white; |
| 76 margin: 0 5px; | 95 font-weight: bold; |
| 77 } | 96 } |
| 78 .tab.active-tab { | 97 |
| 79 background-color: #333; | |
| 80 } | |
| 81 .tab-content { | 98 .tab-content { |
| 82 display: none; | 99 display: none; |
| 83 padding: 20px; | 100 padding: 20px; |
| 84 border: 1px solid #ddd; | 101 border: 1px solid #ccc; |
| 85 border-top: none; | 102 border-top: none; |
| 86 background-color: white; | 103 background: white; |
| 87 } | 104 } |
| 88 .tab-content.active-content { | 105 .tab-content.active { |
| 89 display: block; | 106 display: block; |
| 90 } | 107 } |
| 91 </style> | 108 |
| 109 .help-btn { | |
| 110 margin-left: auto; | |
| 111 padding: 6px 12px; | |
| 112 font-size: 0.9rem; | |
| 113 border: 1px solid #4CAF50; | |
| 114 border-radius: 4px; | |
| 115 background: #4CAF50; | |
| 116 color: white; | |
| 117 cursor: pointer; | |
| 118 } | |
| 119 | |
| 120 /* sortable table header arrows */ | |
| 121 table.sortable th { | |
| 122 position: relative; | |
| 123 padding-right: 20px; /* room for the arrow */ | |
| 124 cursor: pointer; | |
| 125 } | |
| 126 table.sortable th::after { | |
| 127 content: '↕'; | |
| 128 position: absolute; | |
| 129 right: 8px; | |
| 130 opacity: 0.4; | |
| 131 transition: opacity 0.2s; | |
| 132 } | |
| 133 table.sortable th:hover::after { | |
| 134 opacity: 0.7; | |
| 135 } | |
| 136 table.sortable th.sorted-asc::after { | |
| 137 content: '↑'; | |
| 138 opacity: 1; | |
| 139 } | |
| 140 table.sortable th.sorted-desc::after { | |
| 141 content: '↓'; | |
| 142 opacity: 1; | |
| 143 } | |
| 144 </style> | |
| 92 </head> | 145 </head> |
| 93 <body> | 146 <body> |
| 94 <div class="container"> | 147 <div class="container"> |
| 95 """ | 148 """ |
| 96 | 149 |
| 97 | 150 |
| 98 def get_html_closing(): | 151 def get_html_closing() -> str: |
| 99 return """ | 152 return """ |
| 100 </div> | 153 </div> |
| 101 <script> | 154 <script> |
| 102 function openTab(evt, tabName) {{ | 155 document.addEventListener('DOMContentLoaded', () => { |
| 103 var i, tabcontent, tablinks; | 156 document.querySelectorAll('table.sortable').forEach(table => { |
| 104 tabcontent = document.getElementsByClassName("tab-content"); | 157 const getCellValue = (row, idx) => |
| 105 for (i = 0; i < tabcontent.length; i++) {{ | 158 row.children[idx].innerText.trim() || ''; |
| 106 tabcontent[i].style.display = "none"; | 159 |
| 107 }} | 160 const comparer = (idx, asc) => (a, b) => { |
| 108 tablinks = document.getElementsByClassName("tab"); | 161 const v1 = getCellValue(asc ? a : b, idx); |
| 109 for (i = 0; i < tablinks.length; i++) {{ | 162 const v2 = getCellValue(asc ? b : a, idx); |
| 110 tablinks[i].className = | 163 const n1 = parseFloat(v1), n2 = parseFloat(v2); |
| 111 tablinks[i].className.replace(" active-tab", ""); | 164 if (!isNaN(n1) && !isNaN(n2)) return n1 - n2; |
| 112 }} | 165 return v1.localeCompare(v2); |
| 113 document.getElementById(tabName).style.display = "block"; | 166 }; |
| 114 evt.currentTarget.className += " active-tab"; | 167 |
| 115 }} | 168 table.querySelectorAll('th').forEach((th, idx) => { |
| 116 document.addEventListener("DOMContentLoaded", function() {{ | 169 let asc = true; |
| 117 document.querySelector(".tab").click(); | 170 th.addEventListener('click', () => { |
| 118 }}); | 171 // sort rows |
| 119 </script> | 172 const tbody = table.tBodies[0]; |
| 173 Array.from(tbody.rows) | |
| 174 .sort(comparer(idx, asc)) | |
| 175 .forEach(row => tbody.appendChild(row)); | |
| 176 // update arrow classes | |
| 177 table.querySelectorAll('th').forEach(h => { | |
| 178 h.classList.remove('sorted-asc','sorted-desc'); | |
| 179 }); | |
| 180 th.classList.add(asc ? 'sorted-asc' : 'sorted-desc'); | |
| 181 asc = !asc; | |
| 182 }); | |
| 183 }); | |
| 184 }); | |
| 185 }); | |
| 186 </script> | |
| 120 </body> | 187 </body> |
| 121 </html> | 188 </html> |
| 122 """ | 189 """ |
| 123 | 190 |
| 124 | 191 |
| 192 def build_tabbed_html( | |
| 193 summary_html: str, | |
| 194 test_html: str, | |
| 195 feature_html: str, | |
| 196 explainer_html: Optional[str] = None, | |
| 197 ) -> str: | |
| 198 """ | |
| 199 Render the tabbed sections and an always-visible Help button. | |
| 200 """ | |
| 201 # CSS | |
| 202 css = get_html_template().split("<body>")[1].rsplit("</style>", 1)[0] + "</style>" | |
| 203 | |
| 204 # Tabs header | |
| 205 tabs = [ | |
| 206 '<div class="tabs">', | |
| 207 '<div class="tab active" onclick="showTab(\'summary\')">Validation Summary & Config</div>', | |
| 208 '<div class="tab" onclick="showTab(\'test\')">Test Summary</div>', | |
| 209 '<div class="tab" onclick="showTab(\'feature\')">Feature Importance</div>', | |
| 210 ] | |
| 211 if explainer_html: | |
| 212 tabs.append( | |
| 213 '<div class="tab" onclick="showTab(\'explainer\')">Explainer Plots</div>' | |
| 214 ) | |
| 215 tabs.append('<button id="openMetricsHelp" class="help-btn">Help</button>') | |
| 216 tabs.append("</div>") | |
| 217 tabs_section = "\n".join(tabs) | |
| 218 | |
| 219 # Content | |
| 220 contents = [ | |
| 221 f'<div id="summary" class="tab-content active">{summary_html}</div>', | |
| 222 f'<div id="test" class="tab-content">{test_html}</div>', | |
| 223 f'<div id="feature" class="tab-content">{feature_html}</div>', | |
| 224 ] | |
| 225 if explainer_html: | |
| 226 contents.append( | |
| 227 f'<div id="explainer" class="tab-content">{explainer_html}</div>' | |
| 228 ) | |
| 229 content_section = "\n".join(contents) | |
| 230 | |
| 231 # JS | |
| 232 js = """ | |
| 233 <script> | |
| 234 function showTab(id) { | |
| 235 document.querySelectorAll('.tab-content').forEach(el=>el.classList.remove('active')); | |
| 236 document.querySelectorAll('.tab').forEach(el=>el.classList.remove('active')); | |
| 237 document.getElementById(id).classList.add('active'); | |
| 238 document.querySelector(`.tab[onclick*="${id}"]`).classList.add('active'); | |
| 239 } | |
| 240 </script> | |
| 241 """ | |
| 242 | |
| 243 return css + "\n" + tabs_section + "\n" + content_section + "\n" + js | |
| 244 | |
| 245 | |
| 125 def customize_figure_layout(fig, margin_dict=None): | 246 def customize_figure_layout(fig, margin_dict=None): |
| 126 """ | |
| 127 Update the layout of a Plotly figure to reduce margins. | |
| 128 | |
| 129 Parameters: | |
| 130 fig (plotly.graph_objects.Figure): The Plotly figure to customize. | |
| 131 margin_dict (dict, optional): A dictionary specifying margin sizes. | |
| 132 Example: {'l': 10, 'r': 10, 't': 10, 'b': 10} | |
| 133 | |
| 134 Returns: | |
| 135 plotly.graph_objects.Figure: The updated Plotly figure. | |
| 136 """ | |
| 137 if margin_dict is None: | 247 if margin_dict is None: |
| 138 # Set default smaller margins | 248 margin_dict = {"l": 40, "r": 40, "t": 40, "b": 40} |
| 139 margin_dict = {'l': 40, 'r': 40, 't': 40, 'b': 40} | |
| 140 | |
| 141 fig.update_layout(margin=margin_dict) | 249 fig.update_layout(margin=margin_dict) |
| 142 return fig | 250 return fig |
| 143 | 251 |
| 144 | 252 |
| 145 def add_plot_to_html(fig, include_plotlyjs=True): | 253 def add_plot_to_html(fig, include_plotlyjs=True) -> str: |
| 146 custom_margin = {'l': 40, 'r': 40, 't': 60, 'b': 60} | 254 custom_margin = {"l": 40, "r": 40, "t": 60, "b": 60} |
| 147 fig = customize_figure_layout(fig, margin_dict=custom_margin) | 255 fig = customize_figure_layout(fig, margin_dict=custom_margin) |
| 148 return fig.to_html(full_html=False, | 256 return fig.to_html( |
| 149 default_height=350, | 257 full_html=False, |
| 150 include_plotlyjs="cdn" if include_plotlyjs else False) | 258 default_height=350, |
| 151 | 259 include_plotlyjs="cdn" if include_plotlyjs else False, |
| 152 | 260 ) |
| 153 def add_hr_to_html(): | 261 |
| 262 | |
| 263 def add_hr_to_html() -> str: | |
| 154 return "<hr>" | 264 return "<hr>" |
| 155 | 265 |
| 156 | 266 |
| 157 def encode_image_to_base64(image_path): | 267 def encode_image_to_base64(image_path: str) -> str: |
| 158 """Convert an image file to a base64 encoded string.""" | |
| 159 with open(image_path, "rb") as img_file: | 268 with open(image_path, "rb") as img_file: |
| 160 return base64.b64encode(img_file.read()).decode("utf-8") | 269 return base64.b64encode(img_file.read()).decode("utf-8") |
| 161 | 270 |
| 162 | 271 |
| 163 def predict_proba(self, X): | 272 def predict_proba(self, X): |
| 164 pred = self.predict(X) | 273 pred = self.predict(X) |
| 165 return np.array([1 - pred, pred]).T | 274 return np.vstack((1 - pred, pred)).T |
