Mercurial > repos > goeckslab > image_learner
comparison ludwig_backend.py @ 18:bbf30253c99f draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
| author | goeckslab |
|---|---|
| date | Sun, 14 Dec 2025 03:27:12 +0000 |
| parents | db9be962dc13 |
| children |
comparison
equal
deleted
inserted
replaced
| 17:db9be962dc13 | 18:bbf30253c99f |
|---|---|
| 401 logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing") | 401 logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing") |
| 402 elif not is_metaformer: | 402 elif not is_metaformer: |
| 403 # No explicit resize provided; keep for reporting purposes | 403 # No explicit resize provided; keep for reporting purposes |
| 404 config_params.setdefault("image_size", "original") | 404 config_params.setdefault("image_size", "original") |
| 405 | 405 |
| 406 def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]: | 406 def _resolve_validation_metric( |
| 407 """Pick a validation metric that Ludwig will accept for the resolved task.""" | 407 task: str, requested: Optional[str], output_feature: Dict[str, Any] |
| 408 ) -> Optional[str]: | |
| 409 """ | |
| 410 Pick a validation metric that Ludwig will accept for the resolved task/output. | |
| 411 If the requested metric is invalid, fall back to a safe option or omit it entirely. | |
| 412 """ | |
| 408 default_map = { | 413 default_map = { |
| 409 "regression": "pearson_r", | 414 "regression": "mean_squared_error", |
| 410 "binary": "roc_auc", | 415 "binary": "roc_auc", |
| 411 "category": "accuracy", | 416 "category": "accuracy", |
| 412 } | 417 } |
| 413 allowed_map = { | 418 allowed_map = { |
| 414 "regression": { | 419 "regression": { |
| 415 "pearson_r", | |
| 416 "mean_absolute_error", | 420 "mean_absolute_error", |
| 417 "mean_squared_error", | 421 "mean_squared_error", |
| 418 "root_mean_squared_error", | 422 "root_mean_squared_error", |
| 419 "mean_absolute_percentage_error", | 423 "root_mean_squared_percentage_error", |
| 420 "r2", | |
| 421 "explained_variance", | |
| 422 "loss", | 424 "loss", |
| 423 }, | 425 }, |
| 424 # Ludwig rejects f1 and balanced_accuracy for binary outputs; keep to known-safe set. | |
| 425 "binary": { | 426 "binary": { |
| 426 "roc_auc", | 427 "roc_auc", |
| 427 "accuracy", | 428 "accuracy", |
| 428 "precision", | 429 "precision", |
| 429 "recall", | 430 "recall", |
| 430 "specificity", | 431 "specificity", |
| 431 "log_loss", | |
| 432 "loss", | 432 "loss", |
| 433 }, | 433 }, |
| 434 "category": { | 434 "category": { |
| 435 "accuracy", | 435 "accuracy", |
| 436 "balanced_accuracy", | 436 "balanced_accuracy", |
| 437 "precision", | 437 "hits_at_k", |
| 438 "recall", | |
| 439 "f1", | |
| 440 "specificity", | |
| 441 "log_loss", | |
| 442 "loss", | 438 "loss", |
| 443 }, | 439 }, |
| 444 } | 440 } |
| 445 alias_map = { | 441 alias_map = { |
| 446 "regression": { | 442 "regression": { |
| 447 "mae": "mean_absolute_error", | 443 "mae": "mean_absolute_error", |
| 448 "mse": "mean_squared_error", | 444 "mse": "mean_squared_error", |
| 449 "rmse": "root_mean_squared_error", | 445 "rmse": "root_mean_squared_error", |
| 450 "mape": "mean_absolute_percentage_error", | 446 "rmspe": "root_mean_squared_percentage_error", |
| 447 }, | |
| 448 "category": {}, | |
| 449 "binary": { | |
| 450 "roc_auc": "roc_auc", | |
| 451 }, | 451 }, |
| 452 } | 452 } |
| 453 | 453 |
| 454 default_metric = default_map.get(task) | 454 default_metric = default_map.get(task) |
| 455 allowed = allowed_map.get(task, set()) | |
| 456 metric = requested or default_metric | 455 metric = requested or default_metric |
| 457 | |
| 458 if metric is None: | 456 if metric is None: |
| 459 return None | 457 return None |
| 460 | 458 |
| 461 metric = alias_map.get(task, {}).get(metric, metric) | 459 metric = alias_map.get(task, {}).get(metric, metric) |
| 462 | 460 |
| 463 if metric not in allowed: | 461 # Prefer Ludwig's own metric registry when available; intersect with known-safe sets. |
| 462 registry_metrics = None | |
| 463 try: | |
| 464 from ludwig.features.feature_registries import output_type_registry | |
| 465 | |
| 466 feature_cls = output_type_registry.get(output_feature.get("type")) | |
| 467 if feature_cls: | |
| 468 feature_obj = feature_cls(feature=output_feature) | |
| 469 metrics_attr = getattr(feature_obj, "metric_functions", None) or getattr( | |
| 470 feature_obj, "metrics", None | |
| 471 ) | |
| 472 if isinstance(metrics_attr, dict): | |
| 473 registry_metrics = set(metrics_attr.keys()) | |
| 474 except Exception as exc: | |
| 475 logger.debug( | |
| 476 "Could not inspect Ludwig metrics for output type %s: %s", | |
| 477 output_feature.get("type"), | |
| 478 exc, | |
| 479 ) | |
| 480 | |
| 481 allowed = set(allowed_map.get(task, set())) | |
| 482 if registry_metrics: | |
| 483 # Only keep metrics that Ludwig actually exposes for this output type; | |
| 484 # if the intersection is empty, fall back to the registry set. | |
| 485 intersected = allowed.intersection(registry_metrics) | |
| 486 allowed = intersected or registry_metrics | |
| 487 | |
| 488 if allowed and metric not in allowed: | |
| 489 fallback_candidates = [ | |
| 490 default_metric if default_metric in allowed else None, | |
| 491 "loss" if "loss" in allowed else None, | |
| 492 next(iter(allowed), None), | |
| 493 ] | |
| 494 fallback = next((m for m in fallback_candidates if m in allowed), None) | |
| 464 if requested: | 495 if requested: |
| 465 logger.warning( | 496 logger.warning( |
| 466 f"Validation metric '{requested}' is not supported for {task} outputs; using '{default_metric}' instead." | 497 "Validation metric '%s' is not supported for %s outputs; %s", |
| 498 requested, | |
| 499 task, | |
| 500 (f"using '{fallback}' instead." if fallback else "omitting validation_metric."), | |
| 467 ) | 501 ) |
| 468 metric = default_metric | 502 metric = fallback |
| 503 | |
| 469 return metric | 504 return metric |
| 470 | 505 |
| 471 if task_type == "regression": | 506 if task_type == "regression": |
| 472 output_feat = { | 507 output_feat = { |
| 473 "name": LABEL_COLUMN_NAME, | 508 "name": LABEL_COLUMN_NAME, |
| 474 "type": "number", | 509 "type": "number", |
| 475 "decoder": {"type": "regressor"}, | 510 "decoder": {"type": "regressor"}, |
| 476 "loss": {"type": "mean_squared_error"}, | 511 "loss": {"type": "mean_squared_error"}, |
| 477 } | 512 } |
| 478 val_metric = _resolve_validation_metric("regression", config_params.get("validation_metric")) | 513 val_metric = _resolve_validation_metric( |
| 514 "regression", | |
| 515 config_params.get("validation_metric"), | |
| 516 output_feat, | |
| 517 ) | |
| 479 | 518 |
| 480 else: | 519 else: |
| 481 if num_unique_labels == 2: | 520 if num_unique_labels == 2: |
| 482 output_feat = { | 521 output_feat = { |
| 483 "name": LABEL_COLUMN_NAME, | 522 "name": LABEL_COLUMN_NAME, |
| 493 "loss": {"type": "softmax_cross_entropy"}, | 532 "loss": {"type": "softmax_cross_entropy"}, |
| 494 } | 533 } |
| 495 val_metric = _resolve_validation_metric( | 534 val_metric = _resolve_validation_metric( |
| 496 "binary" if num_unique_labels == 2 else "category", | 535 "binary" if num_unique_labels == 2 else "category", |
| 497 config_params.get("validation_metric"), | 536 config_params.get("validation_metric"), |
| 537 output_feat, | |
| 498 ) | 538 ) |
| 499 | 539 |
| 500 # Propagate the resolved validation metric (including any task-based fallback or alias normalization) | 540 # Propagate the resolved validation metric (including any task-based fallback or alias normalization) |
| 501 config_params["validation_metric"] = val_metric | 541 config_params["validation_metric"] = val_metric |
| 502 | 542 |
| 608 exc_info=True, | 648 exc_info=True, |
| 609 ) | 649 ) |
| 610 raise RuntimeError("Ludwig argument error.") from e | 650 raise RuntimeError("Ludwig argument error.") from e |
| 611 except Exception: | 651 except Exception: |
| 612 logger.error( | 652 logger.error( |
| 613 "LudwigDirectBackend: Experiment execution error.", | 653 "LudwigDirectBackend: Experiment execution error. " |
| 654 "If this relates to validation_metric, confirm the XML task selection " | |
| 655 "passes a metric that matches the inferred task type.", | |
| 614 exc_info=True, | 656 exc_info=True, |
| 615 ) | 657 ) |
| 616 raise | 658 raise |
| 617 | 659 |
| 618 def get_training_process(self, output_dir) -> Optional[Dict[str, Any]]: | 660 def get_training_process(self, output_dir) -> Optional[Dict[str, Any]]: |
