Mercurial > repos > goeckslab > image_learner
comparison image_learner_cli.py @ 10:b0d893d04d4c draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1594d503179f28987720594eb49b48a15486f073
| author | goeckslab | 
|---|---|
| date | Mon, 08 Sep 2025 22:38:35 +0000 | 
| parents | 9e912fce264c | 
| children | c5150cceab47 | 
   comparison
  equal
  deleted
  inserted
  replaced
| 9:9e912fce264c | 10:b0d893d04d4c | 
|---|---|
| 67 "early_stop", | 67 "early_stop", | 
| 68 "threshold", | 68 "threshold", | 
| 69 ] | 69 ] | 
| 70 | 70 | 
| 71 rows = [] | 71 rows = [] | 
| 72 | |
| 73 for key in display_keys: | 72 for key in display_keys: | 
| 74 val = config.get(key, None) | 73 val = config.get(key, None) | 
| 75 if key == "threshold": | 74 if key == "threshold": | 
| 76 if output_type != "binary": | 75 if output_type != "binary": | 
| 77 continue | 76 continue | 
| 134 ) | 133 ) | 
| 135 else: | 134 else: | 
| 136 val_str = val | 135 val_str = val | 
| 137 else: | 136 else: | 
| 138 val_str = val if val is not None else "N/A" | 137 val_str = val if val is not None else "N/A" | 
| 139 if val_str == "N/A" and key not in [ | 138 if val_str == "N/A" and key not in ["task_type"]: | 
| 140 "task_type" | |
| 141 ]: # Skip if N/A for non-essential | |
| 142 continue | 139 continue | 
| 143 rows.append( | 140 rows.append( | 
| 144 f"<tr>" | 141 f"<tr>" | 
| 145 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" | 142 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; " | 
| 143 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>" | |
| 146 f"{key.replace('_', ' ').title()}</td>" | 144 f"{key.replace('_', ' ').title()}</td>" | 
| 147 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" | 145 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; " | 
| 146 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>" | |
| 148 f"{val_str}</td>" | 147 f"{val_str}</td>" | 
| 149 f"</tr>" | 148 f"</tr>" | 
| 150 ) | 149 ) | 
| 151 aug_cfg = config.get("augmentation") | 150 aug_cfg = config.get("augmentation") | 
| 152 if aug_cfg: | 151 if aug_cfg: | 
| 153 types = [str(a.get("type", "")) for a in aug_cfg] | 152 types = [str(a.get("type", "")) for a in aug_cfg] | 
| 154 aug_val = ", ".join(types) | 153 aug_val = ", ".join(types) | 
| 155 rows.append( | 154 rows.append( | 
| 156 f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>" | 155 f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; " | 
| 157 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{aug_val}</td></tr>" | 156 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>Augmentation</td>" | 
| 157 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; " | |
| 158 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>{aug_val}</td></tr>" | |
| 158 ) | 159 ) | 
| 159 if split_info: | 160 if split_info: | 
| 160 rows.append( | 161 rows.append( | 
| 161 f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td>" | 162 f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; " | 
| 162 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{split_info}</td></tr>" | 163 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>Data Split</td>" | 
| 164 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; " | |
| 165 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>{split_info}</td></tr>" | |
| 163 ) | 166 ) | 
| 164 html = f""" | 167 html = f""" | 
| 165 <h2 style="text-align: center;">Model and Training Summary</h2> | 168 <h2 style="text-align: center;">Model and Training Summary</h2> | 
| 166 <div style="display: flex; justify-content: center;"> | 169 <div style="display: flex; justify-content: center;"> | 
| 167 <table style="border-collapse: collapse; width: 100%; table-layout: fixed;"> | 170 <table style="border-collapse: collapse; width: 100%; table-layout: fixed;"> | 
| 944 base_viz_dir = exp_dir / "visualizations" | 947 base_viz_dir = exp_dir / "visualizations" | 
| 945 train_viz_dir = base_viz_dir / "train" | 948 train_viz_dir = base_viz_dir / "train" | 
| 946 test_viz_dir = base_viz_dir / "test" | 949 test_viz_dir = base_viz_dir / "test" | 
| 947 | 950 | 
| 948 html = get_html_template() | 951 html = get_html_template() | 
| 952 | |
| 953 # Extra CSS & JS: center Plotly and enable CSV download for predictions table | |
| 954 html += """ | |
| 955 <style> | |
| 956 /* Center Plotly figures (both wrapper and native classes) */ | |
| 957 .plotly-center { display: flex; justify-content: center; } | |
| 958 .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { margin: 0 auto !important; } | |
| 959 .js-plotly-plot, .plotly-graph-div { margin-left: auto !important; margin-right: auto !important; } | |
| 960 | |
| 961 /* Download button for predictions table */ | |
| 962 .download-btn { | |
| 963 padding: 8px 12px; | |
| 964 border: 1px solid #4CAF50; | |
| 965 background: #4CAF50; | |
| 966 color: white; | |
| 967 border-radius: 6px; | |
| 968 cursor: pointer; | |
| 969 } | |
| 970 .download-btn:hover { filter: brightness(0.95); } | |
| 971 .preds-controls { | |
| 972 display: flex; | |
| 973 justify-content: flex-end; | |
| 974 gap: 8px; | |
| 975 margin: 8px 0; | |
| 976 } | |
| 977 </style> | |
| 978 <script> | |
| 979 function tableToCSV(table){ | |
| 980 const rows = Array.from(table.querySelectorAll('tr')); | |
| 981 return rows.map(row => | |
| 982 Array.from(row.querySelectorAll('th,td')).map(cell => { | |
| 983 let text = cell.innerText.replace(/\\r?\\n|\\r/g,' ').trim(); | |
| 984 if (text.includes('"') || text.includes(',')) { | |
| 985 text = '"' + text.replace(/"/g,'""') + '"'; | |
| 986 } | |
| 987 return text; | |
| 988 }).join(',') | |
| 989 ).join('\\n'); | |
| 990 } | |
| 991 document.addEventListener('DOMContentLoaded', function(){ | |
| 992 const btn = document.getElementById('downloadPredsCsv'); | |
| 993 if(btn){ | |
| 994 btn.addEventListener('click', function(){ | |
| 995 const tbl = document.querySelector('.predictions-table'); | |
| 996 if(!tbl){ alert('Predictions table not found.'); return; } | |
| 997 const csv = tableToCSV(tbl); | |
| 998 const blob = new Blob([csv], {type: 'text/csv;charset=utf-8;'}); | |
| 999 const url = URL.createObjectURL(blob); | |
| 1000 const a = document.createElement('a'); | |
| 1001 a.href = url; | |
| 1002 a.download = 'ground_truth_vs_predictions.csv'; | |
| 1003 document.body.appendChild(a); | |
| 1004 a.click(); | |
| 1005 document.body.removeChild(a); | |
| 1006 URL.revokeObjectURL(url); | |
| 1007 }); | |
| 1008 } | |
| 1009 }); | |
| 1010 </script> | |
| 1011 """ | |
| 949 html += f"<h1>{title}</h1>" | 1012 html += f"<h1>{title}</h1>" | 
| 950 | 1013 | 
| 951 metrics_html = "" | 1014 metrics_html = "" | 
| 952 train_val_metrics_html = "" | 1015 train_val_metrics_html = "" | 
| 953 test_metrics_html = "" | 1016 test_metrics_html = "" | 
| 981 config, split_info, training_progress, output_type | 1044 config, split_info, training_progress, output_type | 
| 982 ) | 1045 ) | 
| 983 except Exception as e: | 1046 except Exception as e: | 
| 984 logger.warning(f"Could not load config for HTML report: {e}") | 1047 logger.warning(f"Could not load config for HTML report: {e}") | 
| 985 | 1048 | 
| 1049 # ---------- image rendering with exclusions ---------- | |
| 986 def render_img_section( | 1050 def render_img_section( | 
| 987 title: str, dir_path: Path, output_type: str = None | 1051 title: str, | 
| 1052 dir_path: Path, | |
| 1053 output_type: str = None, | |
| 1054 exclude_names: Optional[set] = None, | |
| 988 ) -> str: | 1055 ) -> str: | 
| 989 if not dir_path.exists(): | 1056 if not dir_path.exists(): | 
| 990 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" | 1057 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" | 
| 991 # collect every PNG | 1058 | 
| 1059 exclude_names = exclude_names or set() | |
| 1060 | |
| 992 imgs = list(dir_path.glob("*.png")) | 1061 imgs = list(dir_path.glob("*.png")) | 
| 993 # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files --- | 1062 | 
| 1063 default_exclude = {"confusion_matrix.png", "roc_curves.png"} | |
| 1064 | |
| 994 imgs = [ | 1065 imgs = [ | 
| 995 img | 1066 img | 
| 996 for img in imgs | 1067 for img in imgs | 
| 997 if not ( | 1068 if img.name not in default_exclude | 
| 998 img.name == "confusion_matrix.png" | 1069 and img.name not in exclude_names | 
| 999 or img.name.startswith("confusion_matrix__label_top") | 1070 and not img.name.startswith("confusion_matrix__label_top") | 
| 1000 or img.name == "roc_curves.png" | |
| 1001 ) | |
| 1002 ] | 1071 ] | 
| 1072 | |
| 1003 if not imgs: | 1073 if not imgs: | 
| 1004 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" | 1074 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" | 
| 1075 | |
| 1005 if output_type == "binary": | 1076 if output_type == "binary": | 
| 1006 order = [ | 1077 order = [ | 
| 1007 "roc_curves_from_prediction_statistics.png", | 1078 "roc_curves_from_prediction_statistics.png", | 
| 1008 "compare_performance_label.png", | 1079 "compare_performance_label.png", | 
| 1009 "confusion_matrix_entropy__label_top2.png", | 1080 "confusion_matrix_entropy__label_top2.png", | 
| 1010 # ...you can tweak ordering as needed | |
| 1011 ] | 1081 ] | 
| 1012 img_names = {img.name: img for img in imgs} | 1082 img_names = {img.name: img for img in imgs} | 
| 1013 ordered = [img_names[n] for n in order if n in img_names] | 1083 ordered = [img_names[n] for n in order if n in img_names] | 
| 1014 others = sorted(img for img in imgs if img.name not in order) | 1084 others = sorted(img for img in imgs if img.name not in order) | 
| 1015 imgs = ordered + others | 1085 imgs = ordered + others | 
| 1017 unwanted = { | 1087 unwanted = { | 
| 1018 "compare_classifiers_multiclass_multimetric__label_best10.png", | 1088 "compare_classifiers_multiclass_multimetric__label_best10.png", | 
| 1019 "compare_classifiers_multiclass_multimetric__label_top10.png", | 1089 "compare_classifiers_multiclass_multimetric__label_top10.png", | 
| 1020 "compare_classifiers_multiclass_multimetric__label_worst10.png", | 1090 "compare_classifiers_multiclass_multimetric__label_worst10.png", | 
| 1021 } | 1091 } | 
| 1092 valid_imgs = [img for img in imgs if img.name not in unwanted] | |
| 1022 display_order = [ | 1093 display_order = [ | 
| 1023 "roc_curves.png", | 1094 "roc_curves.png", | 
| 1024 "compare_performance_label.png", | 1095 "compare_performance_label.png", | 
| 1025 "compare_classifiers_performance_from_prob.png", | 1096 "compare_classifiers_performance_from_prob.png", | 
| 1026 "confusion_matrix_entropy__label_top10.png", | 1097 "confusion_matrix_entropy__label_top10.png", | 
| 1027 ] | 1098 ] | 
| 1028 # filter and order | |
| 1029 valid_imgs = [img for img in imgs if img.name not in unwanted] | |
| 1030 img_map = {img.name: img for img in valid_imgs} | 1099 img_map = {img.name: img for img in valid_imgs} | 
| 1031 ordered = [img_map[n] for n in display_order if n in img_map] | 1100 ordered = [img_map[n] for n in display_order if n in img_map] | 
| 1032 others = sorted( | 1101 others = sorted( | 
| 1033 img for img in valid_imgs if img.name not in display_order | 1102 img for img in valid_imgs if img.name not in display_order | 
| 1034 ) | 1103 ) | 
| 1035 imgs = ordered + others | 1104 imgs = ordered + others | 
| 1036 else: | 1105 else: | 
| 1037 # regression: just sort whatever's left | |
| 1038 imgs = sorted(imgs) | 1106 imgs = sorted(imgs) | 
| 1039 # render each remaining PNG | 1107 | 
| 1040 html = "" | 1108 html_section = "" | 
| 1041 for img in imgs: | 1109 for img in imgs: | 
| 1042 b64 = encode_image_to_base64(str(img)) | 1110 b64 = encode_image_to_base64(str(img)) | 
| 1043 img_title = img.stem.replace("_", " ").title() | 1111 img_title = img.stem.replace("_", " ").title() | 
| 1044 html += ( | 1112 html_section += ( | 
| 1045 f"<h2 style='text-align: center;'>{img_title}</h2>" | 1113 f"<h2 style='text-align: center;'>{img_title}</h2>" | 
| 1046 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' | 1114 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' | 
| 1047 f'<img src="data:image/png;base64,{b64}" ' | 1115 f'<img src="data:image/png;base64,{b64}" ' | 
| 1048 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' | 1116 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' | 
| 1049 f"</div>" | 1117 f"</div>" | 
| 1050 ) | 1118 ) | 
| 1051 return html | 1119 return html_section | 
| 1052 | 1120 | 
| 1053 tab1_content = config_html + metrics_html | 1121 tab1_content = config_html + metrics_html | 
| 1122 | |
| 1054 tab2_content = train_val_metrics_html + render_img_section( | 1123 tab2_content = train_val_metrics_html + render_img_section( | 
| 1055 "Training and Validation Visualizations", train_viz_dir | 1124 "Training and Validation Visualizations", | 
| 1056 ) | 1125 train_viz_dir, | 
| 1057 # --- Predictions vs Ground Truth table --- | 1126 output_type, | 
| 1127 exclude_names={ | |
| 1128 "compare_classifiers_performance_from_prob.png", | |
| 1129 "roc_curves_from_prediction_statistics.png", | |
| 1130 "precision_recall_curves_from_prediction_statistics.png", | |
| 1131 "precision_recall_curve.png", | |
| 1132 }, | |
| 1133 ) | |
| 1134 | |
| 1135 # --- Predictions vs Ground Truth table (REGRESSION ONLY) --- | |
| 1058 preds_section = "" | 1136 preds_section = "" | 
| 1059 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME | 1137 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME | 
| 1060 if output_type == "regression" and parquet_path.exists(): | 1138 if output_type == "regression" and parquet_path.exists(): | 
| 1061 try: | 1139 try: | 
| 1062 # 1) load predictions from Parquet | 1140 # 1) load predictions from Parquet | 
| 1079 df_table.columns = [LABEL_COLUMN_NAME, "prediction"] | 1157 df_table.columns = [LABEL_COLUMN_NAME, "prediction"] | 
| 1080 # 4) render as HTML | 1158 # 4) render as HTML | 
| 1081 preds_html = df_table.to_html(index=False, classes="predictions-table") | 1159 preds_html = df_table.to_html(index=False, classes="predictions-table") | 
| 1082 preds_section = ( | 1160 preds_section = ( | 
| 1083 "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>" | 1161 "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>" | 
| 1084 "<div style='overflow-y:auto; max-height:400px; overflow-x:auto; margin-bottom:20px;'>" | 1162 "<div class='preds-controls'>" | 
| 1163 "<button id='downloadPredsCsv' class='download-btn'>Download CSV</button>" | |
| 1164 "</div>" | |
| 1165 "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:900px; margin-bottom:20px;'>" | |
| 1085 + preds_html | 1166 + preds_html | 
| 1086 + "</div>" | 1167 + "</div>" | 
| 1087 ) | 1168 ) | 
| 1088 except Exception as e: | 1169 except Exception as e: | 
| 1089 logger.warning(f"Could not build Predictions vs GT table: {e}") | 1170 logger.warning(f"Could not build Predictions vs GT table: {e}") | 
| 1171 | |
| 1090 tab3_content = test_metrics_html + preds_section | 1172 tab3_content = test_metrics_html + preds_section | 
| 1173 | |
| 1174 # Classification-only interactive Plotly panels (centered) | |
| 1091 if output_type in ("binary", "category"): | 1175 if output_type in ("binary", "category"): | 
| 1092 training_stats_path = exp_dir / "training_statistics.json" | 1176 training_stats_path = exp_dir / "training_statistics.json" | 
| 1093 interactive_plots = build_classification_plots( | 1177 interactive_plots = build_classification_plots( | 
| 1094 str(test_stats_path), | 1178 str(test_stats_path), | 
| 1095 str(training_stats_path), | 1179 str(training_stats_path), | 
| 1096 ) | 1180 ) | 
| 1097 for plot in interactive_plots: | 1181 for plot in interactive_plots: | 
| 1098 # 2) inject the static "roc_curves_from_prediction_statistics.png" | |
| 1099 if plot["title"] == "ROC-AUC": | |
| 1100 static_img = ( | |
| 1101 test_viz_dir / "roc_curves_from_prediction_statistics.png" | |
| 1102 ) | |
| 1103 if static_img.exists(): | |
| 1104 b64 = encode_image_to_base64(str(static_img)) | |
| 1105 tab3_content += ( | |
| 1106 "<h2 style='text-align: center;'>" | |
| 1107 "Roc Curves From Prediction Statistics" | |
| 1108 "</h2>" | |
| 1109 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' | |
| 1110 f'<img src="data:image/png;base64,{b64}" ' | |
| 1111 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' | |
| 1112 "</div>" | |
| 1113 ) | |
| 1114 # always render the plotly panels exactly as before | |
| 1115 tab3_content += ( | 1182 tab3_content += ( | 
| 1116 f"<h2 style='text-align: center;'>{plot['title']}</h2>" | 1183 f"<h2 style='text-align: center;'>{plot['title']}</h2>" | 
| 1117 + plot["html"] | 1184 f"<div class='plotly-center'>{plot['html']}</div>" | 
| 1118 ) | 1185 ) | 
| 1119 tab3_content += render_img_section( | 1186 | 
| 1120 "Test Visualizations", test_viz_dir, output_type | 1187 # Add static TEST PNGs (with default dedupe/exclusions) | 
| 1121 ) | 1188 tab3_content += render_img_section( | 
| 1122 # assemble the tabs and help modal | 1189 "Test Visualizations", test_viz_dir, output_type | 
| 1190 ) | |
| 1191 | |
| 1123 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) | 1192 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) | 
| 1124 modal_html = get_metrics_help_modal() | 1193 modal_html = get_metrics_help_modal() | 
| 1125 html += tabbed_html + modal_html + get_html_closing() | 1194 html += tabbed_html + modal_html + get_html_closing() | 
| 1126 | 1195 | 
| 1127 try: | 1196 try: | 
