comparison ludwig_backend.py @ 18:bbf30253c99f draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
author goeckslab
date Sun, 14 Dec 2025 03:27:12 +0000
parents db9be962dc13
children
comparison
equal deleted inserted replaced
17:db9be962dc13 18:bbf30253c99f
401 logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing") 401 logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing")
402 elif not is_metaformer: 402 elif not is_metaformer:
403 # No explicit resize provided; keep for reporting purposes 403 # No explicit resize provided; keep for reporting purposes
404 config_params.setdefault("image_size", "original") 404 config_params.setdefault("image_size", "original")
405 405
406 def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]: 406 def _resolve_validation_metric(
407 """Pick a validation metric that Ludwig will accept for the resolved task.""" 407 task: str, requested: Optional[str], output_feature: Dict[str, Any]
408 ) -> Optional[str]:
409 """
410 Pick a validation metric that Ludwig will accept for the resolved task/output.
411 If the requested metric is invalid, fall back to a safe option or omit it entirely.
412 """
408 default_map = { 413 default_map = {
409 "regression": "pearson_r", 414 "regression": "mean_squared_error",
410 "binary": "roc_auc", 415 "binary": "roc_auc",
411 "category": "accuracy", 416 "category": "accuracy",
412 } 417 }
413 allowed_map = { 418 allowed_map = {
414 "regression": { 419 "regression": {
415 "pearson_r",
416 "mean_absolute_error", 420 "mean_absolute_error",
417 "mean_squared_error", 421 "mean_squared_error",
418 "root_mean_squared_error", 422 "root_mean_squared_error",
419 "mean_absolute_percentage_error", 423 "root_mean_squared_percentage_error",
420 "r2",
421 "explained_variance",
422 "loss", 424 "loss",
423 }, 425 },
424 # Ludwig rejects f1 and balanced_accuracy for binary outputs; keep to known-safe set.
425 "binary": { 426 "binary": {
426 "roc_auc", 427 "roc_auc",
427 "accuracy", 428 "accuracy",
428 "precision", 429 "precision",
429 "recall", 430 "recall",
430 "specificity", 431 "specificity",
431 "log_loss",
432 "loss", 432 "loss",
433 }, 433 },
434 "category": { 434 "category": {
435 "accuracy", 435 "accuracy",
436 "balanced_accuracy", 436 "balanced_accuracy",
437 "precision", 437 "hits_at_k",
438 "recall",
439 "f1",
440 "specificity",
441 "log_loss",
442 "loss", 438 "loss",
443 }, 439 },
444 } 440 }
445 alias_map = { 441 alias_map = {
446 "regression": { 442 "regression": {
447 "mae": "mean_absolute_error", 443 "mae": "mean_absolute_error",
448 "mse": "mean_squared_error", 444 "mse": "mean_squared_error",
449 "rmse": "root_mean_squared_error", 445 "rmse": "root_mean_squared_error",
450 "mape": "mean_absolute_percentage_error", 446 "rmspe": "root_mean_squared_percentage_error",
447 },
448 "category": {},
449 "binary": {
450 "roc_auc": "roc_auc",
451 }, 451 },
452 } 452 }
453 453
454 default_metric = default_map.get(task) 454 default_metric = default_map.get(task)
455 allowed = allowed_map.get(task, set())
456 metric = requested or default_metric 455 metric = requested or default_metric
457
458 if metric is None: 456 if metric is None:
459 return None 457 return None
460 458
461 metric = alias_map.get(task, {}).get(metric, metric) 459 metric = alias_map.get(task, {}).get(metric, metric)
462 460
463 if metric not in allowed: 461 # Prefer Ludwig's own metric registry when available; intersect with known-safe sets.
462 registry_metrics = None
463 try:
464 from ludwig.features.feature_registries import output_type_registry
465
466 feature_cls = output_type_registry.get(output_feature.get("type"))
467 if feature_cls:
468 feature_obj = feature_cls(feature=output_feature)
469 metrics_attr = getattr(feature_obj, "metric_functions", None) or getattr(
470 feature_obj, "metrics", None
471 )
472 if isinstance(metrics_attr, dict):
473 registry_metrics = set(metrics_attr.keys())
474 except Exception as exc:
475 logger.debug(
476 "Could not inspect Ludwig metrics for output type %s: %s",
477 output_feature.get("type"),
478 exc,
479 )
480
481 allowed = set(allowed_map.get(task, set()))
482 if registry_metrics:
483 # Only keep metrics that Ludwig actually exposes for this output type;
484 # if the intersection is empty, fall back to the registry set.
485 intersected = allowed.intersection(registry_metrics)
486 allowed = intersected or registry_metrics
487
488 if allowed and metric not in allowed:
489 fallback_candidates = [
490 default_metric if default_metric in allowed else None,
491 "loss" if "loss" in allowed else None,
492 next(iter(allowed), None),
493 ]
494 fallback = next((m for m in fallback_candidates if m in allowed), None)
464 if requested: 495 if requested:
465 logger.warning( 496 logger.warning(
466 f"Validation metric '{requested}' is not supported for {task} outputs; using '{default_metric}' instead." 497 "Validation metric '%s' is not supported for %s outputs; %s",
498 requested,
499 task,
500 (f"using '{fallback}' instead." if fallback else "omitting validation_metric."),
467 ) 501 )
468 metric = default_metric 502 metric = fallback
503
469 return metric 504 return metric
470 505
471 if task_type == "regression": 506 if task_type == "regression":
472 output_feat = { 507 output_feat = {
473 "name": LABEL_COLUMN_NAME, 508 "name": LABEL_COLUMN_NAME,
474 "type": "number", 509 "type": "number",
475 "decoder": {"type": "regressor"}, 510 "decoder": {"type": "regressor"},
476 "loss": {"type": "mean_squared_error"}, 511 "loss": {"type": "mean_squared_error"},
477 } 512 }
478 val_metric = _resolve_validation_metric("regression", config_params.get("validation_metric")) 513 val_metric = _resolve_validation_metric(
514 "regression",
515 config_params.get("validation_metric"),
516 output_feat,
517 )
479 518
480 else: 519 else:
481 if num_unique_labels == 2: 520 if num_unique_labels == 2:
482 output_feat = { 521 output_feat = {
483 "name": LABEL_COLUMN_NAME, 522 "name": LABEL_COLUMN_NAME,
493 "loss": {"type": "softmax_cross_entropy"}, 532 "loss": {"type": "softmax_cross_entropy"},
494 } 533 }
495 val_metric = _resolve_validation_metric( 534 val_metric = _resolve_validation_metric(
496 "binary" if num_unique_labels == 2 else "category", 535 "binary" if num_unique_labels == 2 else "category",
497 config_params.get("validation_metric"), 536 config_params.get("validation_metric"),
537 output_feat,
498 ) 538 )
499 539
500 # Propagate the resolved validation metric (including any task-based fallback or alias normalization) 540 # Propagate the resolved validation metric (including any task-based fallback or alias normalization)
501 config_params["validation_metric"] = val_metric 541 config_params["validation_metric"] = val_metric
502 542
608 exc_info=True, 648 exc_info=True,
609 ) 649 )
610 raise RuntimeError("Ludwig argument error.") from e 650 raise RuntimeError("Ludwig argument error.") from e
611 except Exception: 651 except Exception:
612 logger.error( 652 logger.error(
613 "LudwigDirectBackend: Experiment execution error.", 653 "LudwigDirectBackend: Experiment execution error. "
654 "If this relates to validation_metric, confirm the XML task selection "
655 "passes a metric that matches the inferred task type.",
614 exc_info=True, 656 exc_info=True,
615 ) 657 )
616 raise 658 raise
617 659
618 def get_training_process(self, output_dir) -> Optional[Dict[str, Any]]: 660 def get_training_process(self, output_dir) -> Optional[Dict[str, Any]]: