# HG changeset patch # User goeckslab # Date 1765682832 0 # Node ID bbf30253c99f0800f1b9411ccb9fd060c2d6049b # Parent db9be962dc13398731eaad271704660ed1629b9a planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39 diff -r db9be962dc13 -r bbf30253c99f constants.py --- a/constants.py Wed Dec 10 00:24:13 2025 +0000 +++ b/constants.py Sun Dec 14 03:27:12 2025 +0000 @@ -174,6 +174,7 @@ } METRIC_DISPLAY_NAMES = { "accuracy": "Accuracy", + "balanced_accuracy": "Balanced Accuracy", "accuracy_micro": "Micro Accuracy", "loss": "Loss", "roc_auc": "ROC-AUC", diff -r db9be962dc13 -r bbf30253c99f image_learner.xml --- a/image_learner.xml Wed Dec 10 00:24:13 2025 +0000 +++ b/image_learner.xml Sun Dec 14 03:27:12 2025 +0000 @@ -106,33 +106,26 @@ - - - - + + - - - - - - - - - + + + + diff -r db9be962dc13 -r bbf30253c99f image_learner_cli.py --- a/image_learner_cli.py Wed Dec 10 00:24:13 2025 +0000 +++ b/image_learner_cli.py Sun Dec 14 03:27:12 2025 +0000 @@ -145,26 +145,11 @@ parser.add_argument( "--validation-metric", type=str, - default="roc_auc", - choices=[ - "accuracy", - "loss", - "roc_auc", - "balanced_accuracy", - "precision", - "recall", - "f1", - "specificity", - "log_loss", - "pearson_r", - "mae", - "mse", - "rmse", - "mape", - "r2", - "explained_variance", - ], - help="Metric Ludwig uses to select the best model during training/validation.", + default=None, + help=( + "Metric Ludwig uses to select the best model during training/validation. " + "Leave unset to let the tool pick a default for the inferred task." + ), ) parser.add_argument( "--target-column", diff -r db9be962dc13 -r bbf30253c99f ludwig_backend.py --- a/ludwig_backend.py Wed Dec 10 00:24:13 2025 +0000 +++ b/ludwig_backend.py Sun Dec 14 03:27:12 2025 +0000 @@ -403,42 +403,38 @@ # No explicit resize provided; keep for reporting purposes config_params.setdefault("image_size", "original") - def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]: - """Pick a validation metric that Ludwig will accept for the resolved task.""" + def _resolve_validation_metric( + task: str, requested: Optional[str], output_feature: Dict[str, Any] + ) -> Optional[str]: + """ + Pick a validation metric that Ludwig will accept for the resolved task/output. + If the requested metric is invalid, fall back to a safe option or omit it entirely. + """ default_map = { - "regression": "pearson_r", + "regression": "mean_squared_error", "binary": "roc_auc", "category": "accuracy", } allowed_map = { "regression": { - "pearson_r", "mean_absolute_error", "mean_squared_error", "root_mean_squared_error", - "mean_absolute_percentage_error", - "r2", - "explained_variance", + "root_mean_squared_percentage_error", "loss", }, - # Ludwig rejects f1 and balanced_accuracy for binary outputs; keep to known-safe set. "binary": { "roc_auc", "accuracy", "precision", "recall", "specificity", - "log_loss", "loss", }, "category": { "accuracy", "balanced_accuracy", - "precision", - "recall", - "f1", - "specificity", - "log_loss", + "hits_at_k", "loss", }, } @@ -447,25 +443,64 @@ "mae": "mean_absolute_error", "mse": "mean_squared_error", "rmse": "root_mean_squared_error", - "mape": "mean_absolute_percentage_error", + "rmspe": "root_mean_squared_percentage_error", + }, + "category": {}, + "binary": { + "roc_auc": "roc_auc", }, } default_metric = default_map.get(task) - allowed = allowed_map.get(task, set()) metric = requested or default_metric - if metric is None: return None metric = alias_map.get(task, {}).get(metric, metric) - if metric not in allowed: + # Prefer Ludwig's own metric registry when available; intersect with known-safe sets. + registry_metrics = None + try: + from ludwig.features.feature_registries import output_type_registry + + feature_cls = output_type_registry.get(output_feature.get("type")) + if feature_cls: + feature_obj = feature_cls(feature=output_feature) + metrics_attr = getattr(feature_obj, "metric_functions", None) or getattr( + feature_obj, "metrics", None + ) + if isinstance(metrics_attr, dict): + registry_metrics = set(metrics_attr.keys()) + except Exception as exc: + logger.debug( + "Could not inspect Ludwig metrics for output type %s: %s", + output_feature.get("type"), + exc, + ) + + allowed = set(allowed_map.get(task, set())) + if registry_metrics: + # Only keep metrics that Ludwig actually exposes for this output type; + # if the intersection is empty, fall back to the registry set. + intersected = allowed.intersection(registry_metrics) + allowed = intersected or registry_metrics + + if allowed and metric not in allowed: + fallback_candidates = [ + default_metric if default_metric in allowed else None, + "loss" if "loss" in allowed else None, + next(iter(allowed), None), + ] + fallback = next((m for m in fallback_candidates if m in allowed), None) if requested: logger.warning( - f"Validation metric '{requested}' is not supported for {task} outputs; using '{default_metric}' instead." + "Validation metric '%s' is not supported for %s outputs; %s", + requested, + task, + (f"using '{fallback}' instead." if fallback else "omitting validation_metric."), ) - metric = default_metric + metric = fallback + return metric if task_type == "regression": @@ -475,7 +510,11 @@ "decoder": {"type": "regressor"}, "loss": {"type": "mean_squared_error"}, } - val_metric = _resolve_validation_metric("regression", config_params.get("validation_metric")) + val_metric = _resolve_validation_metric( + "regression", + config_params.get("validation_metric"), + output_feat, + ) else: if num_unique_labels == 2: @@ -495,6 +534,7 @@ val_metric = _resolve_validation_metric( "binary" if num_unique_labels == 2 else "category", config_params.get("validation_metric"), + output_feat, ) # Propagate the resolved validation metric (including any task-based fallback or alias normalization) @@ -610,7 +650,9 @@ raise RuntimeError("Ludwig argument error.") from e except Exception: logger.error( - "LudwigDirectBackend: Experiment execution error.", + "LudwigDirectBackend: Experiment execution error. " + "If this relates to validation_metric, confirm the XML task selection " + "passes a metric that matches the inferred task type.", exc_info=True, ) raise