comparison MetaFormer/metaformer_stacked_cnn.py @ 16:8729f69e9207 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
author goeckslab
date Wed, 03 Dec 2025 01:28:52 +0000
parents c5150cceab47
children
comparison
equal deleted inserted replaced
15:d17e3a1b8659 16:8729f69e9207
97 if isinstance(input_size, (list, tuple)) and len(input_size) == 3: 97 if isinstance(input_size, (list, tuple)) and len(input_size) == 3:
98 expected_channels, expected_height, expected_width = input_size 98 expected_channels, expected_height, expected_width = input_size
99 else: 99 else:
100 expected_channels, expected_height, expected_width = 3, 224, 224 100 expected_channels, expected_height, expected_width = 3, 224, 224
101 101
102 # Use legacy behavior: keep requested size for adapters but align backbone to 224 for stability
103 if expected_height != 224 or expected_width != 224:
104 logger.info(
105 "Overriding expected backbone size to 224x224 for compatibility (was %sx%s)",
106 expected_height,
107 expected_width,
108 )
109 expected_height = expected_width = 224
110
102 self.expected_channels = expected_channels 111 self.expected_channels = expected_channels
103 self.expected_height = expected_height 112 self.expected_height = expected_height
104 self.expected_width = expected_width 113 self.expected_width = expected_width
105 114
106 logger.info(f"Initializing MetaFormerStackedCNN with model: {custom_model}") 115 logger.info(f"Initializing MetaFormerStackedCNN with model: {custom_model}")
162 171
163 ctor = _resolve_metaformer_ctor(self.custom_model) 172 ctor = _resolve_metaformer_ctor(self.custom_model)
164 if ctor is None: 173 if ctor is None:
165 raise ValueError(f"Unknown MetaFormer model: {self.custom_model}") 174 raise ValueError(f"Unknown MetaFormer model: {self.custom_model}")
166 175
176 logger.info("MetaFormer backbone requested: %s, use_pretrained=%s", self.custom_model, self.use_pretrained)
167 cfg = META_DEFAULT_CFGS.get(self.custom_model, {}) 177 cfg = META_DEFAULT_CFGS.get(self.custom_model, {})
168 weights_url = cfg.get('url') 178 logger.info("MetaFormer cfg present=%s", bool(cfg))
179 if not cfg:
180 logger.warning("MetaFormer config missing for %s; will fall back to random initialization", self.custom_model)
181 weights_url = cfg.get('url') if isinstance(cfg, dict) else None
182 logger.info("MetaFormer weights_url=%s", weights_url)
169 # track loading 183 # track loading
170 self._pretrained_loaded = False 184 self._pretrained_loaded = False
171 self._loaded_weights_url: Optional[str] = None 185 self._loaded_weights_url: Optional[str] = None
172 if self.use_pretrained and weights_url: 186 if self.use_pretrained and weights_url:
173 print(f"LOADING MetaFormer pretrained weights from: {weights_url}") 187 print(f"LOADING MetaFormer pretrained weights from: {weights_url}")
174 logger.info(f"Loading pretrained weights from: {weights_url}") 188 logger.info(f"Loading pretrained weights from: {weights_url}")
189 elif self.use_pretrained and not weights_url:
190 logger.warning("MetaFormer: no pretrained URL found for %s; continuing with random weights", self.custom_model)
191 self.use_pretrained = False
175 # Ensure we log whenever the factories call torch.hub.load_state_dict_from_url 192 # Ensure we log whenever the factories call torch.hub.load_state_dict_from_url
176 orig_loader = getattr(torch.hub, 'load_state_dict_from_url', None) 193 orig_loader = getattr(torch.hub, 'load_state_dict_from_url', None)
177 194
178 def _wrapped_loader(url, *args, **kwargs): 195 def _wrapped_loader(url, *args, **kwargs):
179 print(f"DOWNLOADING weights from: {url}") 196 print(f"DOWNLOADING weights from: {url}")
367 384
368 def patched_stacked_cnn_init(self, *args, **kwargs): 385 def patched_stacked_cnn_init(self, *args, **kwargs):
369 custom_model = kwargs.pop("custom_model", None) 386 custom_model = kwargs.pop("custom_model", None)
370 if custom_model is None: 387 if custom_model is None:
371 custom_model = getattr(patch_ludwig_direct, '_metaformer_model', None) 388 custom_model = getattr(patch_ludwig_direct, '_metaformer_model', None)
389 if custom_model is None:
390 # Fallback for multi-process contexts
391 custom_model = os.environ.get("GLEAM_META_FORMER_MODEL")
392 if custom_model:
393 logger.info("Recovered MetaFormer model from env: %s", custom_model)
394
395 logger.info(
396 "patched Stacked2DCNN init called; custom_model=%s kwargs_keys=%s",
397 custom_model,
398 list(kwargs.keys()),
399 )
372 400
373 try: 401 try:
374 if META_MODELS_AVAILABLE and _is_supported_metaformer(custom_model): 402 if META_MODELS_AVAILABLE and _is_supported_metaformer(custom_model):
375 print(f"DETECTED MetaFormer model: {custom_model}") 403 print(f"DETECTED MetaFormer model: {custom_model}")
376 print("MetaFormer encoder is being loaded and used.") 404 print("MetaFormer encoder is being loaded and used.")