Mercurial > repos > goeckslab > image_learner
comparison image_learner_cli.py @ 8:85e6f4b2ad18 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 8a42eb9b33df7e1df5ad5153b380e20b910a05b6
| author | goeckslab |
|---|---|
| date | Thu, 14 Aug 2025 14:53:10 +0000 |
| parents | 801a8b6973fb |
| children |
comparison
equal
deleted
inserted
replaced
| 7:801a8b6973fb | 8:85e6f4b2ad18 |
|---|---|
| 29 TEST_STATISTICS_FILE_NAME, | 29 TEST_STATISTICS_FILE_NAME, |
| 30 TRAIN_SET_METADATA_FILE_NAME, | 30 TRAIN_SET_METADATA_FILE_NAME, |
| 31 ) | 31 ) |
| 32 from ludwig.utils.data_utils import get_split_path | 32 from ludwig.utils.data_utils import get_split_path |
| 33 from ludwig.visualize import get_visualizations_registry | 33 from ludwig.visualize import get_visualizations_registry |
| 34 from plotly_plots import build_classification_plots | |
| 34 from sklearn.model_selection import train_test_split | 35 from sklearn.model_selection import train_test_split |
| 35 from utils import ( | 36 from utils import ( |
| 36 build_tabbed_html, | 37 build_tabbed_html, |
| 37 encode_image_to_base64, | 38 encode_image_to_base64, |
| 38 get_html_closing, | 39 get_html_closing, |
| 50 | 51 |
| 51 def format_config_table_html( | 52 def format_config_table_html( |
| 52 config: dict, | 53 config: dict, |
| 53 split_info: Optional[str] = None, | 54 split_info: Optional[str] = None, |
| 54 training_progress: dict = None, | 55 training_progress: dict = None, |
| 56 output_type: Optional[str] = None, | |
| 55 ) -> str: | 57 ) -> str: |
| 56 display_keys = [ | 58 display_keys = [ |
| 57 "task_type", | 59 "task_type", |
| 58 "model_name", | 60 "model_name", |
| 59 "epochs", | 61 "epochs", |
| 61 "fine_tune", | 63 "fine_tune", |
| 62 "use_pretrained", | 64 "use_pretrained", |
| 63 "learning_rate", | 65 "learning_rate", |
| 64 "random_seed", | 66 "random_seed", |
| 65 "early_stop", | 67 "early_stop", |
| 68 "threshold", | |
| 66 ] | 69 ] |
| 67 | |
| 68 rows = [] | 70 rows = [] |
| 69 | |
| 70 for key in display_keys: | 71 for key in display_keys: |
| 71 val = config.get(key, "N/A") | 72 val = config.get(key, None) |
| 72 if key == "task_type": | 73 if key == "threshold": |
| 73 val = val.title() if isinstance(val, str) else val | 74 if output_type != "binary": |
| 74 if key == "batch_size": | 75 continue |
| 75 if val is not None: | 76 val = val if val is not None else 0.5 |
| 76 val = int(val) | 77 val_str = f"{val:.2f}" |
| 78 if val == 0.5: | |
| 79 val_str += " (default)" | |
| 80 else: | |
| 81 if key == "task_type": | |
| 82 val_str = val.title() if isinstance(val, str) else "N/A" | |
| 83 elif key == "batch_size": | |
| 84 if val is not None: | |
| 85 val_str = int(val) | |
| 86 else: | |
| 87 if training_progress: | |
| 88 resolved_val = training_progress.get("batch_size") | |
| 89 val_str = ( | |
| 90 "Auto-selected batch size by Ludwig:<br>" | |
| 91 f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>" | |
| 92 ) | |
| 93 else: | |
| 94 val_str = "auto" | |
| 95 elif key == "learning_rate": | |
| 96 if val is not None and val != "auto": | |
| 97 val_str = f"{val:.6f}" | |
| 98 else: | |
| 99 if training_progress: | |
| 100 resolved_val = training_progress.get("learning_rate") | |
| 101 val_str = ( | |
| 102 "Auto-selected learning rate by Ludwig:<br>" | |
| 103 f"<span style='font-size: 0.85em;'>" | |
| 104 f"{resolved_val if resolved_val else 'auto'}</span><br>" | |
| 105 "<span style='font-size: 0.85em;'>" | |
| 106 "Based on model architecture and training setup " | |
| 107 "(e.g., fine-tuning).<br>" | |
| 108 "</span>" | |
| 109 ) | |
| 110 else: | |
| 111 val_str = ( | |
| 112 "Auto-selected by Ludwig<br>" | |
| 113 "<span style='font-size: 0.85em;'>" | |
| 114 "Automatically tuned based on architecture and dataset.<br>" | |
| 115 "See <a href='https://ludwig.ai/latest/configuration/trainer/" | |
| 116 "#trainer-parameters' target='_blank'>" | |
| 117 "Ludwig Trainer Parameters</a> for details." | |
| 118 "</span>" | |
| 119 ) | |
| 120 elif key == "epochs": | |
| 121 if val is None: | |
| 122 val_str = "N/A" | |
| 123 else: | |
| 124 if ( | |
| 125 training_progress | |
| 126 and "epoch" in training_progress | |
| 127 and val > training_progress["epoch"] | |
| 128 ): | |
| 129 val_str = ( | |
| 130 f"Because of early stopping: the training " | |
| 131 f"stopped at epoch {training_progress['epoch']}" | |
| 132 ) | |
| 133 else: | |
| 134 val_str = val | |
| 77 else: | 135 else: |
| 78 if training_progress: | 136 val_str = val if val is not None else "N/A" |
| 79 val = "Auto-selected batch size by Ludwig:<br>" | 137 if val_str == "N/A" and key not in ["task_type"]: # Skip if N/A for non-essential |
| 80 resolved_val = training_progress.get("batch_size") | 138 continue |
| 81 val += f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>" | |
| 82 else: | |
| 83 val = "auto" | |
| 84 if key == "learning_rate": | |
| 85 resolved_val = None | |
| 86 if val is None or val == "auto": | |
| 87 if training_progress: | |
| 88 resolved_val = training_progress.get("learning_rate") | |
| 89 val = ( | |
| 90 "Auto-selected learning rate by Ludwig:<br>" | |
| 91 f"<span style='font-size: 0.85em;'>" | |
| 92 f"{resolved_val if resolved_val else val}</span><br>" | |
| 93 "<span style='font-size: 0.85em;'>" | |
| 94 "Based on model architecture and training setup " | |
| 95 "(e.g., fine-tuning).<br>" | |
| 96 "See <a href='https://ludwig.ai/latest/configuration/trainer/" | |
| 97 "#trainer-parameters' target='_blank'>" | |
| 98 "Ludwig Trainer Parameters</a> for details." | |
| 99 "</span>" | |
| 100 ) | |
| 101 else: | |
| 102 val = ( | |
| 103 "Auto-selected by Ludwig<br>" | |
| 104 "<span style='font-size: 0.85em;'>" | |
| 105 "Automatically tuned based on architecture and dataset.<br>" | |
| 106 "See <a href='https://ludwig.ai/latest/configuration/trainer/" | |
| 107 "#trainer-parameters' target='_blank'>" | |
| 108 "Ludwig Trainer Parameters</a> for details." | |
| 109 "</span>" | |
| 110 ) | |
| 111 else: | |
| 112 val = f"{val:.6f}" | |
| 113 if key == "epochs": | |
| 114 if ( | |
| 115 training_progress | |
| 116 and "epoch" in training_progress | |
| 117 and val > training_progress["epoch"] | |
| 118 ): | |
| 119 val = ( | |
| 120 f"Because of early stopping: the training " | |
| 121 f"stopped at epoch {training_progress['epoch']}" | |
| 122 ) | |
| 123 | |
| 124 if val is None: | |
| 125 continue | |
| 126 rows.append( | 139 rows.append( |
| 127 f"<tr>" | 140 f"<tr>" |
| 128 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" | 141 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" |
| 129 f"{key.replace('_', ' ').title()}</td>" | 142 f"{key.replace('_', ' ').title()}</td>" |
| 130 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" | 143 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" |
| 131 f"{val}</td>" | 144 f"{val_str}</td>" |
| 132 f"</tr>" | 145 f"</tr>" |
| 133 ) | 146 ) |
| 134 | |
| 135 aug_cfg = config.get("augmentation") | 147 aug_cfg = config.get("augmentation") |
| 136 if aug_cfg: | 148 if aug_cfg: |
| 137 types = [str(a.get("type", "")) for a in aug_cfg] | 149 types = [str(a.get("type", "")) for a in aug_cfg] |
| 138 aug_val = ", ".join(types) | 150 aug_val = ", ".join(types) |
| 139 rows.append( | 151 rows.append( |
| 140 "<tr>" | 152 f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>" |
| 141 "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>" | 153 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{aug_val}</td></tr>" |
| 142 "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" | 154 ) |
| 143 f"{aug_val}</td>" | |
| 144 "</tr>" | |
| 145 ) | |
| 146 | |
| 147 if split_info: | 155 if split_info: |
| 148 rows.append( | 156 rows.append( |
| 149 f"<tr>" | 157 f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td>" |
| 150 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" | 158 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{split_info}</td></tr>" |
| 151 f"Data Split</td>" | 159 ) |
| 152 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" | 160 html = f""" |
| 153 f"{split_info}</td>" | 161 <h2 style="text-align: center;">Model and Training Summary</h2> |
| 154 f"</tr>" | 162 <div style="display: flex; justify-content: center;"> |
| 155 ) | 163 <table style="border-collapse: collapse; width: 100%; table-layout: fixed;"> |
| 156 | 164 <thead><tr> |
| 157 return ( | 165 <th style="padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Parameter</th> |
| 158 "<h2 style='text-align: center;'>Training Setup</h2>" | 166 <th style="padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Value</th> |
| 159 "<div style='display: flex; justify-content: center;'>" | 167 </tr></thead> |
| 160 "<table style='border-collapse: collapse; width: 60%; table-layout: auto;'>" | 168 <tbody> |
| 161 "<thead><tr>" | 169 {''.join(rows)} |
| 162 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>" | 170 </tbody> |
| 163 "Parameter</th>" | 171 </table> |
| 164 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>" | 172 </div><br> |
| 165 "Value</th>" | 173 <p style="text-align: center; font-size: 0.9em;"> |
| 166 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" | 174 Model trained using <a href="https://ludwig.ai/" target="_blank" rel="noopener noreferrer">Ludwig</a>. |
| 167 "<p style='text-align: center; font-size: 0.9em;'>" | 175 <a href="https://ludwig.ai/latest/configuration/" target="_blank" rel="noopener noreferrer"> |
| 168 "Model trained using Ludwig.<br>" | 176 Ludwig documentation provides detailed information about default model and training parameters |
| 169 "If want to learn more about Ludwig default settings," | 177 </a> |
| 170 "please check their <a href='https://ludwig.ai' target='_blank'>" | 178 </p><hr> |
| 171 "website(ludwig.ai)</a>." | 179 """ |
| 172 "</p><hr>" | 180 return html |
| 173 ) | |
| 174 | 181 |
| 175 | 182 |
| 176 def detect_output_type(test_stats): | 183 def detect_output_type(test_stats): |
| 177 """Detects if the output type is 'binary' or 'category' based on test statistics.""" | 184 """Detects if the output type is 'binary' or 'category' based on test statistics.""" |
| 178 label_stats = test_stats.get("label", {}) | 185 label_stats = test_stats.get("label", {}) |
| 242 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), | 249 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), |
| 243 "loss": get_last_value(label_stats, "loss"), | 250 "loss": get_last_value(label_stats, "loss"), |
| 244 "roc_auc": get_last_value(label_stats, "roc_auc"), | 251 "roc_auc": get_last_value(label_stats, "roc_auc"), |
| 245 "hits_at_k": get_last_value(label_stats, "hits_at_k"), | 252 "hits_at_k": get_last_value(label_stats, "hits_at_k"), |
| 246 } | 253 } |
| 247 | |
| 248 # Test metrics: dynamic extraction according to exclusions | 254 # Test metrics: dynamic extraction according to exclusions |
| 249 test_label_stats = test_stats.get("label", {}) | 255 test_label_stats = test_stats.get("label", {}) |
| 250 if not test_label_stats: | 256 if not test_label_stats: |
| 251 logging.warning("No label statistics found for test split") | 257 logging.warning("No label statistics found for test split") |
| 252 else: | 258 else: |
| 253 combined_stats = test_stats.get("combined", {}) | 259 combined_stats = test_stats.get("combined", {}) |
| 254 overall_stats = test_label_stats.get("overall_stats", {}) | 260 overall_stats = test_label_stats.get("overall_stats", {}) |
| 255 | |
| 256 # Define exclusions | 261 # Define exclusions |
| 257 if output_type == "binary": | 262 if output_type == "binary": |
| 258 exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} | 263 exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} |
| 259 else: | 264 else: |
| 260 exclude = {"per_class_stats", "confusion_matrix"} | 265 exclude = {"per_class_stats", "confusion_matrix"} |
| 261 | |
| 262 # 1. Get all scalar test_label_stats not excluded | 266 # 1. Get all scalar test_label_stats not excluded |
| 263 test_metrics = {} | 267 test_metrics = {} |
| 264 for k, v in test_label_stats.items(): | 268 for k, v in test_label_stats.items(): |
| 265 if k in exclude: | 269 if k in exclude: |
| 266 continue | 270 continue |
| 267 if k == "overall_stats": | 271 if k == "overall_stats": |
| 268 continue | 272 continue |
| 269 if isinstance(v, (int, float, str, bool)): | 273 if isinstance(v, (int, float, str, bool)): |
| 270 test_metrics[k] = v | 274 test_metrics[k] = v |
| 271 | |
| 272 # 2. Add overall_stats (flattened) | 275 # 2. Add overall_stats (flattened) |
| 273 for k, v in overall_stats.items(): | 276 for k, v in overall_stats.items(): |
| 274 test_metrics[k] = v | 277 test_metrics[k] = v |
| 275 | |
| 276 # 3. Optionally include combined/loss if present and not already | 278 # 3. Optionally include combined/loss if present and not already |
| 277 if "loss" in combined_stats and "loss" not in test_metrics: | 279 if "loss" in combined_stats and "loss" not in test_metrics: |
| 278 test_metrics["loss"] = combined_stats["loss"] | 280 test_metrics["loss"] = combined_stats["loss"] |
| 279 | |
| 280 metrics["test"] = test_metrics | 281 metrics["test"] = test_metrics |
| 281 | |
| 282 return metrics | 282 return metrics |
| 283 | 283 |
| 284 | 284 |
| 285 def generate_table_row(cells, styles): | 285 def generate_table_row(cells, styles): |
| 286 """Helper function to generate an HTML table row.""" | 286 """Helper function to generate an HTML table row.""" |
| 287 return ( | 287 return ( |
| 288 "<tr>" | 288 "<tr>" |
| 289 + "".join(f"<td style='{styles}'>{cell}</td>" for cell in cells) | 289 + "".join(f"<td style='{styles}'>{cell}</td>" for cell in cells) |
| 290 + "</tr>" | 290 + "</tr>" |
| 291 ) | 291 ) |
| 292 | |
| 293 | |
| 294 # ----------------------------------------- | |
| 295 # 2) MODEL PERFORMANCE (Train/Val/Test) TABLE | |
| 296 # ----------------------------------------- | |
| 292 | 297 |
| 293 | 298 |
| 294 def format_stats_table_html(train_stats: dict, test_stats: dict) -> str: | 299 def format_stats_table_html(train_stats: dict, test_stats: dict) -> str: |
| 295 """Formats a combined HTML table for training, validation, and test metrics.""" | 300 """Formats a combined HTML table for training, validation, and test metrics.""" |
| 296 output_type = detect_output_type(test_stats) | 301 output_type = detect_output_type(test_stats) |
| 308 t = all_metrics["training"].get(metric_key) | 313 t = all_metrics["training"].get(metric_key) |
| 309 v = all_metrics["validation"].get(metric_key) | 314 v = all_metrics["validation"].get(metric_key) |
| 310 te = all_metrics["test"].get(metric_key) | 315 te = all_metrics["test"].get(metric_key) |
| 311 if all(x is not None for x in [t, v, te]): | 316 if all(x is not None for x in [t, v, te]): |
| 312 rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"]) | 317 rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"]) |
| 313 | |
| 314 if not rows: | 318 if not rows: |
| 315 return "<table><tr><td>No metric values found.</td></tr></table>" | 319 return "<table><tr><td>No metric values found.</td></tr></table>" |
| 316 | |
| 317 html = ( | 320 html = ( |
| 318 "<h2 style='text-align: center;'>Model Performance Summary</h2>" | 321 "<h2 style='text-align: center;'>Model Performance Summary</h2>" |
| 319 "<div style='display: flex; justify-content: center;'>" | 322 "<div style='display: flex; justify-content: center;'>" |
| 320 "<table style='border-collapse: collapse; table-layout: auto;'>" | 323 "<table class='performance-summary' style='border-collapse: collapse;'>" |
| 321 "<thead><tr>" | 324 "<thead><tr>" |
| 322 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " | 325 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>" |
| 323 "white-space: nowrap;'>Metric</th>" | 326 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th>" |
| 324 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " | 327 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th>" |
| 325 "white-space: nowrap;'>Train</th>" | 328 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th>" |
| 326 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " | |
| 327 "white-space: nowrap;'>Validation</th>" | |
| 328 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " | |
| 329 "white-space: nowrap;'>Test</th>" | |
| 330 "</tr></thead><tbody>" | 329 "</tr></thead><tbody>" |
| 331 ) | 330 ) |
| 332 for row in rows: | 331 for row in rows: |
| 333 html += generate_table_row( | 332 html += generate_table_row( |
| 334 row, | 333 row, |
| 335 "padding: 10px; border: 1px solid #ccc; text-align: center; " | 334 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" |
| 336 "white-space: nowrap;", | |
| 337 ) | 335 ) |
| 338 html += "</tbody></table></div><br>" | 336 html += "</tbody></table></div><br>" |
| 339 return html | 337 return html |
| 338 | |
| 339 | |
| 340 # ------------------------------------------- | |
| 341 # 3) TRAIN/VALIDATION PERFORMANCE SUMMARY TABLE | |
| 342 # ------------------------------------------- | |
| 340 | 343 |
| 341 | 344 |
| 342 def format_train_val_stats_table_html(train_stats: dict, test_stats: dict) -> str: | 345 def format_train_val_stats_table_html(train_stats: dict, test_stats: dict) -> str: |
| 343 """Formats an HTML table for training and validation metrics.""" | 346 """Formats an HTML table for training and validation metrics.""" |
| 344 output_type = detect_output_type(test_stats) | 347 output_type = detect_output_type(test_stats) |
| 352 ) | 355 ) |
| 353 t = all_metrics["training"].get(metric_key) | 356 t = all_metrics["training"].get(metric_key) |
| 354 v = all_metrics["validation"].get(metric_key) | 357 v = all_metrics["validation"].get(metric_key) |
| 355 if t is not None and v is not None: | 358 if t is not None and v is not None: |
| 356 rows.append([display_name, f"{t:.4f}", f"{v:.4f}"]) | 359 rows.append([display_name, f"{t:.4f}", f"{v:.4f}"]) |
| 357 | |
| 358 if not rows: | 360 if not rows: |
| 359 return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>" | 361 return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>" |
| 360 | |
| 361 html = ( | 362 html = ( |
| 362 "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>" | 363 "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>" |
| 363 "<div style='display: flex; justify-content: center;'>" | 364 "<div style='display: flex; justify-content: center;'>" |
| 364 "<table style='border-collapse: collapse; table-layout: auto;'>" | 365 "<table class='performance-summary' style='border-collapse: collapse;'>" |
| 365 "<thead><tr>" | 366 "<thead><tr>" |
| 366 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " | 367 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>" |
| 367 "white-space: nowrap;'>Metric</th>" | 368 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th>" |
| 368 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " | 369 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th>" |
| 369 "white-space: nowrap;'>Train</th>" | |
| 370 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " | |
| 371 "white-space: nowrap;'>Validation</th>" | |
| 372 "</tr></thead><tbody>" | 370 "</tr></thead><tbody>" |
| 373 ) | 371 ) |
| 374 for row in rows: | 372 for row in rows: |
| 375 html += generate_table_row( | 373 html += generate_table_row( |
| 376 row, | 374 row, |
| 377 "padding: 10px; border: 1px solid #ccc; text-align: center; " | 375 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" |
| 378 "white-space: nowrap;", | |
| 379 ) | 376 ) |
| 380 html += "</tbody></table></div><br>" | 377 html += "</tbody></table></div><br>" |
| 381 return html | 378 return html |
| 379 | |
| 380 | |
| 381 # ----------------------------------------- | |
| 382 # 4) TEST‐ONLY PERFORMANCE SUMMARY TABLE | |
| 383 # ----------------------------------------- | |
| 382 | 384 |
| 383 | 385 |
| 384 def format_test_merged_stats_table_html( | 386 def format_test_merged_stats_table_html( |
| 385 test_metrics: Dict[str, Optional[float]], | 387 test_metrics: Dict[str, Optional[float]], |
| 386 ) -> str: | 388 ) -> str: |
| 389 for key in sorted(test_metrics.keys()): | 391 for key in sorted(test_metrics.keys()): |
| 390 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title()) | 392 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title()) |
| 391 value = test_metrics[key] | 393 value = test_metrics[key] |
| 392 if value is not None: | 394 if value is not None: |
| 393 rows.append([display_name, f"{value:.4f}"]) | 395 rows.append([display_name, f"{value:.4f}"]) |
| 394 | |
| 395 if not rows: | 396 if not rows: |
| 396 return "<table><tr><td>No test metric values found.</td></tr></table>" | 397 return "<table><tr><td>No test metric values found.</td></tr></table>" |
| 397 | |
| 398 html = ( | 398 html = ( |
| 399 "<h2 style='text-align: center;'>Test Performance Summary</h2>" | 399 "<h2 style='text-align: center;'>Test Performance Summary</h2>" |
| 400 "<div style='display: flex; justify-content: center;'>" | 400 "<div style='display: flex; justify-content: center;'>" |
| 401 "<table style='border-collapse: collapse; table-layout: auto;'>" | 401 "<table class='performance-summary' style='border-collapse: collapse;'>" |
| 402 "<thead><tr>" | 402 "<thead><tr>" |
| 403 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " | 403 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>" |
| 404 "white-space: nowrap;'>Metric</th>" | 404 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th>" |
| 405 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " | |
| 406 "white-space: nowrap;'>Test</th>" | |
| 407 "</tr></thead><tbody>" | 405 "</tr></thead><tbody>" |
| 408 ) | 406 ) |
| 409 for row in rows: | 407 for row in rows: |
| 410 html += generate_table_row( | 408 html += generate_table_row( |
| 411 row, | 409 row, |
| 412 "padding: 10px; border: 1px solid #ccc; text-align: center; " | 410 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" |
| 413 "white-space: nowrap;", | |
| 414 ) | 411 ) |
| 415 html += "</tbody></table></div><br>" | 412 html += "</tbody></table></div><br>" |
| 416 return html | 413 return html |
| 417 | 414 |
| 418 | 415 |
| 424 label_column: Optional[str] = None, | 421 label_column: Optional[str] = None, |
| 425 ) -> pd.DataFrame: | 422 ) -> pd.DataFrame: |
| 426 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" | 423 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" |
| 427 out = df.copy() | 424 out = df.copy() |
| 428 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) | 425 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) |
| 429 | |
| 430 idx_train = out.index[out[split_column] == 0].tolist() | 426 idx_train = out.index[out[split_column] == 0].tolist() |
| 431 | |
| 432 if not idx_train: | 427 if not idx_train: |
| 433 logger.info("No rows with split=0; nothing to do.") | 428 logger.info("No rows with split=0; nothing to do.") |
| 434 return out | 429 return out |
| 435 | |
| 436 # Always use stratify if possible | 430 # Always use stratify if possible |
| 437 stratify_arr = None | 431 stratify_arr = None |
| 438 if label_column and label_column in out.columns: | 432 if label_column and label_column in out.columns: |
| 439 label_counts = out.loc[idx_train, label_column].value_counts() | 433 label_counts = out.loc[idx_train, label_column].value_counts() |
| 440 if label_counts.size > 1: | 434 if label_counts.size > 1: |
| 448 logger.info(f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation") | 442 logger.info(f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation") |
| 449 stratify_arr = out.loc[idx_train, label_column] | 443 stratify_arr = out.loc[idx_train, label_column] |
| 450 logger.info("Using stratified split for validation set") | 444 logger.info("Using stratified split for validation set") |
| 451 else: | 445 else: |
| 452 logger.warning("Only one label class found; cannot stratify") | 446 logger.warning("Only one label class found; cannot stratify") |
| 453 | |
| 454 if validation_size <= 0: | 447 if validation_size <= 0: |
| 455 logger.info("validation_size <= 0; keeping all as train.") | 448 logger.info("validation_size <= 0; keeping all as train.") |
| 456 return out | 449 return out |
| 457 if validation_size >= 1: | 450 if validation_size >= 1: |
| 458 logger.info("validation_size >= 1; moving all train → validation.") | 451 logger.info("validation_size >= 1; moving all train → validation.") |
| 459 out.loc[idx_train, split_column] = 1 | 452 out.loc[idx_train, split_column] = 1 |
| 460 return out | 453 return out |
| 461 | |
| 462 # Always try stratified split first | 454 # Always try stratified split first |
| 463 try: | 455 try: |
| 464 train_idx, val_idx = train_test_split( | 456 train_idx, val_idx = train_test_split( |
| 465 idx_train, | 457 idx_train, |
| 466 test_size=validation_size, | 458 test_size=validation_size, |
| 474 idx_train, | 466 idx_train, |
| 475 test_size=validation_size, | 467 test_size=validation_size, |
| 476 random_state=random_state, | 468 random_state=random_state, |
| 477 stratify=None, | 469 stratify=None, |
| 478 ) | 470 ) |
| 479 | |
| 480 out.loc[train_idx, split_column] = 0 | 471 out.loc[train_idx, split_column] = 0 |
| 481 out.loc[val_idx, split_column] = 1 | 472 out.loc[val_idx, split_column] = 1 |
| 482 out[split_column] = out[split_column].astype(int) | 473 out[split_column] = out[split_column].astype(int) |
| 483 return out | 474 return out |
| 484 | 475 |
| 490 random_state: int = 42, | 481 random_state: int = 42, |
| 491 label_column: Optional[str] = None, | 482 label_column: Optional[str] = None, |
| 492 ) -> pd.DataFrame: | 483 ) -> pd.DataFrame: |
| 493 """Create a stratified random split when no split column exists.""" | 484 """Create a stratified random split when no split column exists.""" |
| 494 out = df.copy() | 485 out = df.copy() |
| 495 | |
| 496 # initialize split column | 486 # initialize split column |
| 497 out[split_column] = 0 | 487 out[split_column] = 0 |
| 498 | |
| 499 if not label_column or label_column not in out.columns: | 488 if not label_column or label_column not in out.columns: |
| 500 logger.warning("No label column found; using random split without stratification") | 489 logger.warning("No label column found; using random split without stratification") |
| 501 # fall back to simple random assignment | 490 # fall back to simple random assignment |
| 502 indices = out.index.tolist() | 491 indices = out.index.tolist() |
| 503 np.random.seed(random_state) | 492 np.random.seed(random_state) |
| 504 np.random.shuffle(indices) | 493 np.random.shuffle(indices) |
| 505 | |
| 506 n_total = len(indices) | 494 n_total = len(indices) |
| 507 n_train = int(n_total * split_probabilities[0]) | 495 n_train = int(n_total * split_probabilities[0]) |
| 508 n_val = int(n_total * split_probabilities[1]) | 496 n_val = int(n_total * split_probabilities[1]) |
| 509 | |
| 510 out.loc[indices[:n_train], split_column] = 0 | 497 out.loc[indices[:n_train], split_column] = 0 |
| 511 out.loc[indices[n_train:n_train + n_val], split_column] = 1 | 498 out.loc[indices[n_train:n_train + n_val], split_column] = 1 |
| 512 out.loc[indices[n_train + n_val:], split_column] = 2 | 499 out.loc[indices[n_train + n_val:], split_column] = 2 |
| 513 | |
| 514 return out.astype({split_column: int}) | 500 return out.astype({split_column: int}) |
| 515 | |
| 516 # check if stratification is possible | 501 # check if stratification is possible |
| 517 label_counts = out[label_column].value_counts() | 502 label_counts = out[label_column].value_counts() |
| 518 min_samples_per_class = label_counts.min() | 503 min_samples_per_class = label_counts.min() |
| 519 | |
| 520 # ensure we have enough samples for stratification: | 504 # ensure we have enough samples for stratification: |
| 521 # Each class must have at least as many samples as the number of splits, | 505 # Each class must have at least as many samples as the number of splits, |
| 522 # so that each split can receive at least one sample per class. | 506 # so that each split can receive at least one sample per class. |
| 523 min_samples_required = len(split_probabilities) | 507 min_samples_required = len(split_probabilities) |
| 524 if min_samples_per_class < min_samples_required: | 508 if min_samples_per_class < min_samples_required: |
| 527 ) | 511 ) |
| 528 # fall back to simple random assignment | 512 # fall back to simple random assignment |
| 529 indices = out.index.tolist() | 513 indices = out.index.tolist() |
| 530 np.random.seed(random_state) | 514 np.random.seed(random_state) |
| 531 np.random.shuffle(indices) | 515 np.random.shuffle(indices) |
| 532 | |
| 533 n_total = len(indices) | 516 n_total = len(indices) |
| 534 n_train = int(n_total * split_probabilities[0]) | 517 n_train = int(n_total * split_probabilities[0]) |
| 535 n_val = int(n_total * split_probabilities[1]) | 518 n_val = int(n_total * split_probabilities[1]) |
| 536 | |
| 537 out.loc[indices[:n_train], split_column] = 0 | 519 out.loc[indices[:n_train], split_column] = 0 |
| 538 out.loc[indices[n_train:n_train + n_val], split_column] = 1 | 520 out.loc[indices[n_train:n_train + n_val], split_column] = 1 |
| 539 out.loc[indices[n_train + n_val:], split_column] = 2 | 521 out.loc[indices[n_train + n_val:], split_column] = 2 |
| 540 | |
| 541 return out.astype({split_column: int}) | 522 return out.astype({split_column: int}) |
| 542 | |
| 543 logger.info("Using stratified random split for train/validation/test sets") | 523 logger.info("Using stratified random split for train/validation/test sets") |
| 544 | |
| 545 # first split: separate test set | 524 # first split: separate test set |
| 546 train_val_idx, test_idx = train_test_split( | 525 train_val_idx, test_idx = train_test_split( |
| 547 out.index.tolist(), | 526 out.index.tolist(), |
| 548 test_size=split_probabilities[2], | 527 test_size=split_probabilities[2], |
| 549 random_state=random_state, | 528 random_state=random_state, |
| 550 stratify=out[label_column], | 529 stratify=out[label_column], |
| 551 ) | 530 ) |
| 552 | |
| 553 # second split: separate training and validation from remaining data | 531 # second split: separate training and validation from remaining data |
| 554 val_size_adjusted = split_probabilities[1] / (split_probabilities[0] + split_probabilities[1]) | 532 val_size_adjusted = split_probabilities[1] / (split_probabilities[0] + split_probabilities[1]) |
| 555 train_idx, val_idx = train_test_split( | 533 train_idx, val_idx = train_test_split( |
| 556 train_val_idx, | 534 train_val_idx, |
| 557 test_size=val_size_adjusted, | 535 test_size=val_size_adjusted, |
| 558 random_state=random_state, | 536 random_state=random_state, |
| 559 stratify=out.loc[train_val_idx, label_column], | 537 stratify=out.loc[train_val_idx, label_column], |
| 560 ) | 538 ) |
| 561 | |
| 562 # assign split values | 539 # assign split values |
| 563 out.loc[train_idx, split_column] = 0 | 540 out.loc[train_idx, split_column] = 0 |
| 564 out.loc[val_idx, split_column] = 1 | 541 out.loc[val_idx, split_column] = 1 |
| 565 out.loc[test_idx, split_column] = 2 | 542 out.loc[test_idx, split_column] = 2 |
| 566 | |
| 567 logger.info("Successfully applied stratified random split") | 543 logger.info("Successfully applied stratified random split") |
| 568 logger.info(f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}") | 544 logger.info(f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}") |
| 569 | |
| 570 return out.astype({split_column: int}) | 545 return out.astype({split_column: int}) |
| 571 | 546 |
| 572 | 547 |
| 573 class Backend(Protocol): | 548 class Backend(Protocol): |
| 574 """Interface for a machine learning backend.""" | 549 """Interface for a machine learning backend.""" |
| 575 | |
| 576 def prepare_config( | 550 def prepare_config( |
| 577 self, | 551 self, |
| 578 config_params: Dict[str, Any], | 552 config_params: Dict[str, Any], |
| 579 split_config: Dict[str, Any], | 553 split_config: Dict[str, Any], |
| 580 ) -> str: | 554 ) -> str: |
| 602 ... | 576 ... |
| 603 | 577 |
| 604 | 578 |
| 605 class LudwigDirectBackend: | 579 class LudwigDirectBackend: |
| 606 """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" | 580 """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" |
| 607 | |
| 608 def prepare_config( | 581 def prepare_config( |
| 609 self, | 582 self, |
| 610 config_params: Dict[str, Any], | 583 config_params: Dict[str, Any], |
| 611 split_config: Dict[str, Any], | 584 split_config: Dict[str, Any], |
| 612 ) -> str: | 585 ) -> str: |
| 613 logger.info("LudwigDirectBackend: Preparing YAML configuration.") | 586 logger.info("LudwigDirectBackend: Preparing YAML configuration.") |
| 614 | |
| 615 model_name = config_params.get("model_name", "resnet18") | 587 model_name = config_params.get("model_name", "resnet18") |
| 616 use_pretrained = config_params.get("use_pretrained", False) | 588 use_pretrained = config_params.get("use_pretrained", False) |
| 617 fine_tune = config_params.get("fine_tune", False) | 589 fine_tune = config_params.get("fine_tune", False) |
| 618 if use_pretrained: | 590 if use_pretrained: |
| 619 trainable = bool(fine_tune) | 591 trainable = bool(fine_tune) |
| 632 "use_pretrained": use_pretrained, | 604 "use_pretrained": use_pretrained, |
| 633 "trainable": trainable, | 605 "trainable": trainable, |
| 634 } | 606 } |
| 635 else: | 607 else: |
| 636 encoder_config = {"type": raw_encoder} | 608 encoder_config = {"type": raw_encoder} |
| 637 | |
| 638 batch_size_cfg = batch_size or "auto" | 609 batch_size_cfg = batch_size or "auto" |
| 639 | |
| 640 label_column_path = config_params.get("label_column_data_path") | 610 label_column_path = config_params.get("label_column_data_path") |
| 641 label_series = None | 611 label_series = None |
| 642 if label_column_path is not None and Path(label_column_path).exists(): | 612 if label_column_path is not None and Path(label_column_path).exists(): |
| 643 try: | 613 try: |
| 644 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] | 614 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] |
| 645 except Exception as e: | 615 except Exception as e: |
| 646 logger.warning(f"Could not read label column for task detection: {e}") | 616 logger.warning(f"Could not read label column for task detection: {e}") |
| 647 | |
| 648 if ( | 617 if ( |
| 649 label_series is not None | 618 label_series is not None |
| 650 and ptypes.is_numeric_dtype(label_series.dtype) | 619 and ptypes.is_numeric_dtype(label_series.dtype) |
| 651 and label_series.nunique() > 10 | 620 and label_series.nunique() > 10 |
| 652 ): | 621 ): |
| 653 task_type = "regression" | 622 task_type = "regression" |
| 654 else: | 623 else: |
| 655 task_type = "classification" | 624 task_type = "classification" |
| 656 | |
| 657 config_params["task_type"] = task_type | 625 config_params["task_type"] = task_type |
| 658 | |
| 659 image_feat: Dict[str, Any] = { | 626 image_feat: Dict[str, Any] = { |
| 660 "name": IMAGE_PATH_COLUMN_NAME, | 627 "name": IMAGE_PATH_COLUMN_NAME, |
| 661 "type": "image", | 628 "type": "image", |
| 662 "encoder": encoder_config, | 629 "encoder": encoder_config, |
| 663 } | 630 } |
| 664 if config_params.get("augmentation") is not None: | 631 if config_params.get("augmentation") is not None: |
| 665 image_feat["augmentation"] = config_params["augmentation"] | 632 image_feat["augmentation"] = config_params["augmentation"] |
| 666 | |
| 667 if task_type == "regression": | 633 if task_type == "regression": |
| 668 output_feat = { | 634 output_feat = { |
| 669 "name": LABEL_COLUMN_NAME, | 635 "name": LABEL_COLUMN_NAME, |
| 670 "type": "number", | 636 "type": "number", |
| 671 "decoder": {"type": "regressor"}, | 637 "decoder": {"type": "regressor"}, |
| 677 "r2", | 643 "r2", |
| 678 ] | 644 ] |
| 679 }, | 645 }, |
| 680 } | 646 } |
| 681 val_metric = config_params.get("validation_metric", "mean_squared_error") | 647 val_metric = config_params.get("validation_metric", "mean_squared_error") |
| 682 | |
| 683 else: | 648 else: |
| 684 num_unique_labels = ( | 649 num_unique_labels = ( |
| 685 label_series.nunique() if label_series is not None else 2 | 650 label_series.nunique() if label_series is not None else 2 |
| 686 ) | 651 ) |
| 687 output_type = "binary" if num_unique_labels == 2 else "category" | 652 output_type = "binary" if num_unique_labels == 2 else "category" |
| 688 output_feat = {"name": LABEL_COLUMN_NAME, "type": output_type} | 653 output_feat = {"name": LABEL_COLUMN_NAME, "type": output_type} |
| 654 if output_type == "binary" and config_params.get("threshold") is not None: | |
| 655 output_feat["threshold"] = float(config_params["threshold"]) | |
| 689 val_metric = None | 656 val_metric = None |
| 690 | |
| 691 conf: Dict[str, Any] = { | 657 conf: Dict[str, Any] = { |
| 692 "model_type": "ecd", | 658 "model_type": "ecd", |
| 693 "input_features": [image_feat], | 659 "input_features": [image_feat], |
| 694 "output_features": [output_feat], | 660 "output_features": [output_feat], |
| 695 "combiner": {"type": "concat"}, | 661 "combiner": {"type": "concat"}, |
| 705 "split": split_config, | 671 "split": split_config, |
| 706 "num_processes": num_processes, | 672 "num_processes": num_processes, |
| 707 "in_memory": False, | 673 "in_memory": False, |
| 708 }, | 674 }, |
| 709 } | 675 } |
| 710 | |
| 711 logger.debug("LudwigDirectBackend: Config dict built.") | 676 logger.debug("LudwigDirectBackend: Config dict built.") |
| 712 try: | 677 try: |
| 713 yaml_str = yaml.dump(conf, sort_keys=False, indent=2) | 678 yaml_str = yaml.dump(conf, sort_keys=False, indent=2) |
| 714 logger.info("LudwigDirectBackend: YAML config generated.") | 679 logger.info("LudwigDirectBackend: YAML config generated.") |
| 715 return yaml_str | 680 return yaml_str |
| 727 output_dir: Path, | 692 output_dir: Path, |
| 728 random_seed: int = 42, | 693 random_seed: int = 42, |
| 729 ) -> None: | 694 ) -> None: |
| 730 """Invoke Ludwig's internal experiment_cli function to run the experiment.""" | 695 """Invoke Ludwig's internal experiment_cli function to run the experiment.""" |
| 731 logger.info("LudwigDirectBackend: Starting experiment execution.") | 696 logger.info("LudwigDirectBackend: Starting experiment execution.") |
| 732 | |
| 733 try: | 697 try: |
| 734 from ludwig.experiment import experiment_cli | 698 from ludwig.experiment import experiment_cli |
| 735 except ImportError as e: | 699 except ImportError as e: |
| 736 logger.error( | 700 logger.error( |
| 737 "LudwigDirectBackend: Could not import experiment_cli.", | 701 "LudwigDirectBackend: Could not import experiment_cli.", |
| 738 exc_info=True, | 702 exc_info=True, |
| 739 ) | 703 ) |
| 740 raise RuntimeError("Ludwig import failed.") from e | 704 raise RuntimeError("Ludwig import failed.") from e |
| 741 | |
| 742 output_dir.mkdir(parents=True, exist_ok=True) | 705 output_dir.mkdir(parents=True, exist_ok=True) |
| 743 | |
| 744 try: | 706 try: |
| 745 experiment_cli( | 707 experiment_cli( |
| 746 dataset=str(dataset_path), | 708 dataset=str(dataset_path), |
| 747 config=str(config_path), | 709 config=str(config_path), |
| 748 output_directory=str(output_dir), | 710 output_directory=str(output_dir), |
| 769 output_dir = Path(output_dir) | 731 output_dir = Path(output_dir) |
| 770 exp_dirs = sorted( | 732 exp_dirs = sorted( |
| 771 output_dir.glob("experiment_run*"), | 733 output_dir.glob("experiment_run*"), |
| 772 key=lambda p: p.stat().st_mtime, | 734 key=lambda p: p.stat().st_mtime, |
| 773 ) | 735 ) |
| 774 | |
| 775 if not exp_dirs: | 736 if not exp_dirs: |
| 776 logger.warning(f"No experiment run directories found in {output_dir}") | 737 logger.warning(f"No experiment run directories found in {output_dir}") |
| 777 return None | 738 return None |
| 778 | |
| 779 progress_file = exp_dirs[-1] / "model" / "training_progress.json" | 739 progress_file = exp_dirs[-1] / "model" / "training_progress.json" |
| 780 if not progress_file.exists(): | 740 if not progress_file.exists(): |
| 781 logger.warning(f"No training_progress.json found in {progress_file}") | 741 logger.warning(f"No training_progress.json found in {progress_file}") |
| 782 return None | 742 return None |
| 783 | |
| 784 try: | 743 try: |
| 785 with progress_file.open("r", encoding="utf-8") as f: | 744 with progress_file.open("r", encoding="utf-8") as f: |
| 786 data = json.load(f) | 745 data = json.load(f) |
| 787 return { | 746 return { |
| 788 "learning_rate": data.get("learning_rate"), | 747 "learning_rate": data.get("learning_rate"), |
| 814 logger.error(f"Error converting Parquet to CSV: {e}") | 773 logger.error(f"Error converting Parquet to CSV: {e}") |
| 815 | 774 |
| 816 def generate_plots(self, output_dir: Path) -> None: | 775 def generate_plots(self, output_dir: Path) -> None: |
| 817 """Generate all registered Ludwig visualizations for the latest experiment run.""" | 776 """Generate all registered Ludwig visualizations for the latest experiment run.""" |
| 818 logger.info("Generating all Ludwig visualizations…") | 777 logger.info("Generating all Ludwig visualizations…") |
| 819 | |
| 820 test_plots = { | 778 test_plots = { |
| 821 "compare_performance", | 779 "compare_performance", |
| 822 "compare_classifiers_performance_from_prob", | 780 "compare_classifiers_performance_from_prob", |
| 823 "compare_classifiers_performance_from_pred", | 781 "compare_classifiers_performance_from_pred", |
| 824 "compare_classifiers_performance_changing_k", | 782 "compare_classifiers_performance_changing_k", |
| 838 } | 796 } |
| 839 train_plots = { | 797 train_plots = { |
| 840 "learning_curves", | 798 "learning_curves", |
| 841 "compare_classifiers_performance_subset", | 799 "compare_classifiers_performance_subset", |
| 842 } | 800 } |
| 843 | |
| 844 output_dir = Path(output_dir) | 801 output_dir = Path(output_dir) |
| 845 exp_dirs = sorted( | 802 exp_dirs = sorted( |
| 846 output_dir.glob("experiment_run*"), | 803 output_dir.glob("experiment_run*"), |
| 847 key=lambda p: p.stat().st_mtime, | 804 key=lambda p: p.stat().st_mtime, |
| 848 ) | 805 ) |
| 849 if not exp_dirs: | 806 if not exp_dirs: |
| 850 logger.warning(f"No experiment run dirs found in {output_dir}") | 807 logger.warning(f"No experiment run dirs found in {output_dir}") |
| 851 return | 808 return |
| 852 exp_dir = exp_dirs[-1] | 809 exp_dir = exp_dirs[-1] |
| 853 | |
| 854 viz_dir = exp_dir / "visualizations" | 810 viz_dir = exp_dir / "visualizations" |
| 855 viz_dir.mkdir(exist_ok=True) | 811 viz_dir.mkdir(exist_ok=True) |
| 856 train_viz = viz_dir / "train" | 812 train_viz = viz_dir / "train" |
| 857 test_viz = viz_dir / "test" | 813 test_viz = viz_dir / "test" |
| 858 train_viz.mkdir(parents=True, exist_ok=True) | 814 train_viz.mkdir(parents=True, exist_ok=True) |
| 863 | 819 |
| 864 training_stats = _check(exp_dir / "training_statistics.json") | 820 training_stats = _check(exp_dir / "training_statistics.json") |
| 865 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) | 821 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) |
| 866 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) | 822 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) |
| 867 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) | 823 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) |
| 868 | |
| 869 dataset_path = None | 824 dataset_path = None |
| 870 split_file = None | 825 split_file = None |
| 871 desc = exp_dir / DESCRIPTION_FILE_NAME | 826 desc = exp_dir / DESCRIPTION_FILE_NAME |
| 872 if desc.exists(): | 827 if desc.exists(): |
| 873 with open(desc, "r") as f: | 828 with open(desc, "r") as f: |
| 874 cfg = json.load(f) | 829 cfg = json.load(f) |
| 875 dataset_path = _check(Path(cfg.get("dataset", ""))) | 830 dataset_path = _check(Path(cfg.get("dataset", ""))) |
| 876 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) | 831 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) |
| 877 | |
| 878 output_feature = "" | 832 output_feature = "" |
| 879 if desc.exists(): | 833 if desc.exists(): |
| 880 try: | 834 try: |
| 881 output_feature = cfg["config"]["output_features"][0]["name"] | 835 output_feature = cfg["config"]["output_features"][0]["name"] |
| 882 except Exception: | 836 except Exception: |
| 883 pass | 837 pass |
| 884 if not output_feature and test_stats: | 838 if not output_feature and test_stats: |
| 885 with open(test_stats, "r") as f: | 839 with open(test_stats, "r") as f: |
| 886 stats = json.load(f) | 840 stats = json.load(f) |
| 887 output_feature = next(iter(stats.keys()), "") | 841 output_feature = next(iter(stats.keys()), "") |
| 888 | |
| 889 viz_registry = get_visualizations_registry() | 842 viz_registry = get_visualizations_registry() |
| 890 for viz_name, viz_func in viz_registry.items(): | 843 for viz_name, viz_func in viz_registry.items(): |
| 891 if viz_name in train_plots: | 844 if viz_name in train_plots: |
| 892 viz_dir_plot = train_viz | 845 viz_dir_plot = train_viz |
| 893 elif viz_name in test_plots: | 846 elif viz_name in test_plots: |
| 894 viz_dir_plot = test_viz | 847 viz_dir_plot = test_viz |
| 895 else: | 848 else: |
| 896 continue | 849 continue |
| 897 | |
| 898 try: | 850 try: |
| 899 viz_func( | 851 viz_func( |
| 900 training_statistics=[training_stats] if training_stats else [], | 852 training_statistics=[training_stats] if training_stats else [], |
| 901 test_statistics=[test_stats] if test_stats else [], | 853 test_statistics=[test_stats] if test_stats else [], |
| 902 probabilities=[probs_path] if probs_path else [], | 854 probabilities=[probs_path] if probs_path else [], |
| 912 file_format="png", | 864 file_format="png", |
| 913 ) | 865 ) |
| 914 logger.info(f"✔ Generated {viz_name}") | 866 logger.info(f"✔ Generated {viz_name}") |
| 915 except Exception as e: | 867 except Exception as e: |
| 916 logger.warning(f"✘ Skipped {viz_name}: {e}") | 868 logger.warning(f"✘ Skipped {viz_name}: {e}") |
| 917 | |
| 918 logger.info(f"All visualizations written to {viz_dir}") | 869 logger.info(f"All visualizations written to {viz_dir}") |
| 919 | 870 |
| 920 def generate_html_report( | 871 def generate_html_report( |
| 921 self, | 872 self, |
| 922 title: str, | 873 title: str, |
| 928 cwd = Path.cwd() | 879 cwd = Path.cwd() |
| 929 report_name = title.lower().replace(" ", "_") + "_report.html" | 880 report_name = title.lower().replace(" ", "_") + "_report.html" |
| 930 report_path = cwd / report_name | 881 report_path = cwd / report_name |
| 931 output_dir = Path(output_dir) | 882 output_dir = Path(output_dir) |
| 932 output_type = None | 883 output_type = None |
| 933 | |
| 934 exp_dirs = sorted( | 884 exp_dirs = sorted( |
| 935 output_dir.glob("experiment_run*"), | 885 output_dir.glob("experiment_run*"), |
| 936 key=lambda p: p.stat().st_mtime, | 886 key=lambda p: p.stat().st_mtime, |
| 937 ) | 887 ) |
| 938 if not exp_dirs: | 888 if not exp_dirs: |
| 939 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") | 889 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") |
| 940 exp_dir = exp_dirs[-1] | 890 exp_dir = exp_dirs[-1] |
| 941 | |
| 942 base_viz_dir = exp_dir / "visualizations" | 891 base_viz_dir = exp_dir / "visualizations" |
| 943 train_viz_dir = base_viz_dir / "train" | 892 train_viz_dir = base_viz_dir / "train" |
| 944 test_viz_dir = base_viz_dir / "test" | 893 test_viz_dir = base_viz_dir / "test" |
| 945 | |
| 946 html = get_html_template() | 894 html = get_html_template() |
| 947 html += f"<h1>{title}</h1>" | 895 html += f"<h1>{title}</h1>" |
| 948 | |
| 949 metrics_html = "" | 896 metrics_html = "" |
| 950 train_val_metrics_html = "" | 897 train_val_metrics_html = "" |
| 951 test_metrics_html = "" | 898 test_metrics_html = "" |
| 952 try: | 899 try: |
| 953 train_stats_path = exp_dir / "training_statistics.json" | 900 train_stats_path = exp_dir / "training_statistics.json" |
| 969 ) | 916 ) |
| 970 except Exception as e: | 917 except Exception as e: |
| 971 logger.warning( | 918 logger.warning( |
| 972 f"Could not load stats for HTML report: {type(e).__name__}: {e}" | 919 f"Could not load stats for HTML report: {type(e).__name__}: {e}" |
| 973 ) | 920 ) |
| 974 | |
| 975 config_html = "" | 921 config_html = "" |
| 976 training_progress = self.get_training_process(output_dir) | 922 training_progress = self.get_training_process(output_dir) |
| 977 try: | 923 try: |
| 978 config_html = format_config_table_html( | 924 config_html = format_config_table_html( |
| 979 config, split_info, training_progress | 925 config, split_info, training_progress |
| 984 def render_img_section( | 930 def render_img_section( |
| 985 title: str, dir_path: Path, output_type: str = None | 931 title: str, dir_path: Path, output_type: str = None |
| 986 ) -> str: | 932 ) -> str: |
| 987 if not dir_path.exists(): | 933 if not dir_path.exists(): |
| 988 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" | 934 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" |
| 989 | 935 # collect every PNG |
| 990 imgs = list(dir_path.glob("*.png")) | 936 imgs = list(dir_path.glob("*.png")) |
| 937 # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files --- | |
| 938 imgs = [ | |
| 939 img for img in imgs | |
| 940 if not ( | |
| 941 img.name == "confusion_matrix.png" | |
| 942 or img.name.startswith("confusion_matrix__label_top") | |
| 943 or img.name == "roc_curves.png" | |
| 944 ) | |
| 945 ] | |
| 991 if not imgs: | 946 if not imgs: |
| 992 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" | 947 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" |
| 993 | 948 if output_type == "binary": |
| 994 if title == "Test Visualizations" and output_type == "binary": | |
| 995 order = [ | 949 order = [ |
| 996 "confusion_matrix__label_top2.png", | |
| 997 "roc_curves_from_prediction_statistics.png", | 950 "roc_curves_from_prediction_statistics.png", |
| 998 "compare_performance_label.png", | 951 "compare_performance_label.png", |
| 999 "confusion_matrix_entropy__label_top2.png", | 952 "confusion_matrix_entropy__label_top2.png", |
| 953 # ...you can tweak ordering as needed | |
| 1000 ] | 954 ] |
| 1001 img_names = {img.name: img for img in imgs} | 955 img_names = {img.name: img for img in imgs} |
| 1002 ordered_imgs = [ | 956 ordered = [img_names[n] for n in order if n in img_names] |
| 1003 img_names[fname] for fname in order if fname in img_names | 957 others = sorted(img for img in imgs if img.name not in order) |
| 1004 ] | 958 imgs = ordered + others |
| 1005 remaining = sorted( | 959 elif output_type == "category": |
| 1006 [ | |
| 1007 img | |
| 1008 for img in imgs | |
| 1009 if img.name not in order and img.name != "roc_curves.png" | |
| 1010 ] | |
| 1011 ) | |
| 1012 imgs = ordered_imgs + remaining | |
| 1013 | |
| 1014 elif title == "Test Visualizations" and output_type == "category": | |
| 1015 unwanted = { | 960 unwanted = { |
| 1016 "compare_classifiers_multiclass_multimetric__label_best10.png", | 961 "compare_classifiers_multiclass_multimetric__label_best10.png", |
| 1017 "compare_classifiers_multiclass_multimetric__label_top10.png", | 962 "compare_classifiers_multiclass_multimetric__label_top10.png", |
| 1018 "compare_classifiers_multiclass_multimetric__label_worst10.png", | 963 "compare_classifiers_multiclass_multimetric__label_worst10.png", |
| 1019 } | 964 } |
| 1020 display_order = [ | 965 display_order = [ |
| 1021 "confusion_matrix__label_top10.png", | |
| 1022 "roc_curves.png", | 966 "roc_curves.png", |
| 1023 "compare_performance_label.png", | 967 "compare_performance_label.png", |
| 1024 "compare_classifiers_performance_from_prob.png", | 968 "compare_classifiers_performance_from_prob.png", |
| 1025 "compare_classifiers_multiclass_multimetric__label_sorted.png", | |
| 1026 "confusion_matrix_entropy__label_top10.png", | 969 "confusion_matrix_entropy__label_top10.png", |
| 1027 ] | 970 ] |
| 1028 img_names = {img.name: img for img in imgs if img.name not in unwanted} | 971 # filter and order |
| 1029 ordered_imgs = [ | 972 valid_imgs = [img for img in imgs if img.name not in unwanted] |
| 1030 img_names[fname] for fname in display_order if fname in img_names | 973 img_map = {img.name: img for img in valid_imgs} |
| 1031 ] | 974 ordered = [img_map[n] for n in display_order if n in img_map] |
| 1032 remaining = sorted( | 975 others = sorted(img for img in valid_imgs if img.name not in display_order) |
| 1033 [img for img in img_names.values() if img.name not in display_order] | 976 imgs = ordered + others |
| 1034 ) | |
| 1035 imgs = ordered_imgs + remaining | |
| 1036 | |
| 1037 else: | 977 else: |
| 1038 if output_type == "category": | 978 # regression: just sort whatever's left |
| 1039 unwanted = { | 979 imgs = sorted(imgs) |
| 1040 "compare_classifiers_multiclass_multimetric__label_best10.png", | 980 # render each remaining PNG |
| 1041 "compare_classifiers_multiclass_multimetric__label_top10.png", | 981 html = "" |
| 1042 "compare_classifiers_multiclass_multimetric__label_worst10.png", | |
| 1043 } | |
| 1044 imgs = sorted([img for img in imgs if img.name not in unwanted]) | |
| 1045 else: | |
| 1046 imgs = sorted(imgs) | |
| 1047 | |
| 1048 section_html = f"<h2 style='text-align: center;'>{title}</h2><div>" | |
| 1049 for img in imgs: | 982 for img in imgs: |
| 1050 b64 = encode_image_to_base64(str(img)) | 983 b64 = encode_image_to_base64(str(img)) |
| 1051 section_html += ( | 984 img_title = img.stem.replace("_", " ").title() |
| 985 html += ( | |
| 986 f"<h2 style='text-align: center;'>{img_title}</h2>" | |
| 1052 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' | 987 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' |
| 1053 f"<h3>{img.stem.replace('_', ' ').title()}</h3>" | |
| 1054 f'<img src="data:image/png;base64,{b64}" ' | 988 f'<img src="data:image/png;base64,{b64}" ' |
| 1055 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' | 989 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' |
| 1056 f"</div>" | 990 f"</div>" |
| 1057 ) | 991 ) |
| 1058 section_html += "</div>" | 992 return html |
| 1059 return section_html | |
| 1060 | 993 |
| 1061 tab1_content = config_html + metrics_html | 994 tab1_content = config_html + metrics_html |
| 1062 | |
| 1063 tab2_content = train_val_metrics_html + render_img_section( | 995 tab2_content = train_val_metrics_html + render_img_section( |
| 1064 "Training & Validation Visualizations", train_viz_dir | 996 "Training and Validation Visualizations", train_viz_dir |
| 1065 ) | 997 ) |
| 1066 | |
| 1067 # --- Predictions vs Ground Truth table --- | 998 # --- Predictions vs Ground Truth table --- |
| 1068 preds_section = "" | 999 preds_section = "" |
| 1069 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME | 1000 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME |
| 1070 if parquet_path.exists(): | 1001 if output_type == "regression" and parquet_path.exists(): |
| 1071 try: | 1002 try: |
| 1072 # 1) load predictions from Parquet | 1003 # 1) load predictions from Parquet |
| 1073 df_preds = pd.read_parquet(parquet_path).reset_index(drop=True) | 1004 df_preds = pd.read_parquet(parquet_path).reset_index(drop=True) |
| 1074 # assume the column containing your model's prediction is named "prediction" | 1005 # assume the column containing your model's prediction is named "prediction" |
| 1075 # or contains that substring: | |
| 1076 pred_col = next( | 1006 pred_col = next( |
| 1077 (c for c in df_preds.columns if "prediction" in c.lower()), | 1007 (c for c in df_preds.columns if "prediction" in c.lower()), |
| 1078 None, | 1008 None, |
| 1079 ) | 1009 ) |
| 1080 if pred_col is None: | 1010 if pred_col is None: |
| 1081 raise ValueError("No prediction column found in Parquet output") | 1011 raise ValueError("No prediction column found in Parquet output") |
| 1082 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"}) | 1012 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"}) |
| 1083 | |
| 1084 # 2) load ground truth for the test split from prepared CSV | 1013 # 2) load ground truth for the test split from prepared CSV |
| 1085 df_all = pd.read_csv(config["label_column_data_path"]) | 1014 df_all = pd.read_csv(config["label_column_data_path"]) |
| 1086 df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][ | 1015 df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][LABEL_COLUMN_NAME].reset_index(drop=True) |
| 1087 LABEL_COLUMN_NAME | 1016 # 3) concatenate side-by-side |
| 1088 ].reset_index(drop=True) | |
| 1089 | |
| 1090 # 3) concatenate side‐by‐side | |
| 1091 df_table = pd.concat([df_gt, df_pred], axis=1) | 1017 df_table = pd.concat([df_gt, df_pred], axis=1) |
| 1092 df_table.columns = [LABEL_COLUMN_NAME, "prediction"] | 1018 df_table.columns = [LABEL_COLUMN_NAME, "prediction"] |
| 1093 | |
| 1094 # 4) render as HTML | 1019 # 4) render as HTML |
| 1095 preds_html = df_table.to_html(index=False, classes="predictions-table") | 1020 preds_html = df_table.to_html(index=False, classes="predictions-table") |
| 1096 preds_section = ( | 1021 preds_section = ( |
| 1097 "<h2 style='text-align: center;'>Predictions vs. Ground Truth</h2>" | 1022 "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>" |
| 1098 "<div style='overflow-x:auto; margin-bottom:20px;'>" | 1023 "<div style='overflow-y:auto; max-height:400px; overflow-x:auto; margin-bottom:20px;'>" |
| 1099 + preds_html | 1024 + preds_html |
| 1100 + "</div>" | 1025 + "</div>" |
| 1101 ) | 1026 ) |
| 1102 except Exception as e: | 1027 except Exception as e: |
| 1103 logger.warning(f"Could not build Predictions vs GT table: {e}") | 1028 logger.warning(f"Could not build Predictions vs GT table: {e}") |
| 1104 # Test tab = Metrics + Preds table + Visualizations | 1029 tab3_content = test_metrics_html + preds_section |
| 1105 | 1030 if output_type in ("binary", "category"): |
| 1106 tab3_content = ( | 1031 training_stats_path = exp_dir / "training_statistics.json" |
| 1107 test_metrics_html | 1032 interactive_plots = build_classification_plots( |
| 1108 + preds_section | 1033 str(test_stats_path), |
| 1109 + render_img_section("Test Visualizations", test_viz_dir, output_type) | 1034 str(training_stats_path), |
| 1110 ) | 1035 ) |
| 1111 | 1036 for plot in interactive_plots: |
| 1037 # 2) inject the static "roc_curves_from_prediction_statistics.png" | |
| 1038 if plot["title"] == "ROC-AUC": | |
| 1039 static_img = test_viz_dir / "roc_curves_from_prediction_statistics.png" | |
| 1040 if static_img.exists(): | |
| 1041 b64 = encode_image_to_base64(str(static_img)) | |
| 1042 tab3_content += ( | |
| 1043 "<h2 style='text-align: center;'>" | |
| 1044 "Roc Curves From Prediction Statistics" | |
| 1045 "</h2>" | |
| 1046 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' | |
| 1047 f'<img src="data:image/png;base64,{b64}" ' | |
| 1048 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' | |
| 1049 "</div>" | |
| 1050 ) | |
| 1051 # always render the plotly panels exactly as before | |
| 1052 tab3_content += ( | |
| 1053 f"<h2 style='text-align: center;'>{plot['title']}</h2>" | |
| 1054 + plot["html"] | |
| 1055 ) | |
| 1056 tab3_content += render_img_section( | |
| 1057 "Test Visualizations", | |
| 1058 test_viz_dir, | |
| 1059 output_type | |
| 1060 ) | |
| 1112 # assemble the tabs and help modal | 1061 # assemble the tabs and help modal |
| 1113 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) | 1062 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) |
| 1114 modal_html = get_metrics_help_modal() | 1063 modal_html = get_metrics_help_modal() |
| 1115 html += tabbed_html + modal_html + get_html_closing() | 1064 html += tabbed_html + modal_html + get_html_closing() |
| 1116 | |
| 1117 try: | 1065 try: |
| 1118 with open(report_path, "w") as f: | 1066 with open(report_path, "w") as f: |
| 1119 f.write(html) | 1067 f.write(html) |
| 1120 logger.info(f"HTML report generated at: {report_path}") | 1068 logger.info(f"HTML report generated at: {report_path}") |
| 1121 except Exception as e: | 1069 except Exception as e: |
| 1122 logger.error(f"Failed to write HTML report: {e}") | 1070 logger.error(f"Failed to write HTML report: {e}") |
| 1123 raise | 1071 raise |
| 1124 | |
| 1125 return report_path | 1072 return report_path |
| 1126 | 1073 |
| 1127 | 1074 |
| 1128 class WorkflowOrchestrator: | 1075 class WorkflowOrchestrator: |
| 1129 """Manages the image-classification workflow.""" | 1076 """Manages the image-classification workflow.""" |
| 1130 | |
| 1131 def __init__(self, args: argparse.Namespace, backend: Backend): | 1077 def __init__(self, args: argparse.Namespace, backend: Backend): |
| 1132 self.args = args | 1078 self.args = args |
| 1133 self.backend = backend | 1079 self.backend = backend |
| 1134 self.temp_dir: Optional[Path] = None | 1080 self.temp_dir: Optional[Path] = None |
| 1135 self.image_extract_dir: Optional[Path] = None | 1081 self.image_extract_dir: Optional[Path] = None |
| 1165 | 1111 |
| 1166 def _prepare_data(self) -> Tuple[Path, Dict[str, Any], str]: | 1112 def _prepare_data(self) -> Tuple[Path, Dict[str, Any], str]: |
| 1167 """Load CSV, update image paths, handle splits, and write prepared CSV.""" | 1113 """Load CSV, update image paths, handle splits, and write prepared CSV.""" |
| 1168 if not self.temp_dir or not self.image_extract_dir: | 1114 if not self.temp_dir or not self.image_extract_dir: |
| 1169 raise RuntimeError("Temp dirs not initialized before data prep.") | 1115 raise RuntimeError("Temp dirs not initialized before data prep.") |
| 1170 | |
| 1171 try: | 1116 try: |
| 1172 df = pd.read_csv(self.args.csv_file) | 1117 df = pd.read_csv(self.args.csv_file) |
| 1173 logger.info(f"Loaded CSV: {self.args.csv_file}") | 1118 logger.info(f"Loaded CSV: {self.args.csv_file}") |
| 1174 except Exception: | 1119 except Exception: |
| 1175 logger.error("Error loading CSV file", exc_info=True) | 1120 logger.error("Error loading CSV file", exc_info=True) |
| 1176 raise | 1121 raise |
| 1177 | |
| 1178 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} | 1122 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} |
| 1179 missing = required - set(df.columns) | 1123 missing = required - set(df.columns) |
| 1180 if missing: | 1124 if missing: |
| 1181 raise ValueError(f"Missing CSV columns: {', '.join(missing)}") | 1125 raise ValueError(f"Missing CSV columns: {', '.join(missing)}") |
| 1182 | |
| 1183 try: | 1126 try: |
| 1184 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( | 1127 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( |
| 1185 lambda p: str((self.image_extract_dir / p).resolve()) | 1128 lambda p: str((self.image_extract_dir / p).resolve()) |
| 1186 ) | 1129 ) |
| 1187 except Exception: | 1130 except Exception: |
| 1188 logger.error("Error updating image paths", exc_info=True) | 1131 logger.error("Error updating image paths", exc_info=True) |
| 1189 raise | 1132 raise |
| 1190 | |
| 1191 if SPLIT_COLUMN_NAME in df.columns: | 1133 if SPLIT_COLUMN_NAME in df.columns: |
| 1192 df, split_config, split_info = self._process_fixed_split(df) | 1134 df, split_config, split_info = self._process_fixed_split(df) |
| 1193 else: | 1135 else: |
| 1194 logger.info("No split column; creating stratified random split") | 1136 logger.info("No split column; creating stratified random split") |
| 1195 df = create_stratified_random_split( | 1137 df = create_stratified_random_split( |
| 1206 split_info = ( | 1148 split_info = ( |
| 1207 f"No split column in CSV. Created stratified random split: " | 1149 f"No split column in CSV. Created stratified random split: " |
| 1208 f"{[int(p * 100) for p in self.args.split_probabilities]}% " | 1150 f"{[int(p * 100) for p in self.args.split_probabilities]}% " |
| 1209 f"for train/val/test with balanced label distribution." | 1151 f"for train/val/test with balanced label distribution." |
| 1210 ) | 1152 ) |
| 1211 | |
| 1212 final_csv = self.temp_dir / TEMP_CSV_FILENAME | 1153 final_csv = self.temp_dir / TEMP_CSV_FILENAME |
| 1213 try: | 1154 try: |
| 1214 | |
| 1215 df.to_csv(final_csv, index=False) | 1155 df.to_csv(final_csv, index=False) |
| 1216 logger.info(f"Saved prepared data to {final_csv}") | 1156 logger.info(f"Saved prepared data to {final_csv}") |
| 1217 except Exception: | 1157 except Exception: |
| 1218 logger.error("Error saving prepared CSV", exc_info=True) | 1158 logger.error("Error saving prepared CSV", exc_info=True) |
| 1219 raise | 1159 raise |
| 1220 | |
| 1221 return final_csv, split_config, split_info | 1160 return final_csv, split_config, split_info |
| 1222 | 1161 |
| 1223 def _process_fixed_split( | 1162 def _process_fixed_split( |
| 1224 self, df: pd.DataFrame | 1163 self, df: pd.DataFrame |
| 1225 ) -> Tuple[pd.DataFrame, Dict[str, Any], str]: | 1164 ) -> Tuple[pd.DataFrame, Dict[str, Any], str]: |
| 1230 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype( | 1169 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype( |
| 1231 pd.Int64Dtype() | 1170 pd.Int64Dtype() |
| 1232 ) | 1171 ) |
| 1233 if df[SPLIT_COLUMN_NAME].isna().any(): | 1172 if df[SPLIT_COLUMN_NAME].isna().any(): |
| 1234 logger.warning("Split column contains non-numeric/missing values.") | 1173 logger.warning("Split column contains non-numeric/missing values.") |
| 1235 | |
| 1236 unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) | 1174 unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) |
| 1237 logger.info(f"Unique split values: {unique}") | 1175 logger.info(f"Unique split values: {unique}") |
| 1238 | |
| 1239 if unique == {0, 2}: | 1176 if unique == {0, 2}: |
| 1240 df = split_data_0_2( | 1177 df = split_data_0_2( |
| 1241 df, | 1178 df, |
| 1242 SPLIT_COLUMN_NAME, | 1179 SPLIT_COLUMN_NAME, |
| 1243 validation_size=self.args.validation_size, | 1180 validation_size=self.args.validation_size, |
| 1254 elif unique.issubset({0, 1, 2}): | 1191 elif unique.issubset({0, 1, 2}): |
| 1255 split_info = "Used user-defined split column from CSV." | 1192 split_info = "Used user-defined split column from CSV." |
| 1256 logger.info("Using fixed split as-is.") | 1193 logger.info("Using fixed split as-is.") |
| 1257 else: | 1194 else: |
| 1258 raise ValueError(f"Unexpected split values: {unique}") | 1195 raise ValueError(f"Unexpected split values: {unique}") |
| 1259 | |
| 1260 return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info | 1196 return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info |
| 1261 | |
| 1262 except Exception: | 1197 except Exception: |
| 1263 logger.error("Error processing fixed split", exc_info=True) | 1198 logger.error("Error processing fixed split", exc_info=True) |
| 1264 raise | 1199 raise |
| 1265 | 1200 |
| 1266 def _cleanup_temp_dirs(self) -> None: | 1201 def _cleanup_temp_dirs(self) -> None: |
| 1272 | 1207 |
| 1273 def run(self) -> None: | 1208 def run(self) -> None: |
| 1274 """Execute the full workflow end-to-end.""" | 1209 """Execute the full workflow end-to-end.""" |
| 1275 logger.info("Starting workflow...") | 1210 logger.info("Starting workflow...") |
| 1276 self.args.output_dir.mkdir(parents=True, exist_ok=True) | 1211 self.args.output_dir.mkdir(parents=True, exist_ok=True) |
| 1277 | |
| 1278 try: | 1212 try: |
| 1279 self._create_temp_dirs() | 1213 self._create_temp_dirs() |
| 1280 self._extract_images() | 1214 self._extract_images() |
| 1281 csv_path, split_cfg, split_info = self._prepare_data() | 1215 csv_path, split_cfg, split_info = self._prepare_data() |
| 1282 | |
| 1283 use_pretrained = self.args.use_pretrained or self.args.fine_tune | 1216 use_pretrained = self.args.use_pretrained or self.args.fine_tune |
| 1284 | |
| 1285 backend_args = { | 1217 backend_args = { |
| 1286 "model_name": self.args.model_name, | 1218 "model_name": self.args.model_name, |
| 1287 "fine_tune": self.args.fine_tune, | 1219 "fine_tune": self.args.fine_tune, |
| 1288 "use_pretrained": use_pretrained, | 1220 "use_pretrained": use_pretrained, |
| 1289 "epochs": self.args.epochs, | 1221 "epochs": self.args.epochs, |
| 1293 "learning_rate": self.args.learning_rate, | 1225 "learning_rate": self.args.learning_rate, |
| 1294 "random_seed": self.args.random_seed, | 1226 "random_seed": self.args.random_seed, |
| 1295 "early_stop": self.args.early_stop, | 1227 "early_stop": self.args.early_stop, |
| 1296 "label_column_data_path": csv_path, | 1228 "label_column_data_path": csv_path, |
| 1297 "augmentation": self.args.augmentation, | 1229 "augmentation": self.args.augmentation, |
| 1230 "threshold": self.args.threshold, | |
| 1298 } | 1231 } |
| 1299 yaml_str = self.backend.prepare_config(backend_args, split_cfg) | 1232 yaml_str = self.backend.prepare_config(backend_args, split_cfg) |
| 1300 | |
| 1301 config_file = self.temp_dir / TEMP_CONFIG_FILENAME | 1233 config_file = self.temp_dir / TEMP_CONFIG_FILENAME |
| 1302 config_file.write_text(yaml_str) | 1234 config_file.write_text(yaml_str) |
| 1303 logger.info(f"Wrote backend config: {config_file}") | 1235 logger.info(f"Wrote backend config: {config_file}") |
| 1304 | |
| 1305 self.backend.run_experiment( | 1236 self.backend.run_experiment( |
| 1306 csv_path, | 1237 csv_path, |
| 1307 config_file, | 1238 config_file, |
| 1308 self.args.output_dir, | 1239 self.args.output_dir, |
| 1309 self.args.random_seed, | 1240 self.args.random_seed, |
| 1347 "random_contrast": {"type": "random_contrast", "min": 0.5, "max": 2.0}, | 1278 "random_contrast": {"type": "random_contrast", "min": 0.5, "max": 2.0}, |
| 1348 } | 1279 } |
| 1349 aug_list = [] | 1280 aug_list = [] |
| 1350 for tok in aug_string.split(","): | 1281 for tok in aug_string.split(","): |
| 1351 key = tok.strip() | 1282 key = tok.strip() |
| 1352 if not key: | |
| 1353 continue | |
| 1354 if key not in mapping: | 1283 if key not in mapping: |
| 1355 valid = ", ".join(mapping.keys()) | 1284 valid = ", ".join(mapping.keys()) |
| 1356 raise ValueError(f"Unknown augmentation '{key}'. Valid choices: {valid}") | 1285 raise ValueError(f"Unknown augmentation '{key}'. Valid choices: {valid}") |
| 1357 aug_list.append(mapping[key]) | 1286 aug_list.append(mapping[key]) |
| 1358 return aug_list | 1287 return aug_list |
| 1426 help="Where to write outputs", | 1355 help="Where to write outputs", |
| 1427 ) | 1356 ) |
| 1428 parser.add_argument( | 1357 parser.add_argument( |
| 1429 "--validation-size", | 1358 "--validation-size", |
| 1430 type=float, | 1359 type=float, |
| 1431 default=0.1, | 1360 default=0.15, |
| 1432 help="Fraction for validation (0.0–1.0)", | 1361 help="Fraction for validation (0.0–1.0)", |
| 1433 ) | 1362 ) |
| 1434 parser.add_argument( | 1363 parser.add_argument( |
| 1435 "--preprocessing-num-processes", | 1364 "--preprocessing-num-processes", |
| 1436 type=int, | 1365 type=int, |
| 1470 "random_horizontal_flip, random_vertical_flip, random_rotate, " | 1399 "random_horizontal_flip, random_vertical_flip, random_rotate, " |
| 1471 "random_blur, random_brightness, random_contrast. " | 1400 "random_blur, random_brightness, random_contrast. " |
| 1472 "E.g. --augmentation random_horizontal_flip,random_rotate" | 1401 "E.g. --augmentation random_horizontal_flip,random_rotate" |
| 1473 ), | 1402 ), |
| 1474 ) | 1403 ) |
| 1475 | 1404 parser.add_argument( |
| 1405 "--threshold", | |
| 1406 type=float, | |
| 1407 default=None, | |
| 1408 help=( | |
| 1409 "Decision threshold for binary classification (0.0–1.0)." | |
| 1410 "Overrides default 0.5." | |
| 1411 ) | |
| 1412 ) | |
| 1476 args = parser.parse_args() | 1413 args = parser.parse_args() |
| 1477 | |
| 1478 if not 0.0 <= args.validation_size <= 1.0: | 1414 if not 0.0 <= args.validation_size <= 1.0: |
| 1479 parser.error("validation-size must be between 0.0 and 1.0") | 1415 parser.error("validation-size must be between 0.0 and 1.0") |
| 1480 if not args.csv_file.is_file(): | 1416 if not args.csv_file.is_file(): |
| 1481 parser.error(f"CSV not found: {args.csv_file}") | 1417 parser.error(f"CSV not found: {args.csv_file}") |
| 1482 if not args.image_zip.is_file(): | 1418 if not args.image_zip.is_file(): |
| 1485 try: | 1421 try: |
| 1486 augmentation_setup = aug_parse(args.augmentation) | 1422 augmentation_setup = aug_parse(args.augmentation) |
| 1487 setattr(args, "augmentation", augmentation_setup) | 1423 setattr(args, "augmentation", augmentation_setup) |
| 1488 except ValueError as e: | 1424 except ValueError as e: |
| 1489 parser.error(str(e)) | 1425 parser.error(str(e)) |
| 1490 | |
| 1491 backend_instance = LudwigDirectBackend() | 1426 backend_instance = LudwigDirectBackend() |
| 1492 orchestrator = WorkflowOrchestrator(args, backend_instance) | 1427 orchestrator = WorkflowOrchestrator(args, backend_instance) |
| 1493 | |
| 1494 exit_code = 0 | 1428 exit_code = 0 |
| 1495 try: | 1429 try: |
| 1496 orchestrator.run() | 1430 orchestrator.run() |
| 1497 logger.info("Main script finished successfully.") | 1431 logger.info("Main script finished successfully.") |
| 1498 except Exception as e: | 1432 except Exception as e: |
| 1503 | 1437 |
| 1504 | 1438 |
| 1505 if __name__ == "__main__": | 1439 if __name__ == "__main__": |
| 1506 try: | 1440 try: |
| 1507 import ludwig | 1441 import ludwig |
| 1508 | |
| 1509 logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") | 1442 logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") |
| 1510 except ImportError: | 1443 except ImportError: |
| 1511 logger.error( | 1444 logger.error( |
| 1512 "Ludwig library not found. Please ensure Ludwig is installed " | 1445 "Ludwig library not found. Please ensure Ludwig is installed " |
| 1513 "('pip install ludwig[image]')" | 1446 "('pip install ludwig[image]')" |
| 1514 ) | 1447 ) |
| 1515 sys.exit(1) | 1448 sys.exit(1) |
| 1516 | |
| 1517 main() | 1449 main() |
