comparison image_learner_cli.py @ 10:b0d893d04d4c draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 1594d503179f28987720594eb49b48a15486f073
author goeckslab
date Mon, 08 Sep 2025 22:38:35 +0000
parents 9e912fce264c
children
comparison
equal deleted inserted replaced
9:9e912fce264c 10:b0d893d04d4c
67 "early_stop", 67 "early_stop",
68 "threshold", 68 "threshold",
69 ] 69 ]
70 70
71 rows = [] 71 rows = []
72
73 for key in display_keys: 72 for key in display_keys:
74 val = config.get(key, None) 73 val = config.get(key, None)
75 if key == "threshold": 74 if key == "threshold":
76 if output_type != "binary": 75 if output_type != "binary":
77 continue 76 continue
134 ) 133 )
135 else: 134 else:
136 val_str = val 135 val_str = val
137 else: 136 else:
138 val_str = val if val is not None else "N/A" 137 val_str = val if val is not None else "N/A"
139 if val_str == "N/A" and key not in [ 138 if val_str == "N/A" and key not in ["task_type"]:
140 "task_type"
141 ]: # Skip if N/A for non-essential
142 continue 139 continue
143 rows.append( 140 rows.append(
144 f"<tr>" 141 f"<tr>"
145 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" 142 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; "
143 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>"
146 f"{key.replace('_', ' ').title()}</td>" 144 f"{key.replace('_', ' ').title()}</td>"
147 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" 145 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; "
146 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>"
148 f"{val_str}</td>" 147 f"{val_str}</td>"
149 f"</tr>" 148 f"</tr>"
150 ) 149 )
151 aug_cfg = config.get("augmentation") 150 aug_cfg = config.get("augmentation")
152 if aug_cfg: 151 if aug_cfg:
153 types = [str(a.get("type", "")) for a in aug_cfg] 152 types = [str(a.get("type", "")) for a in aug_cfg]
154 aug_val = ", ".join(types) 153 aug_val = ", ".join(types)
155 rows.append( 154 rows.append(
156 f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>" 155 f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; "
157 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{aug_val}</td></tr>" 156 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>Augmentation</td>"
157 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; "
158 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>{aug_val}</td></tr>"
158 ) 159 )
159 if split_info: 160 if split_info:
160 rows.append( 161 rows.append(
161 f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td>" 162 f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left; "
162 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{split_info}</td></tr>" 163 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>Data Split</td>"
164 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center; "
165 f"white-space: normal; word-break: break-word; overflow-wrap: anywhere;'>{split_info}</td></tr>"
163 ) 166 )
164 html = f""" 167 html = f"""
165 <h2 style="text-align: center;">Model and Training Summary</h2> 168 <h2 style="text-align: center;">Model and Training Summary</h2>
166 <div style="display: flex; justify-content: center;"> 169 <div style="display: flex; justify-content: center;">
167 <table style="border-collapse: collapse; width: 100%; table-layout: fixed;"> 170 <table style="border-collapse: collapse; width: 100%; table-layout: fixed;">
944 base_viz_dir = exp_dir / "visualizations" 947 base_viz_dir = exp_dir / "visualizations"
945 train_viz_dir = base_viz_dir / "train" 948 train_viz_dir = base_viz_dir / "train"
946 test_viz_dir = base_viz_dir / "test" 949 test_viz_dir = base_viz_dir / "test"
947 950
948 html = get_html_template() 951 html = get_html_template()
952
953 # Extra CSS & JS: center Plotly and enable CSV download for predictions table
954 html += """
955 <style>
956 /* Center Plotly figures (both wrapper and native classes) */
957 .plotly-center { display: flex; justify-content: center; }
958 .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { margin: 0 auto !important; }
959 .js-plotly-plot, .plotly-graph-div { margin-left: auto !important; margin-right: auto !important; }
960
961 /* Download button for predictions table */
962 .download-btn {
963 padding: 8px 12px;
964 border: 1px solid #4CAF50;
965 background: #4CAF50;
966 color: white;
967 border-radius: 6px;
968 cursor: pointer;
969 }
970 .download-btn:hover { filter: brightness(0.95); }
971 .preds-controls {
972 display: flex;
973 justify-content: flex-end;
974 gap: 8px;
975 margin: 8px 0;
976 }
977 </style>
978 <script>
979 function tableToCSV(table){
980 const rows = Array.from(table.querySelectorAll('tr'));
981 return rows.map(row =>
982 Array.from(row.querySelectorAll('th,td')).map(cell => {
983 let text = cell.innerText.replace(/\\r?\\n|\\r/g,' ').trim();
984 if (text.includes('"') || text.includes(',')) {
985 text = '"' + text.replace(/"/g,'""') + '"';
986 }
987 return text;
988 }).join(',')
989 ).join('\\n');
990 }
991 document.addEventListener('DOMContentLoaded', function(){
992 const btn = document.getElementById('downloadPredsCsv');
993 if(btn){
994 btn.addEventListener('click', function(){
995 const tbl = document.querySelector('.predictions-table');
996 if(!tbl){ alert('Predictions table not found.'); return; }
997 const csv = tableToCSV(tbl);
998 const blob = new Blob([csv], {type: 'text/csv;charset=utf-8;'});
999 const url = URL.createObjectURL(blob);
1000 const a = document.createElement('a');
1001 a.href = url;
1002 a.download = 'ground_truth_vs_predictions.csv';
1003 document.body.appendChild(a);
1004 a.click();
1005 document.body.removeChild(a);
1006 URL.revokeObjectURL(url);
1007 });
1008 }
1009 });
1010 </script>
1011 """
949 html += f"<h1>{title}</h1>" 1012 html += f"<h1>{title}</h1>"
950 1013
951 metrics_html = "" 1014 metrics_html = ""
952 train_val_metrics_html = "" 1015 train_val_metrics_html = ""
953 test_metrics_html = "" 1016 test_metrics_html = ""
981 config, split_info, training_progress, output_type 1044 config, split_info, training_progress, output_type
982 ) 1045 )
983 except Exception as e: 1046 except Exception as e:
984 logger.warning(f"Could not load config for HTML report: {e}") 1047 logger.warning(f"Could not load config for HTML report: {e}")
985 1048
1049 # ---------- image rendering with exclusions ----------
986 def render_img_section( 1050 def render_img_section(
987 title: str, dir_path: Path, output_type: str = None 1051 title: str,
1052 dir_path: Path,
1053 output_type: str = None,
1054 exclude_names: Optional[set] = None,
988 ) -> str: 1055 ) -> str:
989 if not dir_path.exists(): 1056 if not dir_path.exists():
990 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" 1057 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>"
991 # collect every PNG 1058
1059 exclude_names = exclude_names or set()
1060
992 imgs = list(dir_path.glob("*.png")) 1061 imgs = list(dir_path.glob("*.png"))
993 # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files --- 1062
1063 default_exclude = {"confusion_matrix.png", "roc_curves.png"}
1064
994 imgs = [ 1065 imgs = [
995 img 1066 img
996 for img in imgs 1067 for img in imgs
997 if not ( 1068 if img.name not in default_exclude
998 img.name == "confusion_matrix.png" 1069 and img.name not in exclude_names
999 or img.name.startswith("confusion_matrix__label_top") 1070 and not img.name.startswith("confusion_matrix__label_top")
1000 or img.name == "roc_curves.png"
1001 )
1002 ] 1071 ]
1072
1003 if not imgs: 1073 if not imgs:
1004 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" 1074 return f"<h2>{title}</h2><p><em>No plots found.</em></p>"
1075
1005 if output_type == "binary": 1076 if output_type == "binary":
1006 order = [ 1077 order = [
1007 "roc_curves_from_prediction_statistics.png", 1078 "roc_curves_from_prediction_statistics.png",
1008 "compare_performance_label.png", 1079 "compare_performance_label.png",
1009 "confusion_matrix_entropy__label_top2.png", 1080 "confusion_matrix_entropy__label_top2.png",
1010 # ...you can tweak ordering as needed
1011 ] 1081 ]
1012 img_names = {img.name: img for img in imgs} 1082 img_names = {img.name: img for img in imgs}
1013 ordered = [img_names[n] for n in order if n in img_names] 1083 ordered = [img_names[n] for n in order if n in img_names]
1014 others = sorted(img for img in imgs if img.name not in order) 1084 others = sorted(img for img in imgs if img.name not in order)
1015 imgs = ordered + others 1085 imgs = ordered + others
1017 unwanted = { 1087 unwanted = {
1018 "compare_classifiers_multiclass_multimetric__label_best10.png", 1088 "compare_classifiers_multiclass_multimetric__label_best10.png",
1019 "compare_classifiers_multiclass_multimetric__label_top10.png", 1089 "compare_classifiers_multiclass_multimetric__label_top10.png",
1020 "compare_classifiers_multiclass_multimetric__label_worst10.png", 1090 "compare_classifiers_multiclass_multimetric__label_worst10.png",
1021 } 1091 }
1092 valid_imgs = [img for img in imgs if img.name not in unwanted]
1022 display_order = [ 1093 display_order = [
1023 "roc_curves.png", 1094 "roc_curves.png",
1024 "compare_performance_label.png", 1095 "compare_performance_label.png",
1025 "compare_classifiers_performance_from_prob.png", 1096 "compare_classifiers_performance_from_prob.png",
1026 "confusion_matrix_entropy__label_top10.png", 1097 "confusion_matrix_entropy__label_top10.png",
1027 ] 1098 ]
1028 # filter and order
1029 valid_imgs = [img for img in imgs if img.name not in unwanted]
1030 img_map = {img.name: img for img in valid_imgs} 1099 img_map = {img.name: img for img in valid_imgs}
1031 ordered = [img_map[n] for n in display_order if n in img_map] 1100 ordered = [img_map[n] for n in display_order if n in img_map]
1032 others = sorted( 1101 others = sorted(
1033 img for img in valid_imgs if img.name not in display_order 1102 img for img in valid_imgs if img.name not in display_order
1034 ) 1103 )
1035 imgs = ordered + others 1104 imgs = ordered + others
1036 else: 1105 else:
1037 # regression: just sort whatever's left
1038 imgs = sorted(imgs) 1106 imgs = sorted(imgs)
1039 # render each remaining PNG 1107
1040 html = "" 1108 html_section = ""
1041 for img in imgs: 1109 for img in imgs:
1042 b64 = encode_image_to_base64(str(img)) 1110 b64 = encode_image_to_base64(str(img))
1043 img_title = img.stem.replace("_", " ").title() 1111 img_title = img.stem.replace("_", " ").title()
1044 html += ( 1112 html_section += (
1045 f"<h2 style='text-align: center;'>{img_title}</h2>" 1113 f"<h2 style='text-align: center;'>{img_title}</h2>"
1046 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' 1114 f'<div class="plot" style="margin-bottom:20px;text-align:center;">'
1047 f'<img src="data:image/png;base64,{b64}" ' 1115 f'<img src="data:image/png;base64,{b64}" '
1048 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' 1116 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />'
1049 f"</div>" 1117 f"</div>"
1050 ) 1118 )
1051 return html 1119 return html_section
1052 1120
1053 tab1_content = config_html + metrics_html 1121 tab1_content = config_html + metrics_html
1122
1054 tab2_content = train_val_metrics_html + render_img_section( 1123 tab2_content = train_val_metrics_html + render_img_section(
1055 "Training and Validation Visualizations", train_viz_dir 1124 "Training and Validation Visualizations",
1056 ) 1125 train_viz_dir,
1057 # --- Predictions vs Ground Truth table --- 1126 output_type,
1127 exclude_names={
1128 "compare_classifiers_performance_from_prob.png",
1129 "roc_curves_from_prediction_statistics.png",
1130 "precision_recall_curves_from_prediction_statistics.png",
1131 "precision_recall_curve.png",
1132 },
1133 )
1134
1135 # --- Predictions vs Ground Truth table (REGRESSION ONLY) ---
1058 preds_section = "" 1136 preds_section = ""
1059 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME 1137 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
1060 if output_type == "regression" and parquet_path.exists(): 1138 if output_type == "regression" and parquet_path.exists():
1061 try: 1139 try:
1062 # 1) load predictions from Parquet 1140 # 1) load predictions from Parquet
1079 df_table.columns = [LABEL_COLUMN_NAME, "prediction"] 1157 df_table.columns = [LABEL_COLUMN_NAME, "prediction"]
1080 # 4) render as HTML 1158 # 4) render as HTML
1081 preds_html = df_table.to_html(index=False, classes="predictions-table") 1159 preds_html = df_table.to_html(index=False, classes="predictions-table")
1082 preds_section = ( 1160 preds_section = (
1083 "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>" 1161 "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>"
1084 "<div style='overflow-y:auto; max-height:400px; overflow-x:auto; margin-bottom:20px;'>" 1162 "<div class='preds-controls'>"
1163 "<button id='downloadPredsCsv' class='download-btn'>Download CSV</button>"
1164 "</div>"
1165 "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:900px; margin-bottom:20px;'>"
1085 + preds_html 1166 + preds_html
1086 + "</div>" 1167 + "</div>"
1087 ) 1168 )
1088 except Exception as e: 1169 except Exception as e:
1089 logger.warning(f"Could not build Predictions vs GT table: {e}") 1170 logger.warning(f"Could not build Predictions vs GT table: {e}")
1171
1090 tab3_content = test_metrics_html + preds_section 1172 tab3_content = test_metrics_html + preds_section
1173
1174 # Classification-only interactive Plotly panels (centered)
1091 if output_type in ("binary", "category"): 1175 if output_type in ("binary", "category"):
1092 training_stats_path = exp_dir / "training_statistics.json" 1176 training_stats_path = exp_dir / "training_statistics.json"
1093 interactive_plots = build_classification_plots( 1177 interactive_plots = build_classification_plots(
1094 str(test_stats_path), 1178 str(test_stats_path),
1095 str(training_stats_path), 1179 str(training_stats_path),
1096 ) 1180 )
1097 for plot in interactive_plots: 1181 for plot in interactive_plots:
1098 # 2) inject the static "roc_curves_from_prediction_statistics.png"
1099 if plot["title"] == "ROC-AUC":
1100 static_img = (
1101 test_viz_dir / "roc_curves_from_prediction_statistics.png"
1102 )
1103 if static_img.exists():
1104 b64 = encode_image_to_base64(str(static_img))
1105 tab3_content += (
1106 "<h2 style='text-align: center;'>"
1107 "Roc Curves From Prediction Statistics"
1108 "</h2>"
1109 f'<div class="plot" style="margin-bottom:20px;text-align:center;">'
1110 f'<img src="data:image/png;base64,{b64}" '
1111 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />'
1112 "</div>"
1113 )
1114 # always render the plotly panels exactly as before
1115 tab3_content += ( 1182 tab3_content += (
1116 f"<h2 style='text-align: center;'>{plot['title']}</h2>" 1183 f"<h2 style='text-align: center;'>{plot['title']}</h2>"
1117 + plot["html"] 1184 f"<div class='plotly-center'>{plot['html']}</div>"
1118 ) 1185 )
1119 tab3_content += render_img_section( 1186
1120 "Test Visualizations", test_viz_dir, output_type 1187 # Add static TEST PNGs (with default dedupe/exclusions)
1121 ) 1188 tab3_content += render_img_section(
1122 # assemble the tabs and help modal 1189 "Test Visualizations", test_viz_dir, output_type
1190 )
1191
1123 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) 1192 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
1124 modal_html = get_metrics_help_modal() 1193 modal_html = get_metrics_help_modal()
1125 html += tabbed_html + modal_html + get_html_closing() 1194 html += tabbed_html + modal_html + get_html_closing()
1126 1195
1127 try: 1196 try: