Mercurial > repos > goeckslab > image_learner
annotate MetaFormer/metaformer_stacked_cnn.py @ 16:8729f69e9207 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
| author | goeckslab |
|---|---|
| date | Wed, 03 Dec 2025 01:28:52 +0000 |
| parents | c5150cceab47 |
| children |
| rev | line source |
|---|---|
|
11
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
1 import logging |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
2 import os |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
3 import sys |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
4 from typing import Dict, List, Optional |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
5 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
6 import torch |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
7 import torch.nn as nn |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
8 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
9 sys.path.insert(0, os.path.dirname(__file__)) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
10 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
11 logging.basicConfig( |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
12 level=logging.INFO, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
13 format="%(asctime)s %(levelname)s %(name)s: %(message)s", |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
14 ) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
15 logger = logging.getLogger(__name__) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
16 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
17 SUPPORTED_PREFIXES = ( |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
18 'identityformer_', |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
19 'randformer_', |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
20 'poolformerv2_', |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
21 'convformer_', |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
22 'caformer_', |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
23 ) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
24 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
25 try: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
26 from metaformer_models import default_cfgs as META_DEFAULT_CFGS |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
27 META_MODELS_AVAILABLE = True |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
28 logger.info("MetaFormer models imported successfully") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
29 except Exception as e: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
30 META_MODELS_AVAILABLE = False |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
31 logger.warning(f"MetaFormer models not available: {e}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
32 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
33 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
34 def _resolve_metaformer_ctor(model_name: str): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
35 # Prefer getattr to avoid importing every factory explicitly |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
36 try: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
37 # Import the module itself for dynamic access |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
38 import metaformer_models |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
39 _factories = metaformer_models.__dict__ |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
40 if model_name in _factories and callable(_factories[model_name]): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
41 return _factories[model_name] |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
42 except Exception: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
43 pass |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
44 return None |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
45 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
46 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
47 class MetaFormerStackedCNN(nn.Module): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
48 def __init__( |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
49 self, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
50 height: int = 224, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
51 width: int = 224, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
52 num_channels: int = 3, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
53 output_size: int = 128, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
54 custom_model: str = "identityformer_s12", |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
55 use_pretrained: bool = True, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
56 trainable: bool = True, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
57 conv_layers: Optional[List[Dict]] = None, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
58 num_conv_layers: Optional[int] = None, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
59 conv_activation: str = "relu", |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
60 conv_dropout: float = 0.0, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
61 conv_norm: Optional[str] = None, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
62 conv_use_bias: bool = True, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
63 fc_layers: Optional[List[Dict]] = None, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
64 num_fc_layers: int = 1, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
65 fc_activation: str = "relu", |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
66 fc_dropout: float = 0.0, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
67 fc_norm: Optional[str] = None, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
68 fc_use_bias: bool = True, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
69 **kwargs, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
70 ): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
71 super().__init__() |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
72 logger.info("MetaFormerStackedCNN encoder instantiated") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
73 logger.info(f"Using MetaFormer model: {custom_model}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
74 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
75 try: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
76 height = int(height) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
77 width = int(width) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
78 num_channels = int(num_channels) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
79 except (TypeError, ValueError) as exc: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
80 raise ValueError("MetaFormerStackedCNN requires integer height, width, and num_channels.") from exc |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
81 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
82 if height <= 0 or width <= 0: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
83 raise ValueError(f"MetaFormerStackedCNN received non-positive dimensions: {height}x{width}.") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
84 if num_channels <= 0: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
85 raise ValueError(f"MetaFormerStackedCNN requires num_channels > 0, received {num_channels}.") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
86 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
87 self.height = height |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
88 self.width = width |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
89 self.num_channels = num_channels |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
90 self.output_size = output_size |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
91 self.custom_model = custom_model |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
92 self.use_pretrained = use_pretrained |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
93 self.trainable = trainable |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
94 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
95 cfg = META_DEFAULT_CFGS.get(custom_model, {}) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
96 input_size = cfg.get('input_size', (3, 224, 224)) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
97 if isinstance(input_size, (list, tuple)) and len(input_size) == 3: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
98 expected_channels, expected_height, expected_width = input_size |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
99 else: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
100 expected_channels, expected_height, expected_width = 3, 224, 224 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
101 |
|
16
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
102 # Use legacy behavior: keep requested size for adapters but align backbone to 224 for stability |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
103 if expected_height != 224 or expected_width != 224: |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
104 logger.info( |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
105 "Overriding expected backbone size to 224x224 for compatibility (was %sx%s)", |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
106 expected_height, |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
107 expected_width, |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
108 ) |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
109 expected_height = expected_width = 224 |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
110 |
|
11
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
111 self.expected_channels = expected_channels |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
112 self.expected_height = expected_height |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
113 self.expected_width = expected_width |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
114 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
115 logger.info(f"Initializing MetaFormerStackedCNN with model: {custom_model}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
116 logger.info( |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
117 "Input: %sx%sx%s -> Output: %s (expected backbone size: %sx%s)", |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
118 num_channels, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
119 height, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
120 width, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
121 output_size, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
122 self.expected_height, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
123 self.expected_width, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
124 ) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
125 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
126 self.channel_adapter: Optional[nn.Conv2d] = None |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
127 if num_channels != self.expected_channels: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
128 self.channel_adapter = nn.Conv2d( |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
129 num_channels, self.expected_channels, kernel_size=1, stride=1, padding=0 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
130 ) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
131 logger.info( |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
132 "Added channel adapter: %s -> %s channels", |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
133 num_channels, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
134 self.expected_channels, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
135 ) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
136 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
137 self.size_adapter: Optional[nn.Module] = None |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
138 if height != self.expected_height or width != self.expected_width: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
139 self.size_adapter = nn.AdaptiveAvgPool2d((height, width)) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
140 logger.info( |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
141 "Configured size adapter to requested input: %sx%s", |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
142 height, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
143 width, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
144 ) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
145 self.backbone_adapter: Optional[nn.Module] = None |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
146 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
147 self.backbone = self._load_metaformer_backbone() |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
148 self.feature_dim = self._get_feature_dim() |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
149 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
150 self.fc_layers = self._create_fc_layers( |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
151 input_dim=self.feature_dim, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
152 output_dim=output_size, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
153 num_layers=num_fc_layers, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
154 activation=fc_activation, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
155 dropout=fc_dropout, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
156 norm=fc_norm, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
157 use_bias=fc_use_bias, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
158 fc_layers_config=fc_layers, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
159 ) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
160 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
161 if not trainable: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
162 for param in self.backbone.parameters(): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
163 param.requires_grad = False |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
164 logger.info("MetaFormer backbone frozen (trainable=False)") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
165 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
166 logger.info("MetaFormerStackedCNN initialized successfully") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
167 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
168 def _load_metaformer_backbone(self): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
169 if not META_MODELS_AVAILABLE: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
170 raise ImportError("MetaFormer models are not available") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
171 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
172 ctor = _resolve_metaformer_ctor(self.custom_model) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
173 if ctor is None: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
174 raise ValueError(f"Unknown MetaFormer model: {self.custom_model}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
175 |
|
16
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
176 logger.info("MetaFormer backbone requested: %s, use_pretrained=%s", self.custom_model, self.use_pretrained) |
|
11
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
177 cfg = META_DEFAULT_CFGS.get(self.custom_model, {}) |
|
16
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
178 logger.info("MetaFormer cfg present=%s", bool(cfg)) |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
179 if not cfg: |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
180 logger.warning("MetaFormer config missing for %s; will fall back to random initialization", self.custom_model) |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
181 weights_url = cfg.get('url') if isinstance(cfg, dict) else None |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
182 logger.info("MetaFormer weights_url=%s", weights_url) |
|
11
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
183 # track loading |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
184 self._pretrained_loaded = False |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
185 self._loaded_weights_url: Optional[str] = None |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
186 if self.use_pretrained and weights_url: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
187 print(f"LOADING MetaFormer pretrained weights from: {weights_url}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
188 logger.info(f"Loading pretrained weights from: {weights_url}") |
|
16
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
189 elif self.use_pretrained and not weights_url: |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
190 logger.warning("MetaFormer: no pretrained URL found for %s; continuing with random weights", self.custom_model) |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
191 self.use_pretrained = False |
|
11
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
192 # Ensure we log whenever the factories call torch.hub.load_state_dict_from_url |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
193 orig_loader = getattr(torch.hub, 'load_state_dict_from_url', None) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
194 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
195 def _wrapped_loader(url, *args, **kwargs): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
196 print(f"DOWNLOADING weights from: {url}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
197 logger.info(f"DOWNLOADING weights from: {url}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
198 self._pretrained_loaded = True |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
199 self._loaded_weights_url = url |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
200 result = orig_loader(url, *args, **kwargs) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
201 print(f"WEIGHTS DOWNLOADED successfully from: {url}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
202 return result |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
203 try: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
204 if self.use_pretrained and orig_loader is not None: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
205 torch.hub.load_state_dict_from_url = _wrapped_loader # type: ignore[attr-defined] |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
206 print(f"CREATING MetaFormer model: {self.custom_model} (pretrained={self.use_pretrained})") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
207 try: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
208 model = ctor(pretrained=self.use_pretrained, num_classes=1000) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
209 print(f"MetaFormer model CREATED: {self.custom_model}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
210 except Exception as model_error: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
211 if self.use_pretrained: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
212 print(f"âš Warning: Failed to load {self.custom_model} with pretrained weights: {model_error}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
213 print("Attempting to load without pretrained weights as fallback...") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
214 logger.warning(f"Failed to load {self.custom_model} with pretrained weights: {model_error}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
215 model = ctor(pretrained=False, num_classes=1000) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
216 print(f"✓ Successfully loaded {self.custom_model} without pretrained weights") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
217 self.use_pretrained = False # Update state to reflect actual loading |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
218 else: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
219 raise model_error |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
220 finally: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
221 if orig_loader is not None: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
222 torch.hub.load_state_dict_from_url = orig_loader # type: ignore[attr-defined] |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
223 self._metaformer_weights_url = weights_url |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
224 if self.use_pretrained: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
225 if self._pretrained_loaded: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
226 print(f"MetaFormer: pretrained weights loaded from {self._loaded_weights_url}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
227 logger.info(f"MetaFormer: pretrained weights loaded from {self._loaded_weights_url}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
228 else: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
229 # Warn but don't fail - weights may have failed to load but model creation succeeded |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
230 print("âš Warning: MetaFormer pretrained weights were requested but not confirmed as loaded") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
231 logger.warning("MetaFormer: pretrained weights were requested but not confirmed as loaded") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
232 else: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
233 print(f"MetaFormer: using randomly initialized weights for {self.custom_model}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
234 logger.info(f"MetaFormer: using randomly initialized weights for {self.custom_model}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
235 logger.info(f"Loaded MetaFormer backbone: {self.custom_model} (pretrained={self.use_pretrained})") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
236 return model |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
237 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
238 def _get_feature_dim(self): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
239 with torch.no_grad(): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
240 dummy_input = torch.randn(1, 3, 224, 224) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
241 features = self.backbone.forward_features(dummy_input) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
242 feature_dim = features.shape[-1] |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
243 logger.info(f"MetaFormer feature dimension: {feature_dim}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
244 return feature_dim |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
245 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
246 def _create_fc_layers(self, input_dim, output_dim, num_layers, activation, dropout, norm, use_bias, fc_layers_config): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
247 layers = [] |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
248 if fc_layers_config: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
249 current_dim = input_dim |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
250 for i, layer_config in enumerate(fc_layers_config): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
251 layer_output_dim = layer_config.get('output_size', output_dim if i == len(fc_layers_config) - 1 else current_dim) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
252 layers.append(nn.Linear(current_dim, layer_output_dim, bias=use_bias)) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
253 if i < len(fc_layers_config) - 1: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
254 if activation == "relu": |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
255 layers.append(nn.ReLU()) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
256 elif activation == "tanh": |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
257 layers.append(nn.Tanh()) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
258 elif activation == "sigmoid": |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
259 layers.append(nn.Sigmoid()) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
260 elif activation == "leaky_relu": |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
261 layers.append(nn.LeakyReLU()) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
262 if dropout > 0: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
263 layers.append(nn.Dropout(dropout)) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
264 if norm == "batch": |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
265 layers.append(nn.BatchNorm1d(layer_output_dim)) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
266 elif norm == "layer": |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
267 layers.append(nn.LayerNorm(layer_output_dim)) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
268 current_dim = layer_output_dim |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
269 else: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
270 if num_layers == 1: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
271 layers.append(nn.Linear(input_dim, output_dim, bias=use_bias)) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
272 else: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
273 intermediate_dims = [input_dim] |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
274 for i in range(num_layers - 1): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
275 intermediate_dim = int(input_dim * (0.5 ** (i + 1))) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
276 intermediate_dim = max(intermediate_dim, output_dim) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
277 intermediate_dims.append(intermediate_dim) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
278 intermediate_dims.append(output_dim) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
279 for i in range(num_layers): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
280 layers.append(nn.Linear(intermediate_dims[i], intermediate_dims[i + 1], bias=use_bias)) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
281 if i < num_layers - 1: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
282 if activation == "relu": |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
283 layers.append(nn.ReLU()) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
284 elif activation == "tanh": |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
285 layers.append(nn.Tanh()) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
286 elif activation == "sigmoid": |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
287 layers.append(nn.Sigmoid()) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
288 elif activation == "leaky_relu": |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
289 layers.append(nn.LeakyReLU()) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
290 if dropout > 0: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
291 layers.append(nn.Dropout(dropout)) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
292 if norm == "batch": |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
293 layers.append(nn.BatchNorm1d(intermediate_dims[i + 1])) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
294 elif norm == "layer": |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
295 layers.append(nn.LayerNorm(intermediate_dims[i + 1])) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
296 return nn.Sequential(*layers) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
297 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
298 def forward(self, x): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
299 if x.shape[1] != self.expected_channels: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
300 if ( |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
301 self.channel_adapter is None |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
302 or self.channel_adapter.in_channels != x.shape[1] |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
303 or self.channel_adapter.out_channels != self.expected_channels |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
304 ): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
305 self.channel_adapter = nn.Conv2d( |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
306 x.shape[1], |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
307 self.expected_channels, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
308 kernel_size=1, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
309 stride=1, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
310 padding=0, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
311 ).to(x.device) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
312 logger.info( |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
313 "Created dynamic channel adapter: %s -> %s channels", |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
314 x.shape[1], |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
315 self.expected_channels, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
316 ) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
317 x = self.channel_adapter(x) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
318 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
319 target_height, target_width = self.height, self.width |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
320 if x.shape[2] != target_height or x.shape[3] != target_width: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
321 if ( |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
322 self.size_adapter is None |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
323 or getattr(self.size_adapter, "output_size", None) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
324 != (target_height, target_width) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
325 ): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
326 self.size_adapter = nn.AdaptiveAvgPool2d( |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
327 (target_height, target_width) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
328 ).to(x.device) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
329 logger.info( |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
330 "Created size adapter: %sx%s -> %sx%s", |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
331 x.shape[2], |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
332 x.shape[3], |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
333 target_height, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
334 target_width, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
335 ) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
336 x = self.size_adapter(x) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
337 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
338 if target_height != self.expected_height or target_width != self.expected_width: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
339 if ( |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
340 self.backbone_adapter is None |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
341 or getattr(self.backbone_adapter, "output_size", None) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
342 != (self.expected_height, self.expected_width) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
343 ): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
344 self.backbone_adapter = nn.AdaptiveAvgPool2d( |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
345 (self.expected_height, self.expected_width) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
346 ).to(x.device) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
347 logger.info( |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
348 "Aligning to MetaFormer backbone size: %sx%s", |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
349 self.expected_height, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
350 self.expected_width, |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
351 ) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
352 x = self.backbone_adapter(x) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
353 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
354 features = self.backbone.forward_features(x) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
355 output = self.fc_layers(features) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
356 return {'encoder_output': output} |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
357 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
358 @property |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
359 def output_shape(self): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
360 return [self.output_size] |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
361 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
362 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
363 def create_metaformer_stacked_cnn(model_name: str, **kwargs) -> MetaFormerStackedCNN: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
364 encoder = MetaFormerStackedCNN(custom_model=model_name, **kwargs) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
365 return encoder |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
366 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
367 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
368 def patch_ludwig_stacked_cnn(): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
369 # Only patch Ludwig if MetaFormer models are available in this runtime |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
370 if not META_MODELS_AVAILABLE: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
371 logger.warning("MetaFormer models unavailable; skipping Ludwig patch for stacked_cnn.") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
372 return False |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
373 return patch_ludwig_direct() |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
374 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
375 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
376 def _is_supported_metaformer(custom_model: Optional[str]) -> bool: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
377 return bool(custom_model) and custom_model.startswith(SUPPORTED_PREFIXES) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
378 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
379 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
380 def patch_ludwig_direct(): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
381 try: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
382 from ludwig.encoders.image.base import Stacked2DCNN |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
383 original_stacked_cnn_init = Stacked2DCNN.__init__ |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
384 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
385 def patched_stacked_cnn_init(self, *args, **kwargs): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
386 custom_model = kwargs.pop("custom_model", None) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
387 if custom_model is None: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
388 custom_model = getattr(patch_ludwig_direct, '_metaformer_model', None) |
|
16
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
389 if custom_model is None: |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
390 # Fallback for multi-process contexts |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
391 custom_model = os.environ.get("GLEAM_META_FORMER_MODEL") |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
392 if custom_model: |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
393 logger.info("Recovered MetaFormer model from env: %s", custom_model) |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
394 |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
395 logger.info( |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
396 "patched Stacked2DCNN init called; custom_model=%s kwargs_keys=%s", |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
397 custom_model, |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
398 list(kwargs.keys()), |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
11
diff
changeset
|
399 ) |
|
11
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
400 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
401 try: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
402 if META_MODELS_AVAILABLE and _is_supported_metaformer(custom_model): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
403 print(f"DETECTED MetaFormer model: {custom_model}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
404 print("MetaFormer encoder is being loaded and used.") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
405 # Initialize base class to keep Ludwig internals intact |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
406 original_stacked_cnn_init(self, *args, **kwargs) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
407 # Create our MetaFormer encoder and graft behavior |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
408 mf_encoder = create_metaformer_stacked_cnn(custom_model, **kwargs) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
409 # ensure base attributes won't be used accidentally |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
410 for attr in ("conv_layers", "fc_layers", "combiner", "output_shape", "reduce_output"): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
411 if hasattr(self, attr): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
412 try: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
413 setattr(self, attr, getattr(mf_encoder, attr, None)) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
414 except Exception: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
415 pass |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
416 self.forward = mf_encoder.forward |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
417 if hasattr(mf_encoder, 'backbone'): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
418 self.backbone = mf_encoder.backbone |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
419 if hasattr(mf_encoder, 'fc_layers'): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
420 self.fc_layers = mf_encoder.fc_layers |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
421 if hasattr(mf_encoder, 'custom_model'): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
422 self.custom_model = mf_encoder.custom_model |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
423 # explicit confirmation logs |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
424 try: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
425 url_info = getattr(mf_encoder, '_loaded_weights_url', None) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
426 loaded_flag = getattr(mf_encoder, '_pretrained_loaded', False) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
427 if loaded_flag and url_info: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
428 print(f"CONFIRMED: MetaFormer '{custom_model}' using pretrained weights from: {url_info}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
429 logger.info(f"CONFIRMED: MetaFormer '{custom_model}' using pretrained weights from: {url_info}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
430 else: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
431 print(f"CONFIRMED: MetaFormer '{custom_model}' using randomly initialized weights (no pretrained)") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
432 logger.info(f"CONFIRMED: MetaFormer '{custom_model}' using randomly initialized weights") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
433 except Exception: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
434 pass |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
435 else: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
436 original_stacked_cnn_init(self, *args, **kwargs) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
437 finally: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
438 if hasattr(patch_ludwig_direct, '_metaformer_model'): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
439 patch_ludwig_direct._metaformer_model = None |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
440 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
441 Stacked2DCNN.__init__ = patched_stacked_cnn_init |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
442 return True |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
443 except Exception as e: |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
444 logger.error(f"Failed to apply MetaFormer direct patch: {e}") |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
445 return False |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
446 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
447 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
448 def set_current_metaformer_model(model_name: str): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
449 """Store the current MetaFormer model name for the patch to use.""" |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
450 setattr(patch_ludwig_direct, '_metaformer_model', model_name) |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
451 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
452 |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
453 def clear_current_metaformer_model(): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
454 """Remove any cached MetaFormer model hint.""" |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
455 if hasattr(patch_ludwig_direct, '_metaformer_model'): |
|
c5150cceab47
planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
goeckslab
parents:
diff
changeset
|
456 delattr(patch_ludwig_direct, '_metaformer_model') |
