comparison image_learner_cli.py @ 8:85e6f4b2ad18 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 8a42eb9b33df7e1df5ad5153b380e20b910a05b6
author goeckslab
date Thu, 14 Aug 2025 14:53:10 +0000
parents 801a8b6973fb
children
comparison
equal deleted inserted replaced
7:801a8b6973fb 8:85e6f4b2ad18
29 TEST_STATISTICS_FILE_NAME, 29 TEST_STATISTICS_FILE_NAME,
30 TRAIN_SET_METADATA_FILE_NAME, 30 TRAIN_SET_METADATA_FILE_NAME,
31 ) 31 )
32 from ludwig.utils.data_utils import get_split_path 32 from ludwig.utils.data_utils import get_split_path
33 from ludwig.visualize import get_visualizations_registry 33 from ludwig.visualize import get_visualizations_registry
34 from plotly_plots import build_classification_plots
34 from sklearn.model_selection import train_test_split 35 from sklearn.model_selection import train_test_split
35 from utils import ( 36 from utils import (
36 build_tabbed_html, 37 build_tabbed_html,
37 encode_image_to_base64, 38 encode_image_to_base64,
38 get_html_closing, 39 get_html_closing,
50 51
51 def format_config_table_html( 52 def format_config_table_html(
52 config: dict, 53 config: dict,
53 split_info: Optional[str] = None, 54 split_info: Optional[str] = None,
54 training_progress: dict = None, 55 training_progress: dict = None,
56 output_type: Optional[str] = None,
55 ) -> str: 57 ) -> str:
56 display_keys = [ 58 display_keys = [
57 "task_type", 59 "task_type",
58 "model_name", 60 "model_name",
59 "epochs", 61 "epochs",
61 "fine_tune", 63 "fine_tune",
62 "use_pretrained", 64 "use_pretrained",
63 "learning_rate", 65 "learning_rate",
64 "random_seed", 66 "random_seed",
65 "early_stop", 67 "early_stop",
68 "threshold",
66 ] 69 ]
67
68 rows = [] 70 rows = []
69
70 for key in display_keys: 71 for key in display_keys:
71 val = config.get(key, "N/A") 72 val = config.get(key, None)
72 if key == "task_type": 73 if key == "threshold":
73 val = val.title() if isinstance(val, str) else val 74 if output_type != "binary":
74 if key == "batch_size": 75 continue
75 if val is not None: 76 val = val if val is not None else 0.5
76 val = int(val) 77 val_str = f"{val:.2f}"
78 if val == 0.5:
79 val_str += " (default)"
80 else:
81 if key == "task_type":
82 val_str = val.title() if isinstance(val, str) else "N/A"
83 elif key == "batch_size":
84 if val is not None:
85 val_str = int(val)
86 else:
87 if training_progress:
88 resolved_val = training_progress.get("batch_size")
89 val_str = (
90 "Auto-selected batch size by Ludwig:<br>"
91 f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>"
92 )
93 else:
94 val_str = "auto"
95 elif key == "learning_rate":
96 if val is not None and val != "auto":
97 val_str = f"{val:.6f}"
98 else:
99 if training_progress:
100 resolved_val = training_progress.get("learning_rate")
101 val_str = (
102 "Auto-selected learning rate by Ludwig:<br>"
103 f"<span style='font-size: 0.85em;'>"
104 f"{resolved_val if resolved_val else 'auto'}</span><br>"
105 "<span style='font-size: 0.85em;'>"
106 "Based on model architecture and training setup "
107 "(e.g., fine-tuning).<br>"
108 "</span>"
109 )
110 else:
111 val_str = (
112 "Auto-selected by Ludwig<br>"
113 "<span style='font-size: 0.85em;'>"
114 "Automatically tuned based on architecture and dataset.<br>"
115 "See <a href='https://ludwig.ai/latest/configuration/trainer/"
116 "#trainer-parameters' target='_blank'>"
117 "Ludwig Trainer Parameters</a> for details."
118 "</span>"
119 )
120 elif key == "epochs":
121 if val is None:
122 val_str = "N/A"
123 else:
124 if (
125 training_progress
126 and "epoch" in training_progress
127 and val > training_progress["epoch"]
128 ):
129 val_str = (
130 f"Because of early stopping: the training "
131 f"stopped at epoch {training_progress['epoch']}"
132 )
133 else:
134 val_str = val
77 else: 135 else:
78 if training_progress: 136 val_str = val if val is not None else "N/A"
79 val = "Auto-selected batch size by Ludwig:<br>" 137 if val_str == "N/A" and key not in ["task_type"]: # Skip if N/A for non-essential
80 resolved_val = training_progress.get("batch_size") 138 continue
81 val += f"<span style='font-size: 0.85em;'>{resolved_val}</span><br>"
82 else:
83 val = "auto"
84 if key == "learning_rate":
85 resolved_val = None
86 if val is None or val == "auto":
87 if training_progress:
88 resolved_val = training_progress.get("learning_rate")
89 val = (
90 "Auto-selected learning rate by Ludwig:<br>"
91 f"<span style='font-size: 0.85em;'>"
92 f"{resolved_val if resolved_val else val}</span><br>"
93 "<span style='font-size: 0.85em;'>"
94 "Based on model architecture and training setup "
95 "(e.g., fine-tuning).<br>"
96 "See <a href='https://ludwig.ai/latest/configuration/trainer/"
97 "#trainer-parameters' target='_blank'>"
98 "Ludwig Trainer Parameters</a> for details."
99 "</span>"
100 )
101 else:
102 val = (
103 "Auto-selected by Ludwig<br>"
104 "<span style='font-size: 0.85em;'>"
105 "Automatically tuned based on architecture and dataset.<br>"
106 "See <a href='https://ludwig.ai/latest/configuration/trainer/"
107 "#trainer-parameters' target='_blank'>"
108 "Ludwig Trainer Parameters</a> for details."
109 "</span>"
110 )
111 else:
112 val = f"{val:.6f}"
113 if key == "epochs":
114 if (
115 training_progress
116 and "epoch" in training_progress
117 and val > training_progress["epoch"]
118 ):
119 val = (
120 f"Because of early stopping: the training "
121 f"stopped at epoch {training_progress['epoch']}"
122 )
123
124 if val is None:
125 continue
126 rows.append( 139 rows.append(
127 f"<tr>" 140 f"<tr>"
128 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" 141 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>"
129 f"{key.replace('_', ' ').title()}</td>" 142 f"{key.replace('_', ' ').title()}</td>"
130 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" 143 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>"
131 f"{val}</td>" 144 f"{val_str}</td>"
132 f"</tr>" 145 f"</tr>"
133 ) 146 )
134
135 aug_cfg = config.get("augmentation") 147 aug_cfg = config.get("augmentation")
136 if aug_cfg: 148 if aug_cfg:
137 types = [str(a.get("type", "")) for a in aug_cfg] 149 types = [str(a.get("type", "")) for a in aug_cfg]
138 aug_val = ", ".join(types) 150 aug_val = ", ".join(types)
139 rows.append( 151 rows.append(
140 "<tr>" 152 f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>"
141 "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Augmentation</td>" 153 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{aug_val}</td></tr>"
142 "<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" 154 )
143 f"{aug_val}</td>"
144 "</tr>"
145 )
146
147 if split_info: 155 if split_info:
148 rows.append( 156 rows.append(
149 f"<tr>" 157 f"<tr><td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>Data Split</td>"
150 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: left;'>" 158 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>{split_info}</td></tr>"
151 f"Data Split</td>" 159 )
152 f"<td style='padding: 6px 12px; border: 1px solid #ccc; text-align: center;'>" 160 html = f"""
153 f"{split_info}</td>" 161 <h2 style="text-align: center;">Model and Training Summary</h2>
154 f"</tr>" 162 <div style="display: flex; justify-content: center;">
155 ) 163 <table style="border-collapse: collapse; width: 100%; table-layout: fixed;">
156 164 <thead><tr>
157 return ( 165 <th style="padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Parameter</th>
158 "<h2 style='text-align: center;'>Training Setup</h2>" 166 <th style="padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Value</th>
159 "<div style='display: flex; justify-content: center;'>" 167 </tr></thead>
160 "<table style='border-collapse: collapse; width: 60%; table-layout: auto;'>" 168 <tbody>
161 "<thead><tr>" 169 {''.join(rows)}
162 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left;'>" 170 </tbody>
163 "Parameter</th>" 171 </table>
164 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center;'>" 172 </div><br>
165 "Value</th>" 173 <p style="text-align: center; font-size: 0.9em;">
166 "</tr></thead><tbody>" + "".join(rows) + "</tbody></table></div><br>" 174 Model trained using <a href="https://ludwig.ai/" target="_blank" rel="noopener noreferrer">Ludwig</a>.
167 "<p style='text-align: center; font-size: 0.9em;'>" 175 <a href="https://ludwig.ai/latest/configuration/" target="_blank" rel="noopener noreferrer">
168 "Model trained using Ludwig.<br>" 176 Ludwig documentation provides detailed information about default model and training parameters
169 "If want to learn more about Ludwig default settings," 177 </a>
170 "please check their <a href='https://ludwig.ai' target='_blank'>" 178 </p><hr>
171 "website(ludwig.ai)</a>." 179 """
172 "</p><hr>" 180 return html
173 )
174 181
175 182
176 def detect_output_type(test_stats): 183 def detect_output_type(test_stats):
177 """Detects if the output type is 'binary' or 'category' based on test statistics.""" 184 """Detects if the output type is 'binary' or 'category' based on test statistics."""
178 label_stats = test_stats.get("label", {}) 185 label_stats = test_stats.get("label", {})
242 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"), 249 "accuracy_micro": get_last_value(label_stats, "accuracy_micro"),
243 "loss": get_last_value(label_stats, "loss"), 250 "loss": get_last_value(label_stats, "loss"),
244 "roc_auc": get_last_value(label_stats, "roc_auc"), 251 "roc_auc": get_last_value(label_stats, "roc_auc"),
245 "hits_at_k": get_last_value(label_stats, "hits_at_k"), 252 "hits_at_k": get_last_value(label_stats, "hits_at_k"),
246 } 253 }
247
248 # Test metrics: dynamic extraction according to exclusions 254 # Test metrics: dynamic extraction according to exclusions
249 test_label_stats = test_stats.get("label", {}) 255 test_label_stats = test_stats.get("label", {})
250 if not test_label_stats: 256 if not test_label_stats:
251 logging.warning("No label statistics found for test split") 257 logging.warning("No label statistics found for test split")
252 else: 258 else:
253 combined_stats = test_stats.get("combined", {}) 259 combined_stats = test_stats.get("combined", {})
254 overall_stats = test_label_stats.get("overall_stats", {}) 260 overall_stats = test_label_stats.get("overall_stats", {})
255
256 # Define exclusions 261 # Define exclusions
257 if output_type == "binary": 262 if output_type == "binary":
258 exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} 263 exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"}
259 else: 264 else:
260 exclude = {"per_class_stats", "confusion_matrix"} 265 exclude = {"per_class_stats", "confusion_matrix"}
261
262 # 1. Get all scalar test_label_stats not excluded 266 # 1. Get all scalar test_label_stats not excluded
263 test_metrics = {} 267 test_metrics = {}
264 for k, v in test_label_stats.items(): 268 for k, v in test_label_stats.items():
265 if k in exclude: 269 if k in exclude:
266 continue 270 continue
267 if k == "overall_stats": 271 if k == "overall_stats":
268 continue 272 continue
269 if isinstance(v, (int, float, str, bool)): 273 if isinstance(v, (int, float, str, bool)):
270 test_metrics[k] = v 274 test_metrics[k] = v
271
272 # 2. Add overall_stats (flattened) 275 # 2. Add overall_stats (flattened)
273 for k, v in overall_stats.items(): 276 for k, v in overall_stats.items():
274 test_metrics[k] = v 277 test_metrics[k] = v
275
276 # 3. Optionally include combined/loss if present and not already 278 # 3. Optionally include combined/loss if present and not already
277 if "loss" in combined_stats and "loss" not in test_metrics: 279 if "loss" in combined_stats and "loss" not in test_metrics:
278 test_metrics["loss"] = combined_stats["loss"] 280 test_metrics["loss"] = combined_stats["loss"]
279
280 metrics["test"] = test_metrics 281 metrics["test"] = test_metrics
281
282 return metrics 282 return metrics
283 283
284 284
285 def generate_table_row(cells, styles): 285 def generate_table_row(cells, styles):
286 """Helper function to generate an HTML table row.""" 286 """Helper function to generate an HTML table row."""
287 return ( 287 return (
288 "<tr>" 288 "<tr>"
289 + "".join(f"<td style='{styles}'>{cell}</td>" for cell in cells) 289 + "".join(f"<td style='{styles}'>{cell}</td>" for cell in cells)
290 + "</tr>" 290 + "</tr>"
291 ) 291 )
292
293
294 # -----------------------------------------
295 # 2) MODEL PERFORMANCE (Train/Val/Test) TABLE
296 # -----------------------------------------
292 297
293 298
294 def format_stats_table_html(train_stats: dict, test_stats: dict) -> str: 299 def format_stats_table_html(train_stats: dict, test_stats: dict) -> str:
295 """Formats a combined HTML table for training, validation, and test metrics.""" 300 """Formats a combined HTML table for training, validation, and test metrics."""
296 output_type = detect_output_type(test_stats) 301 output_type = detect_output_type(test_stats)
308 t = all_metrics["training"].get(metric_key) 313 t = all_metrics["training"].get(metric_key)
309 v = all_metrics["validation"].get(metric_key) 314 v = all_metrics["validation"].get(metric_key)
310 te = all_metrics["test"].get(metric_key) 315 te = all_metrics["test"].get(metric_key)
311 if all(x is not None for x in [t, v, te]): 316 if all(x is not None for x in [t, v, te]):
312 rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"]) 317 rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"])
313
314 if not rows: 318 if not rows:
315 return "<table><tr><td>No metric values found.</td></tr></table>" 319 return "<table><tr><td>No metric values found.</td></tr></table>"
316
317 html = ( 320 html = (
318 "<h2 style='text-align: center;'>Model Performance Summary</h2>" 321 "<h2 style='text-align: center;'>Model Performance Summary</h2>"
319 "<div style='display: flex; justify-content: center;'>" 322 "<div style='display: flex; justify-content: center;'>"
320 "<table style='border-collapse: collapse; table-layout: auto;'>" 323 "<table class='performance-summary' style='border-collapse: collapse;'>"
321 "<thead><tr>" 324 "<thead><tr>"
322 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " 325 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>"
323 "white-space: nowrap;'>Metric</th>" 326 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th>"
324 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " 327 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th>"
325 "white-space: nowrap;'>Train</th>" 328 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th>"
326 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; "
327 "white-space: nowrap;'>Validation</th>"
328 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; "
329 "white-space: nowrap;'>Test</th>"
330 "</tr></thead><tbody>" 329 "</tr></thead><tbody>"
331 ) 330 )
332 for row in rows: 331 for row in rows:
333 html += generate_table_row( 332 html += generate_table_row(
334 row, 333 row,
335 "padding: 10px; border: 1px solid #ccc; text-align: center; " 334 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;"
336 "white-space: nowrap;",
337 ) 335 )
338 html += "</tbody></table></div><br>" 336 html += "</tbody></table></div><br>"
339 return html 337 return html
338
339
340 # -------------------------------------------
341 # 3) TRAIN/VALIDATION PERFORMANCE SUMMARY TABLE
342 # -------------------------------------------
340 343
341 344
342 def format_train_val_stats_table_html(train_stats: dict, test_stats: dict) -> str: 345 def format_train_val_stats_table_html(train_stats: dict, test_stats: dict) -> str:
343 """Formats an HTML table for training and validation metrics.""" 346 """Formats an HTML table for training and validation metrics."""
344 output_type = detect_output_type(test_stats) 347 output_type = detect_output_type(test_stats)
352 ) 355 )
353 t = all_metrics["training"].get(metric_key) 356 t = all_metrics["training"].get(metric_key)
354 v = all_metrics["validation"].get(metric_key) 357 v = all_metrics["validation"].get(metric_key)
355 if t is not None and v is not None: 358 if t is not None and v is not None:
356 rows.append([display_name, f"{t:.4f}", f"{v:.4f}"]) 359 rows.append([display_name, f"{t:.4f}", f"{v:.4f}"])
357
358 if not rows: 360 if not rows:
359 return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>" 361 return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>"
360
361 html = ( 362 html = (
362 "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>" 363 "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>"
363 "<div style='display: flex; justify-content: center;'>" 364 "<div style='display: flex; justify-content: center;'>"
364 "<table style='border-collapse: collapse; table-layout: auto;'>" 365 "<table class='performance-summary' style='border-collapse: collapse;'>"
365 "<thead><tr>" 366 "<thead><tr>"
366 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " 367 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>"
367 "white-space: nowrap;'>Metric</th>" 368 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Train</th>"
368 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; " 369 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Validation</th>"
369 "white-space: nowrap;'>Train</th>"
370 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; "
371 "white-space: nowrap;'>Validation</th>"
372 "</tr></thead><tbody>" 370 "</tr></thead><tbody>"
373 ) 371 )
374 for row in rows: 372 for row in rows:
375 html += generate_table_row( 373 html += generate_table_row(
376 row, 374 row,
377 "padding: 10px; border: 1px solid #ccc; text-align: center; " 375 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;"
378 "white-space: nowrap;",
379 ) 376 )
380 html += "</tbody></table></div><br>" 377 html += "</tbody></table></div><br>"
381 return html 378 return html
379
380
381 # -----------------------------------------
382 # 4) TEST‐ONLY PERFORMANCE SUMMARY TABLE
383 # -----------------------------------------
382 384
383 385
384 def format_test_merged_stats_table_html( 386 def format_test_merged_stats_table_html(
385 test_metrics: Dict[str, Optional[float]], 387 test_metrics: Dict[str, Optional[float]],
386 ) -> str: 388 ) -> str:
389 for key in sorted(test_metrics.keys()): 391 for key in sorted(test_metrics.keys()):
390 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title()) 392 display_name = METRIC_DISPLAY_NAMES.get(key, key.replace("_", " ").title())
391 value = test_metrics[key] 393 value = test_metrics[key]
392 if value is not None: 394 if value is not None:
393 rows.append([display_name, f"{value:.4f}"]) 395 rows.append([display_name, f"{value:.4f}"])
394
395 if not rows: 396 if not rows:
396 return "<table><tr><td>No test metric values found.</td></tr></table>" 397 return "<table><tr><td>No test metric values found.</td></tr></table>"
397
398 html = ( 398 html = (
399 "<h2 style='text-align: center;'>Test Performance Summary</h2>" 399 "<h2 style='text-align: center;'>Test Performance Summary</h2>"
400 "<div style='display: flex; justify-content: center;'>" 400 "<div style='display: flex; justify-content: center;'>"
401 "<table style='border-collapse: collapse; table-layout: auto;'>" 401 "<table class='performance-summary' style='border-collapse: collapse;'>"
402 "<thead><tr>" 402 "<thead><tr>"
403 "<th style='padding: 10px; border: 1px solid #ccc; text-align: left; " 403 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: left; white-space: nowrap;'>Metric</th>"
404 "white-space: nowrap;'>Metric</th>" 404 "<th class='sortable' style='padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;'>Test</th>"
405 "<th style='padding: 10px; border: 1px solid #ccc; text-align: center; "
406 "white-space: nowrap;'>Test</th>"
407 "</tr></thead><tbody>" 405 "</tr></thead><tbody>"
408 ) 406 )
409 for row in rows: 407 for row in rows:
410 html += generate_table_row( 408 html += generate_table_row(
411 row, 409 row,
412 "padding: 10px; border: 1px solid #ccc; text-align: center; " 410 "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;"
413 "white-space: nowrap;",
414 ) 411 )
415 html += "</tbody></table></div><br>" 412 html += "</tbody></table></div><br>"
416 return html 413 return html
417 414
418 415
424 label_column: Optional[str] = None, 421 label_column: Optional[str] = None,
425 ) -> pd.DataFrame: 422 ) -> pd.DataFrame:
426 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" 423 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation)."""
427 out = df.copy() 424 out = df.copy()
428 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) 425 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int)
429
430 idx_train = out.index[out[split_column] == 0].tolist() 426 idx_train = out.index[out[split_column] == 0].tolist()
431
432 if not idx_train: 427 if not idx_train:
433 logger.info("No rows with split=0; nothing to do.") 428 logger.info("No rows with split=0; nothing to do.")
434 return out 429 return out
435
436 # Always use stratify if possible 430 # Always use stratify if possible
437 stratify_arr = None 431 stratify_arr = None
438 if label_column and label_column in out.columns: 432 if label_column and label_column in out.columns:
439 label_counts = out.loc[idx_train, label_column].value_counts() 433 label_counts = out.loc[idx_train, label_column].value_counts()
440 if label_counts.size > 1: 434 if label_counts.size > 1:
448 logger.info(f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation") 442 logger.info(f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation")
449 stratify_arr = out.loc[idx_train, label_column] 443 stratify_arr = out.loc[idx_train, label_column]
450 logger.info("Using stratified split for validation set") 444 logger.info("Using stratified split for validation set")
451 else: 445 else:
452 logger.warning("Only one label class found; cannot stratify") 446 logger.warning("Only one label class found; cannot stratify")
453
454 if validation_size <= 0: 447 if validation_size <= 0:
455 logger.info("validation_size <= 0; keeping all as train.") 448 logger.info("validation_size <= 0; keeping all as train.")
456 return out 449 return out
457 if validation_size >= 1: 450 if validation_size >= 1:
458 logger.info("validation_size >= 1; moving all train → validation.") 451 logger.info("validation_size >= 1; moving all train → validation.")
459 out.loc[idx_train, split_column] = 1 452 out.loc[idx_train, split_column] = 1
460 return out 453 return out
461
462 # Always try stratified split first 454 # Always try stratified split first
463 try: 455 try:
464 train_idx, val_idx = train_test_split( 456 train_idx, val_idx = train_test_split(
465 idx_train, 457 idx_train,
466 test_size=validation_size, 458 test_size=validation_size,
474 idx_train, 466 idx_train,
475 test_size=validation_size, 467 test_size=validation_size,
476 random_state=random_state, 468 random_state=random_state,
477 stratify=None, 469 stratify=None,
478 ) 470 )
479
480 out.loc[train_idx, split_column] = 0 471 out.loc[train_idx, split_column] = 0
481 out.loc[val_idx, split_column] = 1 472 out.loc[val_idx, split_column] = 1
482 out[split_column] = out[split_column].astype(int) 473 out[split_column] = out[split_column].astype(int)
483 return out 474 return out
484 475
490 random_state: int = 42, 481 random_state: int = 42,
491 label_column: Optional[str] = None, 482 label_column: Optional[str] = None,
492 ) -> pd.DataFrame: 483 ) -> pd.DataFrame:
493 """Create a stratified random split when no split column exists.""" 484 """Create a stratified random split when no split column exists."""
494 out = df.copy() 485 out = df.copy()
495
496 # initialize split column 486 # initialize split column
497 out[split_column] = 0 487 out[split_column] = 0
498
499 if not label_column or label_column not in out.columns: 488 if not label_column or label_column not in out.columns:
500 logger.warning("No label column found; using random split without stratification") 489 logger.warning("No label column found; using random split without stratification")
501 # fall back to simple random assignment 490 # fall back to simple random assignment
502 indices = out.index.tolist() 491 indices = out.index.tolist()
503 np.random.seed(random_state) 492 np.random.seed(random_state)
504 np.random.shuffle(indices) 493 np.random.shuffle(indices)
505
506 n_total = len(indices) 494 n_total = len(indices)
507 n_train = int(n_total * split_probabilities[0]) 495 n_train = int(n_total * split_probabilities[0])
508 n_val = int(n_total * split_probabilities[1]) 496 n_val = int(n_total * split_probabilities[1])
509
510 out.loc[indices[:n_train], split_column] = 0 497 out.loc[indices[:n_train], split_column] = 0
511 out.loc[indices[n_train:n_train + n_val], split_column] = 1 498 out.loc[indices[n_train:n_train + n_val], split_column] = 1
512 out.loc[indices[n_train + n_val:], split_column] = 2 499 out.loc[indices[n_train + n_val:], split_column] = 2
513
514 return out.astype({split_column: int}) 500 return out.astype({split_column: int})
515
516 # check if stratification is possible 501 # check if stratification is possible
517 label_counts = out[label_column].value_counts() 502 label_counts = out[label_column].value_counts()
518 min_samples_per_class = label_counts.min() 503 min_samples_per_class = label_counts.min()
519
520 # ensure we have enough samples for stratification: 504 # ensure we have enough samples for stratification:
521 # Each class must have at least as many samples as the number of splits, 505 # Each class must have at least as many samples as the number of splits,
522 # so that each split can receive at least one sample per class. 506 # so that each split can receive at least one sample per class.
523 min_samples_required = len(split_probabilities) 507 min_samples_required = len(split_probabilities)
524 if min_samples_per_class < min_samples_required: 508 if min_samples_per_class < min_samples_required:
527 ) 511 )
528 # fall back to simple random assignment 512 # fall back to simple random assignment
529 indices = out.index.tolist() 513 indices = out.index.tolist()
530 np.random.seed(random_state) 514 np.random.seed(random_state)
531 np.random.shuffle(indices) 515 np.random.shuffle(indices)
532
533 n_total = len(indices) 516 n_total = len(indices)
534 n_train = int(n_total * split_probabilities[0]) 517 n_train = int(n_total * split_probabilities[0])
535 n_val = int(n_total * split_probabilities[1]) 518 n_val = int(n_total * split_probabilities[1])
536
537 out.loc[indices[:n_train], split_column] = 0 519 out.loc[indices[:n_train], split_column] = 0
538 out.loc[indices[n_train:n_train + n_val], split_column] = 1 520 out.loc[indices[n_train:n_train + n_val], split_column] = 1
539 out.loc[indices[n_train + n_val:], split_column] = 2 521 out.loc[indices[n_train + n_val:], split_column] = 2
540
541 return out.astype({split_column: int}) 522 return out.astype({split_column: int})
542
543 logger.info("Using stratified random split for train/validation/test sets") 523 logger.info("Using stratified random split for train/validation/test sets")
544
545 # first split: separate test set 524 # first split: separate test set
546 train_val_idx, test_idx = train_test_split( 525 train_val_idx, test_idx = train_test_split(
547 out.index.tolist(), 526 out.index.tolist(),
548 test_size=split_probabilities[2], 527 test_size=split_probabilities[2],
549 random_state=random_state, 528 random_state=random_state,
550 stratify=out[label_column], 529 stratify=out[label_column],
551 ) 530 )
552
553 # second split: separate training and validation from remaining data 531 # second split: separate training and validation from remaining data
554 val_size_adjusted = split_probabilities[1] / (split_probabilities[0] + split_probabilities[1]) 532 val_size_adjusted = split_probabilities[1] / (split_probabilities[0] + split_probabilities[1])
555 train_idx, val_idx = train_test_split( 533 train_idx, val_idx = train_test_split(
556 train_val_idx, 534 train_val_idx,
557 test_size=val_size_adjusted, 535 test_size=val_size_adjusted,
558 random_state=random_state, 536 random_state=random_state,
559 stratify=out.loc[train_val_idx, label_column], 537 stratify=out.loc[train_val_idx, label_column],
560 ) 538 )
561
562 # assign split values 539 # assign split values
563 out.loc[train_idx, split_column] = 0 540 out.loc[train_idx, split_column] = 0
564 out.loc[val_idx, split_column] = 1 541 out.loc[val_idx, split_column] = 1
565 out.loc[test_idx, split_column] = 2 542 out.loc[test_idx, split_column] = 2
566
567 logger.info("Successfully applied stratified random split") 543 logger.info("Successfully applied stratified random split")
568 logger.info(f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}") 544 logger.info(f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")
569
570 return out.astype({split_column: int}) 545 return out.astype({split_column: int})
571 546
572 547
573 class Backend(Protocol): 548 class Backend(Protocol):
574 """Interface for a machine learning backend.""" 549 """Interface for a machine learning backend."""
575
576 def prepare_config( 550 def prepare_config(
577 self, 551 self,
578 config_params: Dict[str, Any], 552 config_params: Dict[str, Any],
579 split_config: Dict[str, Any], 553 split_config: Dict[str, Any],
580 ) -> str: 554 ) -> str:
602 ... 576 ...
603 577
604 578
605 class LudwigDirectBackend: 579 class LudwigDirectBackend:
606 """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" 580 """Backend for running Ludwig experiments directly via the internal experiment_cli function."""
607
608 def prepare_config( 581 def prepare_config(
609 self, 582 self,
610 config_params: Dict[str, Any], 583 config_params: Dict[str, Any],
611 split_config: Dict[str, Any], 584 split_config: Dict[str, Any],
612 ) -> str: 585 ) -> str:
613 logger.info("LudwigDirectBackend: Preparing YAML configuration.") 586 logger.info("LudwigDirectBackend: Preparing YAML configuration.")
614
615 model_name = config_params.get("model_name", "resnet18") 587 model_name = config_params.get("model_name", "resnet18")
616 use_pretrained = config_params.get("use_pretrained", False) 588 use_pretrained = config_params.get("use_pretrained", False)
617 fine_tune = config_params.get("fine_tune", False) 589 fine_tune = config_params.get("fine_tune", False)
618 if use_pretrained: 590 if use_pretrained:
619 trainable = bool(fine_tune) 591 trainable = bool(fine_tune)
632 "use_pretrained": use_pretrained, 604 "use_pretrained": use_pretrained,
633 "trainable": trainable, 605 "trainable": trainable,
634 } 606 }
635 else: 607 else:
636 encoder_config = {"type": raw_encoder} 608 encoder_config = {"type": raw_encoder}
637
638 batch_size_cfg = batch_size or "auto" 609 batch_size_cfg = batch_size or "auto"
639
640 label_column_path = config_params.get("label_column_data_path") 610 label_column_path = config_params.get("label_column_data_path")
641 label_series = None 611 label_series = None
642 if label_column_path is not None and Path(label_column_path).exists(): 612 if label_column_path is not None and Path(label_column_path).exists():
643 try: 613 try:
644 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] 614 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME]
645 except Exception as e: 615 except Exception as e:
646 logger.warning(f"Could not read label column for task detection: {e}") 616 logger.warning(f"Could not read label column for task detection: {e}")
647
648 if ( 617 if (
649 label_series is not None 618 label_series is not None
650 and ptypes.is_numeric_dtype(label_series.dtype) 619 and ptypes.is_numeric_dtype(label_series.dtype)
651 and label_series.nunique() > 10 620 and label_series.nunique() > 10
652 ): 621 ):
653 task_type = "regression" 622 task_type = "regression"
654 else: 623 else:
655 task_type = "classification" 624 task_type = "classification"
656
657 config_params["task_type"] = task_type 625 config_params["task_type"] = task_type
658
659 image_feat: Dict[str, Any] = { 626 image_feat: Dict[str, Any] = {
660 "name": IMAGE_PATH_COLUMN_NAME, 627 "name": IMAGE_PATH_COLUMN_NAME,
661 "type": "image", 628 "type": "image",
662 "encoder": encoder_config, 629 "encoder": encoder_config,
663 } 630 }
664 if config_params.get("augmentation") is not None: 631 if config_params.get("augmentation") is not None:
665 image_feat["augmentation"] = config_params["augmentation"] 632 image_feat["augmentation"] = config_params["augmentation"]
666
667 if task_type == "regression": 633 if task_type == "regression":
668 output_feat = { 634 output_feat = {
669 "name": LABEL_COLUMN_NAME, 635 "name": LABEL_COLUMN_NAME,
670 "type": "number", 636 "type": "number",
671 "decoder": {"type": "regressor"}, 637 "decoder": {"type": "regressor"},
677 "r2", 643 "r2",
678 ] 644 ]
679 }, 645 },
680 } 646 }
681 val_metric = config_params.get("validation_metric", "mean_squared_error") 647 val_metric = config_params.get("validation_metric", "mean_squared_error")
682
683 else: 648 else:
684 num_unique_labels = ( 649 num_unique_labels = (
685 label_series.nunique() if label_series is not None else 2 650 label_series.nunique() if label_series is not None else 2
686 ) 651 )
687 output_type = "binary" if num_unique_labels == 2 else "category" 652 output_type = "binary" if num_unique_labels == 2 else "category"
688 output_feat = {"name": LABEL_COLUMN_NAME, "type": output_type} 653 output_feat = {"name": LABEL_COLUMN_NAME, "type": output_type}
654 if output_type == "binary" and config_params.get("threshold") is not None:
655 output_feat["threshold"] = float(config_params["threshold"])
689 val_metric = None 656 val_metric = None
690
691 conf: Dict[str, Any] = { 657 conf: Dict[str, Any] = {
692 "model_type": "ecd", 658 "model_type": "ecd",
693 "input_features": [image_feat], 659 "input_features": [image_feat],
694 "output_features": [output_feat], 660 "output_features": [output_feat],
695 "combiner": {"type": "concat"}, 661 "combiner": {"type": "concat"},
705 "split": split_config, 671 "split": split_config,
706 "num_processes": num_processes, 672 "num_processes": num_processes,
707 "in_memory": False, 673 "in_memory": False,
708 }, 674 },
709 } 675 }
710
711 logger.debug("LudwigDirectBackend: Config dict built.") 676 logger.debug("LudwigDirectBackend: Config dict built.")
712 try: 677 try:
713 yaml_str = yaml.dump(conf, sort_keys=False, indent=2) 678 yaml_str = yaml.dump(conf, sort_keys=False, indent=2)
714 logger.info("LudwigDirectBackend: YAML config generated.") 679 logger.info("LudwigDirectBackend: YAML config generated.")
715 return yaml_str 680 return yaml_str
727 output_dir: Path, 692 output_dir: Path,
728 random_seed: int = 42, 693 random_seed: int = 42,
729 ) -> None: 694 ) -> None:
730 """Invoke Ludwig's internal experiment_cli function to run the experiment.""" 695 """Invoke Ludwig's internal experiment_cli function to run the experiment."""
731 logger.info("LudwigDirectBackend: Starting experiment execution.") 696 logger.info("LudwigDirectBackend: Starting experiment execution.")
732
733 try: 697 try:
734 from ludwig.experiment import experiment_cli 698 from ludwig.experiment import experiment_cli
735 except ImportError as e: 699 except ImportError as e:
736 logger.error( 700 logger.error(
737 "LudwigDirectBackend: Could not import experiment_cli.", 701 "LudwigDirectBackend: Could not import experiment_cli.",
738 exc_info=True, 702 exc_info=True,
739 ) 703 )
740 raise RuntimeError("Ludwig import failed.") from e 704 raise RuntimeError("Ludwig import failed.") from e
741
742 output_dir.mkdir(parents=True, exist_ok=True) 705 output_dir.mkdir(parents=True, exist_ok=True)
743
744 try: 706 try:
745 experiment_cli( 707 experiment_cli(
746 dataset=str(dataset_path), 708 dataset=str(dataset_path),
747 config=str(config_path), 709 config=str(config_path),
748 output_directory=str(output_dir), 710 output_directory=str(output_dir),
769 output_dir = Path(output_dir) 731 output_dir = Path(output_dir)
770 exp_dirs = sorted( 732 exp_dirs = sorted(
771 output_dir.glob("experiment_run*"), 733 output_dir.glob("experiment_run*"),
772 key=lambda p: p.stat().st_mtime, 734 key=lambda p: p.stat().st_mtime,
773 ) 735 )
774
775 if not exp_dirs: 736 if not exp_dirs:
776 logger.warning(f"No experiment run directories found in {output_dir}") 737 logger.warning(f"No experiment run directories found in {output_dir}")
777 return None 738 return None
778
779 progress_file = exp_dirs[-1] / "model" / "training_progress.json" 739 progress_file = exp_dirs[-1] / "model" / "training_progress.json"
780 if not progress_file.exists(): 740 if not progress_file.exists():
781 logger.warning(f"No training_progress.json found in {progress_file}") 741 logger.warning(f"No training_progress.json found in {progress_file}")
782 return None 742 return None
783
784 try: 743 try:
785 with progress_file.open("r", encoding="utf-8") as f: 744 with progress_file.open("r", encoding="utf-8") as f:
786 data = json.load(f) 745 data = json.load(f)
787 return { 746 return {
788 "learning_rate": data.get("learning_rate"), 747 "learning_rate": data.get("learning_rate"),
814 logger.error(f"Error converting Parquet to CSV: {e}") 773 logger.error(f"Error converting Parquet to CSV: {e}")
815 774
816 def generate_plots(self, output_dir: Path) -> None: 775 def generate_plots(self, output_dir: Path) -> None:
817 """Generate all registered Ludwig visualizations for the latest experiment run.""" 776 """Generate all registered Ludwig visualizations for the latest experiment run."""
818 logger.info("Generating all Ludwig visualizations…") 777 logger.info("Generating all Ludwig visualizations…")
819
820 test_plots = { 778 test_plots = {
821 "compare_performance", 779 "compare_performance",
822 "compare_classifiers_performance_from_prob", 780 "compare_classifiers_performance_from_prob",
823 "compare_classifiers_performance_from_pred", 781 "compare_classifiers_performance_from_pred",
824 "compare_classifiers_performance_changing_k", 782 "compare_classifiers_performance_changing_k",
838 } 796 }
839 train_plots = { 797 train_plots = {
840 "learning_curves", 798 "learning_curves",
841 "compare_classifiers_performance_subset", 799 "compare_classifiers_performance_subset",
842 } 800 }
843
844 output_dir = Path(output_dir) 801 output_dir = Path(output_dir)
845 exp_dirs = sorted( 802 exp_dirs = sorted(
846 output_dir.glob("experiment_run*"), 803 output_dir.glob("experiment_run*"),
847 key=lambda p: p.stat().st_mtime, 804 key=lambda p: p.stat().st_mtime,
848 ) 805 )
849 if not exp_dirs: 806 if not exp_dirs:
850 logger.warning(f"No experiment run dirs found in {output_dir}") 807 logger.warning(f"No experiment run dirs found in {output_dir}")
851 return 808 return
852 exp_dir = exp_dirs[-1] 809 exp_dir = exp_dirs[-1]
853
854 viz_dir = exp_dir / "visualizations" 810 viz_dir = exp_dir / "visualizations"
855 viz_dir.mkdir(exist_ok=True) 811 viz_dir.mkdir(exist_ok=True)
856 train_viz = viz_dir / "train" 812 train_viz = viz_dir / "train"
857 test_viz = viz_dir / "test" 813 test_viz = viz_dir / "test"
858 train_viz.mkdir(parents=True, exist_ok=True) 814 train_viz.mkdir(parents=True, exist_ok=True)
863 819
864 training_stats = _check(exp_dir / "training_statistics.json") 820 training_stats = _check(exp_dir / "training_statistics.json")
865 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) 821 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME)
866 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) 822 probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME)
867 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) 823 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME)
868
869 dataset_path = None 824 dataset_path = None
870 split_file = None 825 split_file = None
871 desc = exp_dir / DESCRIPTION_FILE_NAME 826 desc = exp_dir / DESCRIPTION_FILE_NAME
872 if desc.exists(): 827 if desc.exists():
873 with open(desc, "r") as f: 828 with open(desc, "r") as f:
874 cfg = json.load(f) 829 cfg = json.load(f)
875 dataset_path = _check(Path(cfg.get("dataset", ""))) 830 dataset_path = _check(Path(cfg.get("dataset", "")))
876 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) 831 split_file = _check(Path(get_split_path(cfg.get("dataset", ""))))
877
878 output_feature = "" 832 output_feature = ""
879 if desc.exists(): 833 if desc.exists():
880 try: 834 try:
881 output_feature = cfg["config"]["output_features"][0]["name"] 835 output_feature = cfg["config"]["output_features"][0]["name"]
882 except Exception: 836 except Exception:
883 pass 837 pass
884 if not output_feature and test_stats: 838 if not output_feature and test_stats:
885 with open(test_stats, "r") as f: 839 with open(test_stats, "r") as f:
886 stats = json.load(f) 840 stats = json.load(f)
887 output_feature = next(iter(stats.keys()), "") 841 output_feature = next(iter(stats.keys()), "")
888
889 viz_registry = get_visualizations_registry() 842 viz_registry = get_visualizations_registry()
890 for viz_name, viz_func in viz_registry.items(): 843 for viz_name, viz_func in viz_registry.items():
891 if viz_name in train_plots: 844 if viz_name in train_plots:
892 viz_dir_plot = train_viz 845 viz_dir_plot = train_viz
893 elif viz_name in test_plots: 846 elif viz_name in test_plots:
894 viz_dir_plot = test_viz 847 viz_dir_plot = test_viz
895 else: 848 else:
896 continue 849 continue
897
898 try: 850 try:
899 viz_func( 851 viz_func(
900 training_statistics=[training_stats] if training_stats else [], 852 training_statistics=[training_stats] if training_stats else [],
901 test_statistics=[test_stats] if test_stats else [], 853 test_statistics=[test_stats] if test_stats else [],
902 probabilities=[probs_path] if probs_path else [], 854 probabilities=[probs_path] if probs_path else [],
912 file_format="png", 864 file_format="png",
913 ) 865 )
914 logger.info(f"✔ Generated {viz_name}") 866 logger.info(f"✔ Generated {viz_name}")
915 except Exception as e: 867 except Exception as e:
916 logger.warning(f"✘ Skipped {viz_name}: {e}") 868 logger.warning(f"✘ Skipped {viz_name}: {e}")
917
918 logger.info(f"All visualizations written to {viz_dir}") 869 logger.info(f"All visualizations written to {viz_dir}")
919 870
920 def generate_html_report( 871 def generate_html_report(
921 self, 872 self,
922 title: str, 873 title: str,
928 cwd = Path.cwd() 879 cwd = Path.cwd()
929 report_name = title.lower().replace(" ", "_") + "_report.html" 880 report_name = title.lower().replace(" ", "_") + "_report.html"
930 report_path = cwd / report_name 881 report_path = cwd / report_name
931 output_dir = Path(output_dir) 882 output_dir = Path(output_dir)
932 output_type = None 883 output_type = None
933
934 exp_dirs = sorted( 884 exp_dirs = sorted(
935 output_dir.glob("experiment_run*"), 885 output_dir.glob("experiment_run*"),
936 key=lambda p: p.stat().st_mtime, 886 key=lambda p: p.stat().st_mtime,
937 ) 887 )
938 if not exp_dirs: 888 if not exp_dirs:
939 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") 889 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}")
940 exp_dir = exp_dirs[-1] 890 exp_dir = exp_dirs[-1]
941
942 base_viz_dir = exp_dir / "visualizations" 891 base_viz_dir = exp_dir / "visualizations"
943 train_viz_dir = base_viz_dir / "train" 892 train_viz_dir = base_viz_dir / "train"
944 test_viz_dir = base_viz_dir / "test" 893 test_viz_dir = base_viz_dir / "test"
945
946 html = get_html_template() 894 html = get_html_template()
947 html += f"<h1>{title}</h1>" 895 html += f"<h1>{title}</h1>"
948
949 metrics_html = "" 896 metrics_html = ""
950 train_val_metrics_html = "" 897 train_val_metrics_html = ""
951 test_metrics_html = "" 898 test_metrics_html = ""
952 try: 899 try:
953 train_stats_path = exp_dir / "training_statistics.json" 900 train_stats_path = exp_dir / "training_statistics.json"
969 ) 916 )
970 except Exception as e: 917 except Exception as e:
971 logger.warning( 918 logger.warning(
972 f"Could not load stats for HTML report: {type(e).__name__}: {e}" 919 f"Could not load stats for HTML report: {type(e).__name__}: {e}"
973 ) 920 )
974
975 config_html = "" 921 config_html = ""
976 training_progress = self.get_training_process(output_dir) 922 training_progress = self.get_training_process(output_dir)
977 try: 923 try:
978 config_html = format_config_table_html( 924 config_html = format_config_table_html(
979 config, split_info, training_progress 925 config, split_info, training_progress
984 def render_img_section( 930 def render_img_section(
985 title: str, dir_path: Path, output_type: str = None 931 title: str, dir_path: Path, output_type: str = None
986 ) -> str: 932 ) -> str:
987 if not dir_path.exists(): 933 if not dir_path.exists():
988 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>" 934 return f"<h2>{title}</h2><p><em>Directory not found.</em></p>"
989 935 # collect every PNG
990 imgs = list(dir_path.glob("*.png")) 936 imgs = list(dir_path.glob("*.png"))
937 # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files ---
938 imgs = [
939 img for img in imgs
940 if not (
941 img.name == "confusion_matrix.png"
942 or img.name.startswith("confusion_matrix__label_top")
943 or img.name == "roc_curves.png"
944 )
945 ]
991 if not imgs: 946 if not imgs:
992 return f"<h2>{title}</h2><p><em>No plots found.</em></p>" 947 return f"<h2>{title}</h2><p><em>No plots found.</em></p>"
993 948 if output_type == "binary":
994 if title == "Test Visualizations" and output_type == "binary":
995 order = [ 949 order = [
996 "confusion_matrix__label_top2.png",
997 "roc_curves_from_prediction_statistics.png", 950 "roc_curves_from_prediction_statistics.png",
998 "compare_performance_label.png", 951 "compare_performance_label.png",
999 "confusion_matrix_entropy__label_top2.png", 952 "confusion_matrix_entropy__label_top2.png",
953 # ...you can tweak ordering as needed
1000 ] 954 ]
1001 img_names = {img.name: img for img in imgs} 955 img_names = {img.name: img for img in imgs}
1002 ordered_imgs = [ 956 ordered = [img_names[n] for n in order if n in img_names]
1003 img_names[fname] for fname in order if fname in img_names 957 others = sorted(img for img in imgs if img.name not in order)
1004 ] 958 imgs = ordered + others
1005 remaining = sorted( 959 elif output_type == "category":
1006 [
1007 img
1008 for img in imgs
1009 if img.name not in order and img.name != "roc_curves.png"
1010 ]
1011 )
1012 imgs = ordered_imgs + remaining
1013
1014 elif title == "Test Visualizations" and output_type == "category":
1015 unwanted = { 960 unwanted = {
1016 "compare_classifiers_multiclass_multimetric__label_best10.png", 961 "compare_classifiers_multiclass_multimetric__label_best10.png",
1017 "compare_classifiers_multiclass_multimetric__label_top10.png", 962 "compare_classifiers_multiclass_multimetric__label_top10.png",
1018 "compare_classifiers_multiclass_multimetric__label_worst10.png", 963 "compare_classifiers_multiclass_multimetric__label_worst10.png",
1019 } 964 }
1020 display_order = [ 965 display_order = [
1021 "confusion_matrix__label_top10.png",
1022 "roc_curves.png", 966 "roc_curves.png",
1023 "compare_performance_label.png", 967 "compare_performance_label.png",
1024 "compare_classifiers_performance_from_prob.png", 968 "compare_classifiers_performance_from_prob.png",
1025 "compare_classifiers_multiclass_multimetric__label_sorted.png",
1026 "confusion_matrix_entropy__label_top10.png", 969 "confusion_matrix_entropy__label_top10.png",
1027 ] 970 ]
1028 img_names = {img.name: img for img in imgs if img.name not in unwanted} 971 # filter and order
1029 ordered_imgs = [ 972 valid_imgs = [img for img in imgs if img.name not in unwanted]
1030 img_names[fname] for fname in display_order if fname in img_names 973 img_map = {img.name: img for img in valid_imgs}
1031 ] 974 ordered = [img_map[n] for n in display_order if n in img_map]
1032 remaining = sorted( 975 others = sorted(img for img in valid_imgs if img.name not in display_order)
1033 [img for img in img_names.values() if img.name not in display_order] 976 imgs = ordered + others
1034 )
1035 imgs = ordered_imgs + remaining
1036
1037 else: 977 else:
1038 if output_type == "category": 978 # regression: just sort whatever's left
1039 unwanted = { 979 imgs = sorted(imgs)
1040 "compare_classifiers_multiclass_multimetric__label_best10.png", 980 # render each remaining PNG
1041 "compare_classifiers_multiclass_multimetric__label_top10.png", 981 html = ""
1042 "compare_classifiers_multiclass_multimetric__label_worst10.png",
1043 }
1044 imgs = sorted([img for img in imgs if img.name not in unwanted])
1045 else:
1046 imgs = sorted(imgs)
1047
1048 section_html = f"<h2 style='text-align: center;'>{title}</h2><div>"
1049 for img in imgs: 982 for img in imgs:
1050 b64 = encode_image_to_base64(str(img)) 983 b64 = encode_image_to_base64(str(img))
1051 section_html += ( 984 img_title = img.stem.replace("_", " ").title()
985 html += (
986 f"<h2 style='text-align: center;'>{img_title}</h2>"
1052 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' 987 f'<div class="plot" style="margin-bottom:20px;text-align:center;">'
1053 f"<h3>{img.stem.replace('_', ' ').title()}</h3>"
1054 f'<img src="data:image/png;base64,{b64}" ' 988 f'<img src="data:image/png;base64,{b64}" '
1055 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' 989 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />'
1056 f"</div>" 990 f"</div>"
1057 ) 991 )
1058 section_html += "</div>" 992 return html
1059 return section_html
1060 993
1061 tab1_content = config_html + metrics_html 994 tab1_content = config_html + metrics_html
1062
1063 tab2_content = train_val_metrics_html + render_img_section( 995 tab2_content = train_val_metrics_html + render_img_section(
1064 "Training & Validation Visualizations", train_viz_dir 996 "Training and Validation Visualizations", train_viz_dir
1065 ) 997 )
1066
1067 # --- Predictions vs Ground Truth table --- 998 # --- Predictions vs Ground Truth table ---
1068 preds_section = "" 999 preds_section = ""
1069 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME 1000 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME
1070 if parquet_path.exists(): 1001 if output_type == "regression" and parquet_path.exists():
1071 try: 1002 try:
1072 # 1) load predictions from Parquet 1003 # 1) load predictions from Parquet
1073 df_preds = pd.read_parquet(parquet_path).reset_index(drop=True) 1004 df_preds = pd.read_parquet(parquet_path).reset_index(drop=True)
1074 # assume the column containing your model's prediction is named "prediction" 1005 # assume the column containing your model's prediction is named "prediction"
1075 # or contains that substring:
1076 pred_col = next( 1006 pred_col = next(
1077 (c for c in df_preds.columns if "prediction" in c.lower()), 1007 (c for c in df_preds.columns if "prediction" in c.lower()),
1078 None, 1008 None,
1079 ) 1009 )
1080 if pred_col is None: 1010 if pred_col is None:
1081 raise ValueError("No prediction column found in Parquet output") 1011 raise ValueError("No prediction column found in Parquet output")
1082 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"}) 1012 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"})
1083
1084 # 2) load ground truth for the test split from prepared CSV 1013 # 2) load ground truth for the test split from prepared CSV
1085 df_all = pd.read_csv(config["label_column_data_path"]) 1014 df_all = pd.read_csv(config["label_column_data_path"])
1086 df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][ 1015 df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][LABEL_COLUMN_NAME].reset_index(drop=True)
1087 LABEL_COLUMN_NAME 1016 # 3) concatenate side-by-side
1088 ].reset_index(drop=True)
1089
1090 # 3) concatenate side‐by‐side
1091 df_table = pd.concat([df_gt, df_pred], axis=1) 1017 df_table = pd.concat([df_gt, df_pred], axis=1)
1092 df_table.columns = [LABEL_COLUMN_NAME, "prediction"] 1018 df_table.columns = [LABEL_COLUMN_NAME, "prediction"]
1093
1094 # 4) render as HTML 1019 # 4) render as HTML
1095 preds_html = df_table.to_html(index=False, classes="predictions-table") 1020 preds_html = df_table.to_html(index=False, classes="predictions-table")
1096 preds_section = ( 1021 preds_section = (
1097 "<h2 style='text-align: center;'>Predictions vs. Ground Truth</h2>" 1022 "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>"
1098 "<div style='overflow-x:auto; margin-bottom:20px;'>" 1023 "<div style='overflow-y:auto; max-height:400px; overflow-x:auto; margin-bottom:20px;'>"
1099 + preds_html 1024 + preds_html
1100 + "</div>" 1025 + "</div>"
1101 ) 1026 )
1102 except Exception as e: 1027 except Exception as e:
1103 logger.warning(f"Could not build Predictions vs GT table: {e}") 1028 logger.warning(f"Could not build Predictions vs GT table: {e}")
1104 # Test tab = Metrics + Preds table + Visualizations 1029 tab3_content = test_metrics_html + preds_section
1105 1030 if output_type in ("binary", "category"):
1106 tab3_content = ( 1031 training_stats_path = exp_dir / "training_statistics.json"
1107 test_metrics_html 1032 interactive_plots = build_classification_plots(
1108 + preds_section 1033 str(test_stats_path),
1109 + render_img_section("Test Visualizations", test_viz_dir, output_type) 1034 str(training_stats_path),
1110 ) 1035 )
1111 1036 for plot in interactive_plots:
1037 # 2) inject the static "roc_curves_from_prediction_statistics.png"
1038 if plot["title"] == "ROC-AUC":
1039 static_img = test_viz_dir / "roc_curves_from_prediction_statistics.png"
1040 if static_img.exists():
1041 b64 = encode_image_to_base64(str(static_img))
1042 tab3_content += (
1043 "<h2 style='text-align: center;'>"
1044 "Roc Curves From Prediction Statistics"
1045 "</h2>"
1046 f'<div class="plot" style="margin-bottom:20px;text-align:center;">'
1047 f'<img src="data:image/png;base64,{b64}" '
1048 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />'
1049 "</div>"
1050 )
1051 # always render the plotly panels exactly as before
1052 tab3_content += (
1053 f"<h2 style='text-align: center;'>{plot['title']}</h2>"
1054 + plot["html"]
1055 )
1056 tab3_content += render_img_section(
1057 "Test Visualizations",
1058 test_viz_dir,
1059 output_type
1060 )
1112 # assemble the tabs and help modal 1061 # assemble the tabs and help modal
1113 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) 1062 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
1114 modal_html = get_metrics_help_modal() 1063 modal_html = get_metrics_help_modal()
1115 html += tabbed_html + modal_html + get_html_closing() 1064 html += tabbed_html + modal_html + get_html_closing()
1116
1117 try: 1065 try:
1118 with open(report_path, "w") as f: 1066 with open(report_path, "w") as f:
1119 f.write(html) 1067 f.write(html)
1120 logger.info(f"HTML report generated at: {report_path}") 1068 logger.info(f"HTML report generated at: {report_path}")
1121 except Exception as e: 1069 except Exception as e:
1122 logger.error(f"Failed to write HTML report: {e}") 1070 logger.error(f"Failed to write HTML report: {e}")
1123 raise 1071 raise
1124
1125 return report_path 1072 return report_path
1126 1073
1127 1074
1128 class WorkflowOrchestrator: 1075 class WorkflowOrchestrator:
1129 """Manages the image-classification workflow.""" 1076 """Manages the image-classification workflow."""
1130
1131 def __init__(self, args: argparse.Namespace, backend: Backend): 1077 def __init__(self, args: argparse.Namespace, backend: Backend):
1132 self.args = args 1078 self.args = args
1133 self.backend = backend 1079 self.backend = backend
1134 self.temp_dir: Optional[Path] = None 1080 self.temp_dir: Optional[Path] = None
1135 self.image_extract_dir: Optional[Path] = None 1081 self.image_extract_dir: Optional[Path] = None
1165 1111
1166 def _prepare_data(self) -> Tuple[Path, Dict[str, Any], str]: 1112 def _prepare_data(self) -> Tuple[Path, Dict[str, Any], str]:
1167 """Load CSV, update image paths, handle splits, and write prepared CSV.""" 1113 """Load CSV, update image paths, handle splits, and write prepared CSV."""
1168 if not self.temp_dir or not self.image_extract_dir: 1114 if not self.temp_dir or not self.image_extract_dir:
1169 raise RuntimeError("Temp dirs not initialized before data prep.") 1115 raise RuntimeError("Temp dirs not initialized before data prep.")
1170
1171 try: 1116 try:
1172 df = pd.read_csv(self.args.csv_file) 1117 df = pd.read_csv(self.args.csv_file)
1173 logger.info(f"Loaded CSV: {self.args.csv_file}") 1118 logger.info(f"Loaded CSV: {self.args.csv_file}")
1174 except Exception: 1119 except Exception:
1175 logger.error("Error loading CSV file", exc_info=True) 1120 logger.error("Error loading CSV file", exc_info=True)
1176 raise 1121 raise
1177
1178 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} 1122 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME}
1179 missing = required - set(df.columns) 1123 missing = required - set(df.columns)
1180 if missing: 1124 if missing:
1181 raise ValueError(f"Missing CSV columns: {', '.join(missing)}") 1125 raise ValueError(f"Missing CSV columns: {', '.join(missing)}")
1182
1183 try: 1126 try:
1184 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( 1127 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply(
1185 lambda p: str((self.image_extract_dir / p).resolve()) 1128 lambda p: str((self.image_extract_dir / p).resolve())
1186 ) 1129 )
1187 except Exception: 1130 except Exception:
1188 logger.error("Error updating image paths", exc_info=True) 1131 logger.error("Error updating image paths", exc_info=True)
1189 raise 1132 raise
1190
1191 if SPLIT_COLUMN_NAME in df.columns: 1133 if SPLIT_COLUMN_NAME in df.columns:
1192 df, split_config, split_info = self._process_fixed_split(df) 1134 df, split_config, split_info = self._process_fixed_split(df)
1193 else: 1135 else:
1194 logger.info("No split column; creating stratified random split") 1136 logger.info("No split column; creating stratified random split")
1195 df = create_stratified_random_split( 1137 df = create_stratified_random_split(
1206 split_info = ( 1148 split_info = (
1207 f"No split column in CSV. Created stratified random split: " 1149 f"No split column in CSV. Created stratified random split: "
1208 f"{[int(p * 100) for p in self.args.split_probabilities]}% " 1150 f"{[int(p * 100) for p in self.args.split_probabilities]}% "
1209 f"for train/val/test with balanced label distribution." 1151 f"for train/val/test with balanced label distribution."
1210 ) 1152 )
1211
1212 final_csv = self.temp_dir / TEMP_CSV_FILENAME 1153 final_csv = self.temp_dir / TEMP_CSV_FILENAME
1213 try: 1154 try:
1214
1215 df.to_csv(final_csv, index=False) 1155 df.to_csv(final_csv, index=False)
1216 logger.info(f"Saved prepared data to {final_csv}") 1156 logger.info(f"Saved prepared data to {final_csv}")
1217 except Exception: 1157 except Exception:
1218 logger.error("Error saving prepared CSV", exc_info=True) 1158 logger.error("Error saving prepared CSV", exc_info=True)
1219 raise 1159 raise
1220
1221 return final_csv, split_config, split_info 1160 return final_csv, split_config, split_info
1222 1161
1223 def _process_fixed_split( 1162 def _process_fixed_split(
1224 self, df: pd.DataFrame 1163 self, df: pd.DataFrame
1225 ) -> Tuple[pd.DataFrame, Dict[str, Any], str]: 1164 ) -> Tuple[pd.DataFrame, Dict[str, Any], str]:
1230 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype( 1169 df[SPLIT_COLUMN_NAME] = pd.to_numeric(col, errors="coerce").astype(
1231 pd.Int64Dtype() 1170 pd.Int64Dtype()
1232 ) 1171 )
1233 if df[SPLIT_COLUMN_NAME].isna().any(): 1172 if df[SPLIT_COLUMN_NAME].isna().any():
1234 logger.warning("Split column contains non-numeric/missing values.") 1173 logger.warning("Split column contains non-numeric/missing values.")
1235
1236 unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) 1174 unique = set(df[SPLIT_COLUMN_NAME].dropna().unique())
1237 logger.info(f"Unique split values: {unique}") 1175 logger.info(f"Unique split values: {unique}")
1238
1239 if unique == {0, 2}: 1176 if unique == {0, 2}:
1240 df = split_data_0_2( 1177 df = split_data_0_2(
1241 df, 1178 df,
1242 SPLIT_COLUMN_NAME, 1179 SPLIT_COLUMN_NAME,
1243 validation_size=self.args.validation_size, 1180 validation_size=self.args.validation_size,
1254 elif unique.issubset({0, 1, 2}): 1191 elif unique.issubset({0, 1, 2}):
1255 split_info = "Used user-defined split column from CSV." 1192 split_info = "Used user-defined split column from CSV."
1256 logger.info("Using fixed split as-is.") 1193 logger.info("Using fixed split as-is.")
1257 else: 1194 else:
1258 raise ValueError(f"Unexpected split values: {unique}") 1195 raise ValueError(f"Unexpected split values: {unique}")
1259
1260 return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info 1196 return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info
1261
1262 except Exception: 1197 except Exception:
1263 logger.error("Error processing fixed split", exc_info=True) 1198 logger.error("Error processing fixed split", exc_info=True)
1264 raise 1199 raise
1265 1200
1266 def _cleanup_temp_dirs(self) -> None: 1201 def _cleanup_temp_dirs(self) -> None:
1272 1207
1273 def run(self) -> None: 1208 def run(self) -> None:
1274 """Execute the full workflow end-to-end.""" 1209 """Execute the full workflow end-to-end."""
1275 logger.info("Starting workflow...") 1210 logger.info("Starting workflow...")
1276 self.args.output_dir.mkdir(parents=True, exist_ok=True) 1211 self.args.output_dir.mkdir(parents=True, exist_ok=True)
1277
1278 try: 1212 try:
1279 self._create_temp_dirs() 1213 self._create_temp_dirs()
1280 self._extract_images() 1214 self._extract_images()
1281 csv_path, split_cfg, split_info = self._prepare_data() 1215 csv_path, split_cfg, split_info = self._prepare_data()
1282
1283 use_pretrained = self.args.use_pretrained or self.args.fine_tune 1216 use_pretrained = self.args.use_pretrained or self.args.fine_tune
1284
1285 backend_args = { 1217 backend_args = {
1286 "model_name": self.args.model_name, 1218 "model_name": self.args.model_name,
1287 "fine_tune": self.args.fine_tune, 1219 "fine_tune": self.args.fine_tune,
1288 "use_pretrained": use_pretrained, 1220 "use_pretrained": use_pretrained,
1289 "epochs": self.args.epochs, 1221 "epochs": self.args.epochs,
1293 "learning_rate": self.args.learning_rate, 1225 "learning_rate": self.args.learning_rate,
1294 "random_seed": self.args.random_seed, 1226 "random_seed": self.args.random_seed,
1295 "early_stop": self.args.early_stop, 1227 "early_stop": self.args.early_stop,
1296 "label_column_data_path": csv_path, 1228 "label_column_data_path": csv_path,
1297 "augmentation": self.args.augmentation, 1229 "augmentation": self.args.augmentation,
1230 "threshold": self.args.threshold,
1298 } 1231 }
1299 yaml_str = self.backend.prepare_config(backend_args, split_cfg) 1232 yaml_str = self.backend.prepare_config(backend_args, split_cfg)
1300
1301 config_file = self.temp_dir / TEMP_CONFIG_FILENAME 1233 config_file = self.temp_dir / TEMP_CONFIG_FILENAME
1302 config_file.write_text(yaml_str) 1234 config_file.write_text(yaml_str)
1303 logger.info(f"Wrote backend config: {config_file}") 1235 logger.info(f"Wrote backend config: {config_file}")
1304
1305 self.backend.run_experiment( 1236 self.backend.run_experiment(
1306 csv_path, 1237 csv_path,
1307 config_file, 1238 config_file,
1308 self.args.output_dir, 1239 self.args.output_dir,
1309 self.args.random_seed, 1240 self.args.random_seed,
1347 "random_contrast": {"type": "random_contrast", "min": 0.5, "max": 2.0}, 1278 "random_contrast": {"type": "random_contrast", "min": 0.5, "max": 2.0},
1348 } 1279 }
1349 aug_list = [] 1280 aug_list = []
1350 for tok in aug_string.split(","): 1281 for tok in aug_string.split(","):
1351 key = tok.strip() 1282 key = tok.strip()
1352 if not key:
1353 continue
1354 if key not in mapping: 1283 if key not in mapping:
1355 valid = ", ".join(mapping.keys()) 1284 valid = ", ".join(mapping.keys())
1356 raise ValueError(f"Unknown augmentation '{key}'. Valid choices: {valid}") 1285 raise ValueError(f"Unknown augmentation '{key}'. Valid choices: {valid}")
1357 aug_list.append(mapping[key]) 1286 aug_list.append(mapping[key])
1358 return aug_list 1287 return aug_list
1426 help="Where to write outputs", 1355 help="Where to write outputs",
1427 ) 1356 )
1428 parser.add_argument( 1357 parser.add_argument(
1429 "--validation-size", 1358 "--validation-size",
1430 type=float, 1359 type=float,
1431 default=0.1, 1360 default=0.15,
1432 help="Fraction for validation (0.0–1.0)", 1361 help="Fraction for validation (0.0–1.0)",
1433 ) 1362 )
1434 parser.add_argument( 1363 parser.add_argument(
1435 "--preprocessing-num-processes", 1364 "--preprocessing-num-processes",
1436 type=int, 1365 type=int,
1470 "random_horizontal_flip, random_vertical_flip, random_rotate, " 1399 "random_horizontal_flip, random_vertical_flip, random_rotate, "
1471 "random_blur, random_brightness, random_contrast. " 1400 "random_blur, random_brightness, random_contrast. "
1472 "E.g. --augmentation random_horizontal_flip,random_rotate" 1401 "E.g. --augmentation random_horizontal_flip,random_rotate"
1473 ), 1402 ),
1474 ) 1403 )
1475 1404 parser.add_argument(
1405 "--threshold",
1406 type=float,
1407 default=None,
1408 help=(
1409 "Decision threshold for binary classification (0.0–1.0)."
1410 "Overrides default 0.5."
1411 )
1412 )
1476 args = parser.parse_args() 1413 args = parser.parse_args()
1477
1478 if not 0.0 <= args.validation_size <= 1.0: 1414 if not 0.0 <= args.validation_size <= 1.0:
1479 parser.error("validation-size must be between 0.0 and 1.0") 1415 parser.error("validation-size must be between 0.0 and 1.0")
1480 if not args.csv_file.is_file(): 1416 if not args.csv_file.is_file():
1481 parser.error(f"CSV not found: {args.csv_file}") 1417 parser.error(f"CSV not found: {args.csv_file}")
1482 if not args.image_zip.is_file(): 1418 if not args.image_zip.is_file():
1485 try: 1421 try:
1486 augmentation_setup = aug_parse(args.augmentation) 1422 augmentation_setup = aug_parse(args.augmentation)
1487 setattr(args, "augmentation", augmentation_setup) 1423 setattr(args, "augmentation", augmentation_setup)
1488 except ValueError as e: 1424 except ValueError as e:
1489 parser.error(str(e)) 1425 parser.error(str(e))
1490
1491 backend_instance = LudwigDirectBackend() 1426 backend_instance = LudwigDirectBackend()
1492 orchestrator = WorkflowOrchestrator(args, backend_instance) 1427 orchestrator = WorkflowOrchestrator(args, backend_instance)
1493
1494 exit_code = 0 1428 exit_code = 0
1495 try: 1429 try:
1496 orchestrator.run() 1430 orchestrator.run()
1497 logger.info("Main script finished successfully.") 1431 logger.info("Main script finished successfully.")
1498 except Exception as e: 1432 except Exception as e:
1503 1437
1504 1438
1505 if __name__ == "__main__": 1439 if __name__ == "__main__":
1506 try: 1440 try:
1507 import ludwig 1441 import ludwig
1508
1509 logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") 1442 logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}")
1510 except ImportError: 1443 except ImportError:
1511 logger.error( 1444 logger.error(
1512 "Ludwig library not found. Please ensure Ludwig is installed " 1445 "Ludwig library not found. Please ensure Ludwig is installed "
1513 "('pip install ludwig[image]')" 1446 "('pip install ludwig[image]')"
1514 ) 1447 )
1515 sys.exit(1) 1448 sys.exit(1)
1516
1517 main() 1449 main()