Mercurial > repos > goeckslab > image_learner
diff MetaFormer/metaformer_stacked_cnn.py @ 16:8729f69e9207 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
| author | goeckslab |
|---|---|
| date | Wed, 03 Dec 2025 01:28:52 +0000 |
| parents | c5150cceab47 |
| children |
line wrap: on
line diff
--- a/MetaFormer/metaformer_stacked_cnn.py Fri Nov 28 15:45:49 2025 +0000 +++ b/MetaFormer/metaformer_stacked_cnn.py Wed Dec 03 01:28:52 2025 +0000 @@ -99,6 +99,15 @@ else: expected_channels, expected_height, expected_width = 3, 224, 224 + # Use legacy behavior: keep requested size for adapters but align backbone to 224 for stability + if expected_height != 224 or expected_width != 224: + logger.info( + "Overriding expected backbone size to 224x224 for compatibility (was %sx%s)", + expected_height, + expected_width, + ) + expected_height = expected_width = 224 + self.expected_channels = expected_channels self.expected_height = expected_height self.expected_width = expected_width @@ -164,14 +173,22 @@ if ctor is None: raise ValueError(f"Unknown MetaFormer model: {self.custom_model}") + logger.info("MetaFormer backbone requested: %s, use_pretrained=%s", self.custom_model, self.use_pretrained) cfg = META_DEFAULT_CFGS.get(self.custom_model, {}) - weights_url = cfg.get('url') + logger.info("MetaFormer cfg present=%s", bool(cfg)) + if not cfg: + logger.warning("MetaFormer config missing for %s; will fall back to random initialization", self.custom_model) + weights_url = cfg.get('url') if isinstance(cfg, dict) else None + logger.info("MetaFormer weights_url=%s", weights_url) # track loading self._pretrained_loaded = False self._loaded_weights_url: Optional[str] = None if self.use_pretrained and weights_url: print(f"LOADING MetaFormer pretrained weights from: {weights_url}") logger.info(f"Loading pretrained weights from: {weights_url}") + elif self.use_pretrained and not weights_url: + logger.warning("MetaFormer: no pretrained URL found for %s; continuing with random weights", self.custom_model) + self.use_pretrained = False # Ensure we log whenever the factories call torch.hub.load_state_dict_from_url orig_loader = getattr(torch.hub, 'load_state_dict_from_url', None) @@ -369,6 +386,17 @@ custom_model = kwargs.pop("custom_model", None) if custom_model is None: custom_model = getattr(patch_ludwig_direct, '_metaformer_model', None) + if custom_model is None: + # Fallback for multi-process contexts + custom_model = os.environ.get("GLEAM_META_FORMER_MODEL") + if custom_model: + logger.info("Recovered MetaFormer model from env: %s", custom_model) + + logger.info( + "patched Stacked2DCNN init called; custom_model=%s kwargs_keys=%s", + custom_model, + list(kwargs.keys()), + ) try: if META_MODELS_AVAILABLE and _is_supported_metaformer(custom_model):
