Mercurial > repos > goeckslab > image_learner
changeset 18:bbf30253c99f draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
| author | goeckslab |
|---|---|
| date | Sun, 14 Dec 2025 03:27:12 +0000 |
| parents | db9be962dc13 |
| children | |
| files | constants.py image_learner.xml image_learner_cli.py ludwig_backend.py |
| diffstat | 4 files changed, 76 insertions(+), 55 deletions(-) [+] |
line wrap: on
line diff
--- a/constants.py Wed Dec 10 00:24:13 2025 +0000 +++ b/constants.py Sun Dec 14 03:27:12 2025 +0000 @@ -174,6 +174,7 @@ } METRIC_DISPLAY_NAMES = { "accuracy": "Accuracy", + "balanced_accuracy": "Balanced Accuracy", "accuracy_micro": "Micro Accuracy", "loss": "Loss", "roc_auc": "ROC-AUC",
--- a/image_learner.xml Wed Dec 10 00:24:13 2025 +0000 +++ b/image_learner.xml Sun Dec 14 03:27:12 2025 +0000 @@ -106,33 +106,26 @@ <param name="validation_metric_binary" type="select" optional="true" label="Validation metric (binary)" help="Metrics accepted by Ludwig for binary outputs."> <option value="roc_auc" selected="true">ROC-AUC</option> <option value="accuracy">Accuracy</option> - <option value="balanced_accuracy">Balanced Accuracy</option> <option value="precision">Precision</option> <option value="recall">Recall</option> - <option value="f1">F1</option> <option value="specificity">Specificity</option> - <option value="log_loss">Log Loss</option> <option value="loss">Loss</option> </param> </when> <when value="classification"> <param name="validation_metric_multiclass" type="select" optional="true" label="Validation metric (multi-class)" help="Metrics accepted by Ludwig for multi-class outputs."> <option value="accuracy" selected="true">Accuracy</option> - <option value="roc_auc">ROC-AUC</option> + <option value="balanced_accuracy">Balanced Accuracy</option> + <option value="hits_at_k">Hits at K (top-k)</option> <option value="loss">Loss</option> - <option value="balanced_accuracy">Balanced Accuracy</option> - <option value="precision">Precision</option> - <option value="recall">Recall</option> - <option value="f1">F1</option> - <option value="specificity">Specificity</option> - <option value="log_loss">Log Loss</option> </param> </when> <when value="regression"> <param name="validation_metric_regression" type="select" optional="true" label="Validation metric (regression)" help="Metrics accepted by Ludwig for regression outputs."> - <option value="mae" selected="true">MAE</option> - <option value="mse">MSE</option> - <option value="rmse">RMSE</option> + <option value="mean_squared_error" selected="true">Mean Squared Error</option> + <option value="mean_absolute_error">Mean Absolute Error</option> + <option value="root_mean_squared_error">Root Mean Squared Error</option> + <option value="root_mean_squared_percentage_error">Root Mean Squared Percentage Error</option> <option value="loss">Loss</option> </param> </when>
--- a/image_learner_cli.py Wed Dec 10 00:24:13 2025 +0000 +++ b/image_learner_cli.py Sun Dec 14 03:27:12 2025 +0000 @@ -145,26 +145,11 @@ parser.add_argument( "--validation-metric", type=str, - default="roc_auc", - choices=[ - "accuracy", - "loss", - "roc_auc", - "balanced_accuracy", - "precision", - "recall", - "f1", - "specificity", - "log_loss", - "pearson_r", - "mae", - "mse", - "rmse", - "mape", - "r2", - "explained_variance", - ], - help="Metric Ludwig uses to select the best model during training/validation.", + default=None, + help=( + "Metric Ludwig uses to select the best model during training/validation. " + "Leave unset to let the tool pick a default for the inferred task." + ), ) parser.add_argument( "--target-column",
--- a/ludwig_backend.py Wed Dec 10 00:24:13 2025 +0000 +++ b/ludwig_backend.py Sun Dec 14 03:27:12 2025 +0000 @@ -403,42 +403,38 @@ # No explicit resize provided; keep for reporting purposes config_params.setdefault("image_size", "original") - def _resolve_validation_metric(task: str, requested: Optional[str]) -> Optional[str]: - """Pick a validation metric that Ludwig will accept for the resolved task.""" + def _resolve_validation_metric( + task: str, requested: Optional[str], output_feature: Dict[str, Any] + ) -> Optional[str]: + """ + Pick a validation metric that Ludwig will accept for the resolved task/output. + If the requested metric is invalid, fall back to a safe option or omit it entirely. + """ default_map = { - "regression": "pearson_r", + "regression": "mean_squared_error", "binary": "roc_auc", "category": "accuracy", } allowed_map = { "regression": { - "pearson_r", "mean_absolute_error", "mean_squared_error", "root_mean_squared_error", - "mean_absolute_percentage_error", - "r2", - "explained_variance", + "root_mean_squared_percentage_error", "loss", }, - # Ludwig rejects f1 and balanced_accuracy for binary outputs; keep to known-safe set. "binary": { "roc_auc", "accuracy", "precision", "recall", "specificity", - "log_loss", "loss", }, "category": { "accuracy", "balanced_accuracy", - "precision", - "recall", - "f1", - "specificity", - "log_loss", + "hits_at_k", "loss", }, } @@ -447,25 +443,64 @@ "mae": "mean_absolute_error", "mse": "mean_squared_error", "rmse": "root_mean_squared_error", - "mape": "mean_absolute_percentage_error", + "rmspe": "root_mean_squared_percentage_error", + }, + "category": {}, + "binary": { + "roc_auc": "roc_auc", }, } default_metric = default_map.get(task) - allowed = allowed_map.get(task, set()) metric = requested or default_metric - if metric is None: return None metric = alias_map.get(task, {}).get(metric, metric) - if metric not in allowed: + # Prefer Ludwig's own metric registry when available; intersect with known-safe sets. + registry_metrics = None + try: + from ludwig.features.feature_registries import output_type_registry + + feature_cls = output_type_registry.get(output_feature.get("type")) + if feature_cls: + feature_obj = feature_cls(feature=output_feature) + metrics_attr = getattr(feature_obj, "metric_functions", None) or getattr( + feature_obj, "metrics", None + ) + if isinstance(metrics_attr, dict): + registry_metrics = set(metrics_attr.keys()) + except Exception as exc: + logger.debug( + "Could not inspect Ludwig metrics for output type %s: %s", + output_feature.get("type"), + exc, + ) + + allowed = set(allowed_map.get(task, set())) + if registry_metrics: + # Only keep metrics that Ludwig actually exposes for this output type; + # if the intersection is empty, fall back to the registry set. + intersected = allowed.intersection(registry_metrics) + allowed = intersected or registry_metrics + + if allowed and metric not in allowed: + fallback_candidates = [ + default_metric if default_metric in allowed else None, + "loss" if "loss" in allowed else None, + next(iter(allowed), None), + ] + fallback = next((m for m in fallback_candidates if m in allowed), None) if requested: logger.warning( - f"Validation metric '{requested}' is not supported for {task} outputs; using '{default_metric}' instead." + "Validation metric '%s' is not supported for %s outputs; %s", + requested, + task, + (f"using '{fallback}' instead." if fallback else "omitting validation_metric."), ) - metric = default_metric + metric = fallback + return metric if task_type == "regression": @@ -475,7 +510,11 @@ "decoder": {"type": "regressor"}, "loss": {"type": "mean_squared_error"}, } - val_metric = _resolve_validation_metric("regression", config_params.get("validation_metric")) + val_metric = _resolve_validation_metric( + "regression", + config_params.get("validation_metric"), + output_feat, + ) else: if num_unique_labels == 2: @@ -495,6 +534,7 @@ val_metric = _resolve_validation_metric( "binary" if num_unique_labels == 2 else "category", config_params.get("validation_metric"), + output_feat, ) # Propagate the resolved validation metric (including any task-based fallback or alias normalization) @@ -610,7 +650,9 @@ raise RuntimeError("Ludwig argument error.") from e except Exception: logger.error( - "LudwigDirectBackend: Experiment execution error.", + "LudwigDirectBackend: Experiment execution error. " + "If this relates to validation_metric, confirm the XML task selection " + "passes a metric that matches the inferred task type.", exc_info=True, ) raise
