diff MetaFormer/metaformer_stacked_cnn.py @ 16:8729f69e9207 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
author goeckslab
date Wed, 03 Dec 2025 01:28:52 +0000
parents c5150cceab47
children
line wrap: on
line diff
--- a/MetaFormer/metaformer_stacked_cnn.py	Fri Nov 28 15:45:49 2025 +0000
+++ b/MetaFormer/metaformer_stacked_cnn.py	Wed Dec 03 01:28:52 2025 +0000
@@ -99,6 +99,15 @@
         else:
             expected_channels, expected_height, expected_width = 3, 224, 224
 
+        # Use legacy behavior: keep requested size for adapters but align backbone to 224 for stability
+        if expected_height != 224 or expected_width != 224:
+            logger.info(
+                "Overriding expected backbone size to 224x224 for compatibility (was %sx%s)",
+                expected_height,
+                expected_width,
+            )
+            expected_height = expected_width = 224
+
         self.expected_channels = expected_channels
         self.expected_height = expected_height
         self.expected_width = expected_width
@@ -164,14 +173,22 @@
         if ctor is None:
             raise ValueError(f"Unknown MetaFormer model: {self.custom_model}")
 
+        logger.info("MetaFormer backbone requested: %s, use_pretrained=%s", self.custom_model, self.use_pretrained)
         cfg = META_DEFAULT_CFGS.get(self.custom_model, {})
-        weights_url = cfg.get('url')
+        logger.info("MetaFormer cfg present=%s", bool(cfg))
+        if not cfg:
+            logger.warning("MetaFormer config missing for %s; will fall back to random initialization", self.custom_model)
+        weights_url = cfg.get('url') if isinstance(cfg, dict) else None
+        logger.info("MetaFormer weights_url=%s", weights_url)
         # track loading
         self._pretrained_loaded = False
         self._loaded_weights_url: Optional[str] = None
         if self.use_pretrained and weights_url:
             print(f"LOADING MetaFormer pretrained weights from: {weights_url}")
             logger.info(f"Loading pretrained weights from: {weights_url}")
+        elif self.use_pretrained and not weights_url:
+            logger.warning("MetaFormer: no pretrained URL found for %s; continuing with random weights", self.custom_model)
+            self.use_pretrained = False
         # Ensure we log whenever the factories call torch.hub.load_state_dict_from_url
         orig_loader = getattr(torch.hub, 'load_state_dict_from_url', None)
 
@@ -369,6 +386,17 @@
             custom_model = kwargs.pop("custom_model", None)
             if custom_model is None:
                 custom_model = getattr(patch_ludwig_direct, '_metaformer_model', None)
+            if custom_model is None:
+                # Fallback for multi-process contexts
+                custom_model = os.environ.get("GLEAM_META_FORMER_MODEL")
+                if custom_model:
+                    logger.info("Recovered MetaFormer model from env: %s", custom_model)
+
+            logger.info(
+                "patched Stacked2DCNN init called; custom_model=%s kwargs_keys=%s",
+                custom_model,
+                list(kwargs.keys()),
+            )
 
             try:
                 if META_MODELS_AVAILABLE and _is_supported_metaformer(custom_model):