# HG changeset patch
# User goeckslab
# Date 1765682832 0
# Node ID bbf30253c99f0800f1b9411ccb9fd060c2d6049b
# Parent db9be962dc13398731eaad271704660ed1629b9a
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
diff -r db9be962dc13 -r bbf30253c99f constants.py
--- a/constants.py Wed Dec 10 00:24:13 2025 +0000
+++ b/constants.py Sun Dec 14 03:27:12 2025 +0000
@@ -174,6 +174,7 @@
}
METRIC_DISPLAY_NAMES = {
"accuracy": "Accuracy",
+ "balanced_accuracy": "Balanced Accuracy",
"accuracy_micro": "Micro Accuracy",
"loss": "Loss",
"roc_auc": "ROC-AUC",
diff -r db9be962dc13 -r bbf30253c99f image_learner.xml
--- a/image_learner.xml Wed Dec 10 00:24:13 2025 +0000
+++ b/image_learner.xml Sun Dec 14 03:27:12 2025 +0000
@@ -106,33 +106,26 @@
-
-
-
-
+
+
-
-
-
-
-
-
-
-
-
+
+
+
+
diff -r db9be962dc13 -r bbf30253c99f image_learner_cli.py
--- a/image_learner_cli.py Wed Dec 10 00:24:13 2025 +0000
+++ b/image_learner_cli.py Sun Dec 14 03:27:12 2025 +0000
@@ -145,26 +145,11 @@
parser.add_argument(
"--validation-metric",
type=str,
- default="roc_auc",
- choices=[
- "accuracy",
- "loss",
- "roc_auc",
- "balanced_accuracy",
- "precision",
- "recall",
- "f1",
- "specificity",
- "log_loss",
- "pearson_r",
- "mae",
- "mse",
- "rmse",
- "mape",
- "r2",
- "explained_variance",
- ],
- help="Metric Ludwig uses to select the best model during training/validation.",
+ default=None,
+ help=(
+ "Metric Ludwig uses to select the best model during training/validation. "
+ "Leave unset to let the tool pick a default for the inferred task."
+ ),
)
parser.add_argument(
"--target-column",
diff -r db9be962dc13 -r bbf30253c99f ludwig_backend.py
--- a/ludwig_backend.py Wed Dec 10 00:24:13 2025 +0000
+++ b/ludwig_backend.py Sun Dec 14 03:27:12 2025 +0000
@@ -403,42 +403,38 @@
# No explicit resize provided; keep for reporting purposes
config_params.setdefault("image_size", "original")
- def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]:
- """Pick a validation metric that Ludwig will accept for the resolved task."""
+ def _resolve_validation_metric(
+ task: str, requested: Optional[str], output_feature: Dict[str, Any]
+ ) -> Optional[str]:
+ """
+ Pick a validation metric that Ludwig will accept for the resolved task/output.
+ If the requested metric is invalid, fall back to a safe option or omit it entirely.
+ """
default_map = {
- "regression": "pearson_r",
+ "regression": "mean_squared_error",
"binary": "roc_auc",
"category": "accuracy",
}
allowed_map = {
"regression": {
- "pearson_r",
"mean_absolute_error",
"mean_squared_error",
"root_mean_squared_error",
- "mean_absolute_percentage_error",
- "r2",
- "explained_variance",
+ "root_mean_squared_percentage_error",
"loss",
},
- # Ludwig rejects f1 and balanced_accuracy for binary outputs; keep to known-safe set.
"binary": {
"roc_auc",
"accuracy",
"precision",
"recall",
"specificity",
- "log_loss",
"loss",
},
"category": {
"accuracy",
"balanced_accuracy",
- "precision",
- "recall",
- "f1",
- "specificity",
- "log_loss",
+ "hits_at_k",
"loss",
},
}
@@ -447,25 +443,64 @@
"mae": "mean_absolute_error",
"mse": "mean_squared_error",
"rmse": "root_mean_squared_error",
- "mape": "mean_absolute_percentage_error",
+ "rmspe": "root_mean_squared_percentage_error",
+ },
+ "category": {},
+ "binary": {
+ "roc_auc": "roc_auc",
},
}
default_metric = default_map.get(task)
- allowed = allowed_map.get(task, set())
metric = requested or default_metric
-
if metric is None:
return None
metric = alias_map.get(task, {}).get(metric, metric)
- if metric not in allowed:
+ # Prefer Ludwig's own metric registry when available; intersect with known-safe sets.
+ registry_metrics = None
+ try:
+ from ludwig.features.feature_registries import output_type_registry
+
+ feature_cls = output_type_registry.get(output_feature.get("type"))
+ if feature_cls:
+ feature_obj = feature_cls(feature=output_feature)
+ metrics_attr = getattr(feature_obj, "metric_functions", None) or getattr(
+ feature_obj, "metrics", None
+ )
+ if isinstance(metrics_attr, dict):
+ registry_metrics = set(metrics_attr.keys())
+ except Exception as exc:
+ logger.debug(
+ "Could not inspect Ludwig metrics for output type %s: %s",
+ output_feature.get("type"),
+ exc,
+ )
+
+ allowed = set(allowed_map.get(task, set()))
+ if registry_metrics:
+ # Only keep metrics that Ludwig actually exposes for this output type;
+ # if the intersection is empty, fall back to the registry set.
+ intersected = allowed.intersection(registry_metrics)
+ allowed = intersected or registry_metrics
+
+ if allowed and metric not in allowed:
+ fallback_candidates = [
+ default_metric if default_metric in allowed else None,
+ "loss" if "loss" in allowed else None,
+ next(iter(allowed), None),
+ ]
+ fallback = next((m for m in fallback_candidates if m in allowed), None)
if requested:
logger.warning(
- f"Validation metric '{requested}' is not supported for {task} outputs; using '{default_metric}' instead."
+ "Validation metric '%s' is not supported for %s outputs; %s",
+ requested,
+ task,
+ (f"using '{fallback}' instead." if fallback else "omitting validation_metric."),
)
- metric = default_metric
+ metric = fallback
+
return metric
if task_type == "regression":
@@ -475,7 +510,11 @@
"decoder": {"type": "regressor"},
"loss": {"type": "mean_squared_error"},
}
- val_metric = _resolve_validation_metric("regression", config_params.get("validation_metric"))
+ val_metric = _resolve_validation_metric(
+ "regression",
+ config_params.get("validation_metric"),
+ output_feat,
+ )
else:
if num_unique_labels == 2:
@@ -495,6 +534,7 @@
val_metric = _resolve_validation_metric(
"binary" if num_unique_labels == 2 else "category",
config_params.get("validation_metric"),
+ output_feat,
)
# Propagate the resolved validation metric (including any task-based fallback or alias normalization)
@@ -610,7 +650,9 @@
raise RuntimeError("Ludwig argument error.") from e
except Exception:
logger.error(
- "LudwigDirectBackend: Experiment execution error.",
+ "LudwigDirectBackend: Experiment execution error. "
+ "If this relates to validation_metric, confirm the XML task selection "
+ "passes a metric that matches the inferred task type.",
exc_info=True,
)
raise