Mercurial > repos > goeckslab > image_learner
comparison image_learner_cli.py @ 10:b0d893d04d4c draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1594d503179f28987720594eb49b48a15486f073
author | goeckslab |
---|---|
date | Mon, 08 Sep 2025 22:38:35 +0000 |
parents | 9e912fce264c |
children |
comparison
equal
deleted
inserted
replaced
9:9e912fce264c | 10:b0d893d04d4c |
---|---|
67 "early_stop", | 67 "early_stop", |
68 "threshold", | 68 "threshold", |
69 ] | 69 ] |
70 | 70 |
71 rows = [] | 71 rows = [] |
72 | |
73 for key in display_keys: | 72 for key in display_keys: |
74 val = config.get(key, None) | 73 val = config.get(key, None) |
75 if key == "threshold": | 74 if key == "threshold": |
76 if output_type != "binary": | 75 if output_type != "binary": |
77 continue | 76 continue |
134 ) | 133 ) |
135 else: | 134 else: |
136 val_str = val | 135 val_str = val |
137 else: | 136 else: |
138 val_str = val if val is not None else "N/A" | 137 val_str = val if val is not None else "N/A" |
139 if val_str == "N/A" and key not in [ | 138 if val_str == "N/A" and key not in ["task_type"]: |
140 "task_type" | |
141 ]: # Skip if N/A for non-essential | |
142 continue | 139 continue |
143 rows.append( | 140 rows.append( |
144 f"<tr>" | 141 f"<tr>" |
145 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" | 142 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; " |
143 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>" | |
146 f"{key.replace('_', ' ').title()}</td>" | 144 f"{key.replace('_', ' ').title()}</td>" |
147 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" | 145 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; " |
146 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>" | |
148 f"{val_str}</td>" | 147 f"{val_str}</td>" |
149 f"</tr>" | 148 f"</tr>" |
150 ) | 149 ) |
151 aug_cfg = config.get("augmentation") | 150 aug_cfg = config.get("augmentation") |
152 if aug_cfg: | 151 if aug_cfg: |
153 types = [str(a.get("type", "")) for a in aug_cfg] | 152 types = [str(a.get("type", "")) for a in aug_cfg] |
154 aug_val = ", ".join(types) | 153 aug_val = ", ".join(types) |
155 rows.append( | 154 rows.append( |
156 f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>" | 155 f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; " |
157 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{aug_val}</td></tr>" | 156 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>Augmentation</td>" |
157 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; " | |
158 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>{aug_val}</td></tr>" | |
158 ) | 159 ) |
159 if split_info: | 160 if split_info: |
160 rows.append( | 161 rows.append( |
161 f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td>" | 162 f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; " |
162 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{split_info}</td></tr>" | 163 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>Data Split</td>" |
164 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; " | |
165 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>{split_info}</td></tr>" | |
163 ) | 166 ) |
164 html = f""" | 167 html = f""" |
165 <h2 style="text-align: center;">Model and Training Summary</h2> | 168 <h2 style="text-align: center;">Model and Training Summary</h2> |
166 <div style="display: flex; justify-content: center;"> | 169 <div style="display: flex; justify-content: center;"> |
167 <table style="border-collapse: collapse; width: 100%; table-layout: fixed;"> | 170 <table style="border-collapse: collapse; width: 100%; table-layout: fixed;"> |
944 base_viz_dir = exp_dir / "visualizations" | 947 base_viz_dir = exp_dir / "visualizations" |
945 train_viz_dir = base_viz_dir / "train" | 948 train_viz_dir = base_viz_dir / "train" |
946 test_viz_dir = base_viz_dir / "test" | 949 test_viz_dir = base_viz_dir / "test" |
947 | 950 |
948 html = get_html_template() | 951 html = get_html_template() |
952 | |
953 # Extra CSS & JS: center Plotly and enable CSV download for predictions table | |
954 html += """ | |
955 <style> | |
956 /* Center Plotly figures (both wrapper and native classes) */ | |
957 .plotly-center { display: flex; justify-content: center; } | |
958 .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { margin: 0 auto !important; } | |
959 .js-plotly-plot, .plotly-graph-div { margin-left: auto !important; margin-right: auto !important; } | |
960 | |
961 /* Download button for predictions table */ | |
962 .download-btn { | |
963 padding: 8px 12px; | |
964 border: 1px solid #4CAF50; | |
965 background: #4CAF50; | |
966 color: white; | |
967 border-radius: 6px; | |
968 cursor: pointer; | |
969 } | |
970 .download-btn:hover { filter: brightness(0.95); } | |
971 .preds-controls { | |
972 display: flex; | |
973 justify-content: flex-end; | |
974 gap: 8px; | |
975 margin: 8px 0; | |
976 } | |
977 </style> | |
978 <script> | |
979 function tableToCSV(table){ | |
980 const rows = Array.from(table.querySelectorAll('tr')); | |
981 return rows.map(row => | |
982 Array.from(row.querySelectorAll('th,td')).map(cell => { | |
983 let text = cell.innerText.replace(/\\r?\\n|\\r/g,' ').trim(); | |
984 if (text.includes('"') || text.includes(',')) { | |
985 text = '"' + text.replace(/"/g,'""') + '"'; | |
986 } | |
987 return text; | |
988 }).join(',') | |
989 ).join('\\n'); | |
990 } | |
991 document.addEventListener('DOMContentLoaded', function(){ | |
992 const btn = document.getElementById('downloadPredsCsv'); | |
993 if(btn){ | |
994 btn.addEventListener('click', function(){ | |
995 const tbl = document.querySelector('.predictions-table'); | |
996 if(!tbl){ alert('Predictions table not found.'); return; } | |
997 const csv = tableToCSV(tbl); | |
998 const blob = new Blob([csv], {type: 'text/csv;charset=utf-8;'}); | |
999 const url = URL.createObjectURL(blob); | |
1000 const a = document.createElement('a'); | |
1001 a.href = url; | |
1002 a.download = 'ground_truth_vs_predictions.csv'; | |
1003 document.body.appendChild(a); | |
1004 a.click(); | |
1005 document.body.removeChild(a); | |
1006 URL.revokeObjectURL(url); | |
1007 }); | |
1008 } | |
1009 }); | |
1010 </script> | |
1011 """ | |
949 html += f"<h1>{title}</h1>" | 1012 html += f"<h1>{title}</h1>" |
950 | 1013 |
951 metrics_html = "" | 1014 metrics_html = "" |
952 train_val_metrics_html = "" | 1015 train_val_metrics_html = "" |
953 test_metrics_html = "" | 1016 test_metrics_html = "" |
981 config, split_info, training_progress, output_type | 1044 config, split_info, training_progress, output_type |
982 ) | 1045 ) |
983 except Exception as e: | 1046 except Exception as e: |
984 logger.warning(f"Could not load config for HTML report: {e}") | 1047 logger.warning(f"Could not load config for HTML report: {e}") |
985 | 1048 |
1049 # ---------- image rendering with exclusions ---------- | |
986 def render_img_section( | 1050 def render_img_section( |
987 title: str, dir_path: Path, output_type: str = None | 1051 title: str, |
1052 dir_path: Path, | |
1053 output_type: str = None, | |
1054 exclude_names: Optional[set] = None, | |
988 ) -> str: | 1055 ) -> str: |
989 if not dir_path.exists(): | 1056 if not dir_path.exists(): |
990 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" | 1057 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" |
991 # collect every PNG | 1058 |
1059 exclude_names = exclude_names or set() | |
1060 | |
992 imgs = list(dir_path.glob("*.png")) | 1061 imgs = list(dir_path.glob("*.png")) |
993 # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files --- | 1062 |
1063 default_exclude = {"confusion_matrix.png", "roc_curves.png"} | |
1064 | |
994 imgs = [ | 1065 imgs = [ |
995 img | 1066 img |
996 for img in imgs | 1067 for img in imgs |
997 if not ( | 1068 if img.name not in default_exclude |
998 img.name == "confusion_matrix.png" | 1069 and img.name not in exclude_names |
999 or img.name.startswith("confusion_matrix__label_top") | 1070 and not img.name.startswith("confusion_matrix__label_top") |
1000 or img.name == "roc_curves.png" | |
1001 ) | |
1002 ] | 1071 ] |
1072 | |
1003 if not imgs: | 1073 if not imgs: |
1004 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" | 1074 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" |
1075 | |
1005 if output_type == "binary": | 1076 if output_type == "binary": |
1006 order = [ | 1077 order = [ |
1007 "roc_curves_from_prediction_statistics.png", | 1078 "roc_curves_from_prediction_statistics.png", |
1008 "compare_performance_label.png", | 1079 "compare_performance_label.png", |
1009 "confusion_matrix_entropy__label_top2.png", | 1080 "confusion_matrix_entropy__label_top2.png", |
1010 # ...you can tweak ordering as needed | |
1011 ] | 1081 ] |
1012 img_names = {img.name: img for img in imgs} | 1082 img_names = {img.name: img for img in imgs} |
1013 ordered = [img_names[n] for n in order if n in img_names] | 1083 ordered = [img_names[n] for n in order if n in img_names] |
1014 others = sorted(img for img in imgs if img.name not in order) | 1084 others = sorted(img for img in imgs if img.name not in order) |
1015 imgs = ordered + others | 1085 imgs = ordered + others |
1017 unwanted = { | 1087 unwanted = { |
1018 "compare_classifiers_multiclass_multimetric__label_best10.png", | 1088 "compare_classifiers_multiclass_multimetric__label_best10.png", |
1019 "compare_classifiers_multiclass_multimetric__label_top10.png", | 1089 "compare_classifiers_multiclass_multimetric__label_top10.png", |
1020 "compare_classifiers_multiclass_multimetric__label_worst10.png", | 1090 "compare_classifiers_multiclass_multimetric__label_worst10.png", |
1021 } | 1091 } |
1092 valid_imgs = [img for img in imgs if img.name not in unwanted] | |
1022 display_order = [ | 1093 display_order = [ |
1023 "roc_curves.png", | 1094 "roc_curves.png", |
1024 "compare_performance_label.png", | 1095 "compare_performance_label.png", |
1025 "compare_classifiers_performance_from_prob.png", | 1096 "compare_classifiers_performance_from_prob.png", |
1026 "confusion_matrix_entropy__label_top10.png", | 1097 "confusion_matrix_entropy__label_top10.png", |
1027 ] | 1098 ] |
1028 # filter and order | |
1029 valid_imgs = [img for img in imgs if img.name not in unwanted] | |
1030 img_map = {img.name: img for img in valid_imgs} | 1099 img_map = {img.name: img for img in valid_imgs} |
1031 ordered = [img_map[n] for n in display_order if n in img_map] | 1100 ordered = [img_map[n] for n in display_order if n in img_map] |
1032 others = sorted( | 1101 others = sorted( |
1033 img for img in valid_imgs if img.name not in display_order | 1102 img for img in valid_imgs if img.name not in display_order |
1034 ) | 1103 ) |
1035 imgs = ordered + others | 1104 imgs = ordered + others |
1036 else: | 1105 else: |
1037 # regression: just sort whatever's left | |
1038 imgs = sorted(imgs) | 1106 imgs = sorted(imgs) |
1039 # render each remaining PNG | 1107 |
1040 html = "" | 1108 html_section = "" |
1041 for img in imgs: | 1109 for img in imgs: |
1042 b64 = encode_image_to_base64(str(img)) | 1110 b64 = encode_image_to_base64(str(img)) |
1043 img_title = img.stem.replace("_", " ").title() | 1111 img_title = img.stem.replace("_", " ").title() |
1044 html += ( | 1112 html_section += ( |
1045 f"<h2 style='text-align: center;'>{img_title}</h2>" | 1113 f"<h2 style='text-align: center;'>{img_title}</h2>" |
1046 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' | 1114 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' |
1047 f'<img src="data:image/png;base64,{b64}" ' | 1115 f'<img src="data:image/png;base64,{b64}" ' |
1048 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' | 1116 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' |
1049 f"</div>" | 1117 f"</div>" |
1050 ) | 1118 ) |
1051 return html | 1119 return html_section |
1052 | 1120 |
1053 tab1_content = config_html + metrics_html | 1121 tab1_content = config_html + metrics_html |
1122 | |
1054 tab2_content = train_val_metrics_html + render_img_section( | 1123 tab2_content = train_val_metrics_html + render_img_section( |
1055 "Training and Validation Visualizations", train_viz_dir | 1124 "Training and Validation Visualizations", |
1056 ) | 1125 train_viz_dir, |
1057 # --- Predictions vs Ground Truth table --- | 1126 output_type, |
1127 exclude_names={ | |
1128 "compare_classifiers_performance_from_prob.png", | |
1129 "roc_curves_from_prediction_statistics.png", | |
1130 "precision_recall_curves_from_prediction_statistics.png", | |
1131 "precision_recall_curve.png", | |
1132 }, | |
1133 ) | |
1134 | |
1135 # --- Predictions vs Ground Truth table (REGRESSION ONLY) --- | |
1058 preds_section = "" | 1136 preds_section = "" |
1059 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME | 1137 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME |
1060 if output_type == "regression" and parquet_path.exists(): | 1138 if output_type == "regression" and parquet_path.exists(): |
1061 try: | 1139 try: |
1062 # 1) load predictions from Parquet | 1140 # 1) load predictions from Parquet |
1079 df_table.columns = [LABEL_COLUMN_NAME, "prediction"] | 1157 df_table.columns = [LABEL_COLUMN_NAME, "prediction"] |
1080 # 4) render as HTML | 1158 # 4) render as HTML |
1081 preds_html = df_table.to_html(index=False, classes="predictions-table") | 1159 preds_html = df_table.to_html(index=False, classes="predictions-table") |
1082 preds_section = ( | 1160 preds_section = ( |
1083 "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>" | 1161 "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>" |
1084 "<div style='overflow-y:auto; max-height:400px; overflow-x:auto; margin-bottom:20px;'>" | 1162 "<div class='preds-controls'>" |
1163 "<button id='downloadPredsCsv' class='download-btn'>Download CSV</button>" | |
1164 "</div>" | |
1165 "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:900px; margin-bottom:20px;'>" | |
1085 + preds_html | 1166 + preds_html |
1086 + "</div>" | 1167 + "</div>" |
1087 ) | 1168 ) |
1088 except Exception as e: | 1169 except Exception as e: |
1089 logger.warning(f"Could not build Predictions vs GT table: {e}") | 1170 logger.warning(f"Could not build Predictions vs GT table: {e}") |
1171 | |
1090 tab3_content = test_metrics_html + preds_section | 1172 tab3_content = test_metrics_html + preds_section |
1173 | |
1174 # Classification-only interactive Plotly panels (centered) | |
1091 if output_type in ("binary", "category"): | 1175 if output_type in ("binary", "category"): |
1092 training_stats_path = exp_dir / "training_statistics.json" | 1176 training_stats_path = exp_dir / "training_statistics.json" |
1093 interactive_plots = build_classification_plots( | 1177 interactive_plots = build_classification_plots( |
1094 str(test_stats_path), | 1178 str(test_stats_path), |
1095 str(training_stats_path), | 1179 str(training_stats_path), |
1096 ) | 1180 ) |
1097 for plot in interactive_plots: | 1181 for plot in interactive_plots: |
1098 # 2) inject the static "roc_curves_from_prediction_statistics.png" | |
1099 if plot["title"] == "ROC-AUC": | |
1100 static_img = ( | |
1101 test_viz_dir / "roc_curves_from_prediction_statistics.png" | |
1102 ) | |
1103 if static_img.exists(): | |
1104 b64 = encode_image_to_base64(str(static_img)) | |
1105 tab3_content += ( | |
1106 "<h2 style='text-align: center;'>" | |
1107 "Roc Curves From Prediction Statistics" | |
1108 "</h2>" | |
1109 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' | |
1110 f'<img src="data:image/png;base64,{b64}" ' | |
1111 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' | |
1112 "</div>" | |
1113 ) | |
1114 # always render the plotly panels exactly as before | |
1115 tab3_content += ( | 1182 tab3_content += ( |
1116 f"<h2 style='text-align: center;'>{plot['title']}</h2>" | 1183 f"<h2 style='text-align: center;'>{plot['title']}</h2>" |
1117 + plot["html"] | 1184 f"<div class='plotly-center'>{plot['html']}</div>" |
1118 ) | 1185 ) |
1119 tab3_content += render_img_section( | 1186 |
1120 "Test Visualizations", test_viz_dir, output_type | 1187 # Add static TEST PNGs (with default dedupe/exclusions) |
1121 ) | 1188 tab3_content += render_img_section( |
1122 # assemble the tabs and help modal | 1189 "Test Visualizations", test_viz_dir, output_type |
1190 ) | |
1191 | |
1123 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) | 1192 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) |
1124 modal_html = get_metrics_help_modal() | 1193 modal_html = get_metrics_help_modal() |
1125 html += tabbed_html + modal_html + get_html_closing() | 1194 html += tabbed_html + modal_html + get_html_closing() |
1126 | 1195 |
1127 try: | 1196 try: |