Mercurial > repos > goeckslab > image_learner
diff image_learner_cli.py @ 10:b0d893d04d4c draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1594d503179f28987720594eb49b48a15486f073
author | goeckslab |
---|---|
date | Mon, 08 Sep 2025 22:38:35 +0000 |
parents | 9e912fce264c |
children |
line wrap: on
line diff
--- a/image_learner_cli.py Wed Aug 27 21:02:48 2025 +0000 +++ b/image_learner_cli.py Mon Sep 08 22:38:35 2025 +0000 @@ -69,7 +69,6 @@ ] rows = [] - for key in display_keys: val = config.get(key, None) if key == "threshold": @@ -136,15 +135,15 @@ val_str = val else: val_str = val if val is not None else "N/A" - if val_str == "N/A" and key not in [ - "task_type" - ]: # Skip if N/A for non-essential + if val_str == "N/A" and key not in ["task_type"]: continue rows.append( f"<tr>" - f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" + f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; " + f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>" f"{key.replace('_', ' ').title()}</td>" - f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" + f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; " + f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>" f"{val_str}</td>" f"</tr>" ) @@ -153,13 +152,17 @@ types = [str(a.get("type", "")) for a in aug_cfg] aug_val = ", ".join(types) rows.append( - f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>" - f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{aug_val}</td></tr>" + f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; " + f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>Augmentation</td>" + f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; " + f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>{aug_val}</td></tr>" ) if split_info: rows.append( - f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td>" - f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{split_info}</td></tr>" + f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; " + f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>Data Split</td>" + f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; " + f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>{split_info}</td></tr>" ) html = f""" <h2 style="text-align: center;">Model and Training Summary</h2> @@ -946,6 +949,66 @@ test_viz_dir = base_viz_dir / "test" html = get_html_template() + + # Extra CSS & JS: center Plotly and enable CSV download for predictions table + html += """ +<style> + /* Center Plotly figures (both wrapper and native classes) */ + .plotly-center { display: flex; justify-content: center; } + .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { margin: 0 auto !important; } + .js-plotly-plot, .plotly-graph-div { margin-left: auto !important; margin-right: auto !important; } + + /* Download button for predictions table */ + .download-btn { + padding: 8px 12px; + border: 1px solid #4CAF50; + background: #4CAF50; + color: white; + border-radius: 6px; + cursor: pointer; + } + .download-btn:hover { filter: brightness(0.95); } + .preds-controls { + display: flex; + justify-content: flex-end; + gap: 8px; + margin: 8px 0; + } +</style> +<script> + function tableToCSV(table){ + const rows = Array.from(table.querySelectorAll('tr')); + return rows.map(row => + Array.from(row.querySelectorAll('th,td')).map(cell => { + let text = cell.innerText.replace(/\\r?\\n|\\r/g,' ').trim(); + if (text.includes('"') || text.includes(',')) { + text = '"' + text.replace(/"/g,'""') + '"'; + } + return text; + }).join(',') + ).join('\\n'); + } + document.addEventListener('DOMContentLoaded', function(){ + const btn = document.getElementById('downloadPredsCsv'); + if(btn){ + btn.addEventListener('click', function(){ + const tbl = document.querySelector('.predictions-table'); + if(!tbl){ alert('Predictions table not found.'); return; } + const csv = tableToCSV(tbl); + const blob = new Blob([csv], {type: 'text/csv;charset=utf-8;'}); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = 'ground_truth_vs_predictions.csv'; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + }); + } + }); +</script> +""" html += f"<h1>{title}</h1>" metrics_html = "" @@ -983,31 +1046,38 @@ except Exception as e: logger.warning(f"Could not load config for HTML report: {e}") + # ---------- image rendering with exclusions ---------- def render_img_section( - title: str, dir_path: Path, output_type: str = None + title: str, + dir_path: Path, + output_type: str = None, + exclude_names: Optional[set] = None, ) -> str: if not dir_path.exists(): return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" - # collect every PNG + + exclude_names = exclude_names or set() + imgs = list(dir_path.glob("*.png")) - # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files --- + + default_exclude = {"confusion_matrix.png", "roc_curves.png"} + imgs = [ img for img in imgs - if not ( - img.name == "confusion_matrix.png" - or img.name.startswith("confusion_matrix__label_top") - or img.name == "roc_curves.png" - ) + if img.name not in default_exclude + and img.name not in exclude_names + and not img.name.startswith("confusion_matrix__label_top") ] + if not imgs: return f"<h2>{title}</h2><p><em>No plots found.</em></p>" + if output_type == "binary": order = [ "roc_curves_from_prediction_statistics.png", "compare_performance_label.png", "confusion_matrix_entropy__label_top2.png", - # ...you can tweak ordering as needed ] img_names = {img.name: img for img in imgs} ordered = [img_names[n] for n in order if n in img_names] @@ -1019,14 +1089,13 @@ "compare_classifiers_multiclass_multimetric__label_top10.png", "compare_classifiers_multiclass_multimetric__label_worst10.png", } + valid_imgs = [img for img in imgs if img.name not in unwanted] display_order = [ "roc_curves.png", "compare_performance_label.png", "compare_classifiers_performance_from_prob.png", "confusion_matrix_entropy__label_top10.png", ] - # filter and order - valid_imgs = [img for img in imgs if img.name not in unwanted] img_map = {img.name: img for img in valid_imgs} ordered = [img_map[n] for n in display_order if n in img_map] others = sorted( @@ -1034,27 +1103,36 @@ ) imgs = ordered + others else: - # regression: just sort whatever's left imgs = sorted(imgs) - # render each remaining PNG - html = "" + + html_section = "" for img in imgs: b64 = encode_image_to_base64(str(img)) img_title = img.stem.replace("_", " ").title() - html += ( + html_section += ( f"<h2 style='text-align: center;'>{img_title}</h2>" f'<div class="plot" style="margin-bottom:20px;text-align:center;">' f'<img src="data:image/png;base64,{b64}" ' f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' f"</div>" ) - return html + return html_section tab1_content = config_html + metrics_html + tab2_content = train_val_metrics_html + render_img_section( - "Training and Validation Visualizations", train_viz_dir + "Training and Validation Visualizations", + train_viz_dir, + output_type, + exclude_names={ + "compare_classifiers_performance_from_prob.png", + "roc_curves_from_prediction_statistics.png", + "precision_recall_curves_from_prediction_statistics.png", + "precision_recall_curve.png", + }, ) - # --- Predictions vs Ground Truth table --- + + # --- Predictions vs Ground Truth table (REGRESSION ONLY) --- preds_section = "" parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME if output_type == "regression" and parquet_path.exists(): @@ -1081,13 +1159,19 @@ preds_html = df_table.to_html(index=False, classes="predictions-table") preds_section = ( "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>" - "<div style='overflow-y:auto; max-height:400px; overflow-x:auto; margin-bottom:20px;'>" + "<div class='preds-controls'>" + "<button id='downloadPredsCsv' class='download-btn'>Download CSV</button>" + "</div>" + "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:900px; margin-bottom:20px;'>" + preds_html + "</div>" ) except Exception as e: logger.warning(f"Could not build Predictions vs GT table: {e}") + tab3_content = test_metrics_html + preds_section + + # Classification-only interactive Plotly panels (centered) if output_type in ("binary", "category"): training_stats_path = exp_dir / "training_statistics.json" interactive_plots = build_classification_plots( @@ -1095,31 +1179,16 @@ str(training_stats_path), ) for plot in interactive_plots: - # 2) inject the static "roc_curves_from_prediction_statistics.png" - if plot["title"] == "ROC-AUC": - static_img = ( - test_viz_dir / "roc_curves_from_prediction_statistics.png" - ) - if static_img.exists(): - b64 = encode_image_to_base64(str(static_img)) - tab3_content += ( - "<h2 style='text-align: center;'>" - "Roc Curves From Prediction Statistics" - "</h2>" - f'<div class="plot" style="margin-bottom:20px;text-align:center;">' - f'<img src="data:image/png;base64,{b64}" ' - f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' - "</div>" - ) - # always render the plotly panels exactly as before tab3_content += ( f"<h2 style='text-align: center;'>{plot['title']}</h2>" - + plot["html"] + f"<div class='plotly-center'>{plot['html']}</div>" ) - tab3_content += render_img_section( - "Test Visualizations", test_viz_dir, output_type - ) - # assemble the tabs and help modal + + # Add static TEST PNGs (with default dedupe/exclusions) + tab3_content += render_img_section( + "Test Visualizations", test_viz_dir, output_type + ) + tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) modal_html = get_metrics_help_modal() html += tabbed_html + modal_html + get_html_closing()