diff ludwig_backend.py @ 16:8729f69e9207 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
author goeckslab
date Wed, 03 Dec 2025 01:28:52 +0000
parents d17e3a1b8659
children
line wrap: on
line diff
--- a/ludwig_backend.py	Fri Nov 28 15:45:49 2025 +0000
+++ b/ludwig_backend.py	Wed Dec 03 01:28:52 2025 +0000
@@ -1,5 +1,6 @@
 import json
 import logging
+import os
 from pathlib import Path
 from typing import Any, Dict, Optional, Protocol, Tuple
 
@@ -162,7 +163,17 @@
                 custom_model = model_name
 
             logger.info(f"DETECTED MetaFormer model: {custom_model}")
+            # Stash the model name for patched Stacked2DCNN in case Ludwig drops custom_model from kwargs
+            try:
+                from MetaFormer.metaformer_stacked_cnn import set_current_metaformer_model
+
+                set_current_metaformer_model(custom_model)
+            except Exception:
+                logger.debug("Could not set current MetaFormer model hint; proceeding without global override")
+            # Also pass via environment to survive process boundaries (e.g., Ray workers)
+            os.environ["GLEAM_META_FORMER_MODEL"] = custom_model
             cfg_channels, cfg_height, cfg_width = 3, 224, 224
+            model_cfg = {}
             if META_DEFAULT_CFGS:
                 model_cfg = META_DEFAULT_CFGS.get(custom_model, {})
                 input_size = model_cfg.get("input_size")
@@ -173,7 +184,22 @@
                         int(input_size[2]),
                     )
 
-            target_height, target_width = cfg_height, cfg_width
+            weights_url = None
+            if isinstance(model_cfg, dict):
+                weights_url = model_cfg.get("url")
+            logger.info(
+                "MetaFormer cfg lookup: model=%s has_cfg=%s url=%s use_pretrained=%s",
+                custom_model,
+                bool(model_cfg),
+                weights_url,
+                use_pretrained,
+            )
+            if use_pretrained and not weights_url:
+                logger.warning(
+                    "MetaFormer pretrained requested for %s but no URL found in default cfgs; model will be randomly initialized",
+                    custom_model,
+                )
+
             resize_value = config_params.get("image_resize")
             if resize_value and resize_value != "original":
                 try:
@@ -198,17 +224,15 @@
             else:
                 image_zip_path = config_params.get("image_zip", "")
                 detected_height, detected_width = self._detect_image_dimensions(image_zip_path)
-                if use_pretrained:
-                    if (detected_height, detected_width) != (cfg_height, cfg_width):
-                        logger.info(
-                            "MetaFormer pretrained weights expect %sx%s; resizing from detected %sx%s",
-                            cfg_height,
-                            cfg_width,
-                            detected_height,
-                            detected_width,
-                        )
-                else:
-                    target_height, target_width = detected_height, detected_width
+                target_height, target_width = detected_height, detected_width
+                if use_pretrained and (detected_height, detected_width) != (cfg_height, cfg_width):
+                    logger.info(
+                        "MetaFormer pretrained weights expect %sx%s; proceeding with detected %sx%s",
+                        cfg_height,
+                        cfg_width,
+                        detected_height,
+                        detected_width,
+                    )
                 if target_height <= 0 or target_width <= 0:
                     raise ValueError(
                         f"Invalid detected image dimensions for MetaFormer: {target_height}x{target_width}."