Mercurial > repos > goeckslab > image_learner
comparison MetaFormer/metaformer_stacked_cnn.py @ 16:8729f69e9207 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
| author | goeckslab |
|---|---|
| date | Wed, 03 Dec 2025 01:28:52 +0000 |
| parents | c5150cceab47 |
| children |
comparison
equal
deleted
inserted
replaced
| 15:d17e3a1b8659 | 16:8729f69e9207 |
|---|---|
| 97 if isinstance(input_size, (list, tuple)) and len(input_size) == 3: | 97 if isinstance(input_size, (list, tuple)) and len(input_size) == 3: |
| 98 expected_channels, expected_height, expected_width = input_size | 98 expected_channels, expected_height, expected_width = input_size |
| 99 else: | 99 else: |
| 100 expected_channels, expected_height, expected_width = 3, 224, 224 | 100 expected_channels, expected_height, expected_width = 3, 224, 224 |
| 101 | 101 |
| 102 # Use legacy behavior: keep requested size for adapters but align backbone to 224 for stability | |
| 103 if expected_height != 224 or expected_width != 224: | |
| 104 logger.info( | |
| 105 "Overriding expected backbone size to 224x224 for compatibility (was %sx%s)", | |
| 106 expected_height, | |
| 107 expected_width, | |
| 108 ) | |
| 109 expected_height = expected_width = 224 | |
| 110 | |
| 102 self.expected_channels = expected_channels | 111 self.expected_channels = expected_channels |
| 103 self.expected_height = expected_height | 112 self.expected_height = expected_height |
| 104 self.expected_width = expected_width | 113 self.expected_width = expected_width |
| 105 | 114 |
| 106 logger.info(f"Initializing MetaFormerStackedCNN with model: {custom_model}") | 115 logger.info(f"Initializing MetaFormerStackedCNN with model: {custom_model}") |
| 162 | 171 |
| 163 ctor = _resolve_metaformer_ctor(self.custom_model) | 172 ctor = _resolve_metaformer_ctor(self.custom_model) |
| 164 if ctor is None: | 173 if ctor is None: |
| 165 raise ValueError(f"Unknown MetaFormer model: {self.custom_model}") | 174 raise ValueError(f"Unknown MetaFormer model: {self.custom_model}") |
| 166 | 175 |
| 176 logger.info("MetaFormer backbone requested: %s, use_pretrained=%s", self.custom_model, self.use_pretrained) | |
| 167 cfg = META_DEFAULT_CFGS.get(self.custom_model, {}) | 177 cfg = META_DEFAULT_CFGS.get(self.custom_model, {}) |
| 168 weights_url = cfg.get('url') | 178 logger.info("MetaFormer cfg present=%s", bool(cfg)) |
| 179 if not cfg: | |
| 180 logger.warning("MetaFormer config missing for %s; will fall back to random initialization", self.custom_model) | |
| 181 weights_url = cfg.get('url') if isinstance(cfg, dict) else None | |
| 182 logger.info("MetaFormer weights_url=%s", weights_url) | |
| 169 # track loading | 183 # track loading |
| 170 self._pretrained_loaded = False | 184 self._pretrained_loaded = False |
| 171 self._loaded_weights_url: Optional[str] = None | 185 self._loaded_weights_url: Optional[str] = None |
| 172 if self.use_pretrained and weights_url: | 186 if self.use_pretrained and weights_url: |
| 173 print(f"LOADING MetaFormer pretrained weights from: {weights_url}") | 187 print(f"LOADING MetaFormer pretrained weights from: {weights_url}") |
| 174 logger.info(f"Loading pretrained weights from: {weights_url}") | 188 logger.info(f"Loading pretrained weights from: {weights_url}") |
| 189 elif self.use_pretrained and not weights_url: | |
| 190 logger.warning("MetaFormer: no pretrained URL found for %s; continuing with random weights", self.custom_model) | |
| 191 self.use_pretrained = False | |
| 175 # Ensure we log whenever the factories call torch.hub.load_state_dict_from_url | 192 # Ensure we log whenever the factories call torch.hub.load_state_dict_from_url |
| 176 orig_loader = getattr(torch.hub, 'load_state_dict_from_url', None) | 193 orig_loader = getattr(torch.hub, 'load_state_dict_from_url', None) |
| 177 | 194 |
| 178 def _wrapped_loader(url, *args, **kwargs): | 195 def _wrapped_loader(url, *args, **kwargs): |
| 179 print(f"DOWNLOADING weights from: {url}") | 196 print(f"DOWNLOADING weights from: {url}") |
| 367 | 384 |
| 368 def patched_stacked_cnn_init(self, *args, **kwargs): | 385 def patched_stacked_cnn_init(self, *args, **kwargs): |
| 369 custom_model = kwargs.pop("custom_model", None) | 386 custom_model = kwargs.pop("custom_model", None) |
| 370 if custom_model is None: | 387 if custom_model is None: |
| 371 custom_model = getattr(patch_ludwig_direct, '_metaformer_model', None) | 388 custom_model = getattr(patch_ludwig_direct, '_metaformer_model', None) |
| 389 if custom_model is None: | |
| 390 # Fallback for multi-process contexts | |
| 391 custom_model = os.environ.get("GLEAM_META_FORMER_MODEL") | |
| 392 if custom_model: | |
| 393 logger.info("Recovered MetaFormer model from env: %s", custom_model) | |
| 394 | |
| 395 logger.info( | |
| 396 "patched Stacked2DCNN init called; custom_model=%s kwargs_keys=%s", | |
| 397 custom_model, | |
| 398 list(kwargs.keys()), | |
| 399 ) | |
| 372 | 400 |
| 373 try: | 401 try: |
| 374 if META_MODELS_AVAILABLE and _is_supported_metaformer(custom_model): | 402 if META_MODELS_AVAILABLE and _is_supported_metaformer(custom_model): |
| 375 print(f"DETECTED MetaFormer model: {custom_model}") | 403 print(f"DETECTED MetaFormer model: {custom_model}") |
| 376 print("MetaFormer encoder is being loaded and used.") | 404 print("MetaFormer encoder is being loaded and used.") |
