comparison pytorch_embedding.py @ 1:84f96c952c2c draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
author goeckslab
date Sun, 09 Nov 2025 19:03:21 +0000
parents 38333676a029
children
comparison
equal deleted inserted replaced
0:38333676a029 1:84f96c952c2c
16 and conversion. 16 and conversion.
17 """ 17 """
18 18
19 import argparse 19 import argparse
20 import csv 20 import csv
21 import inspect
22 import logging 21 import logging
23 import os 22 import os
24 import zipfile 23 import zipfile
25 from inspect import signature 24 from inspect import signature
26 25
30 import torchvision.models as models 29 import torchvision.models as models
31 from PIL import Image 30 from PIL import Image
32 from torch.utils.data import DataLoader, Dataset 31 from torch.utils.data import DataLoader, Dataset
33 from torchvision import transforms 32 from torchvision import transforms
34 33
34 # GPFM imports
35 try:
36 import requests
37 GPFM_AVAILABLE = True
38 except ImportError as e:
39 GPFM_AVAILABLE = False
40 logging.warning(f"GPFM dependencies not available: {e}")
41
35 # Configure logging 42 # Configure logging
36 logging.basicConfig( 43 logging.basicConfig(
37 filename="/tmp/ludwig_embeddings.log",
38 filemode="a",
39 format="%(asctime)s - %(levelname)s - %(message)s", 44 format="%(asctime)s - %(levelname)s - %(message)s",
40 level=logging.DEBUG, 45 level=logging.INFO,
46 handlers=[
47 logging.StreamHandler(), # Console output
48 logging.FileHandler("/tmp/ludwig_embeddings.log", mode="a") # File output
49 ]
41 ) 50 )
42 51
43 # Create a cache directory in the current working directory 52 # Create a cache directory in the current working directory
44 cache_dir = os.path.join(os.getcwd(), 'hf_cache') 53 cache_dir = os.path.join(os.getcwd(), 'hf_cache')
45 try: 54 try:
47 logging.info(f"Cache directory created: {cache_dir}, writable: {os.access(cache_dir, os.W_OK)}") 56 logging.info(f"Cache directory created: {cache_dir}, writable: {os.access(cache_dir, os.W_OK)}")
48 except OSError as e: 57 except OSError as e:
49 logging.error(f"Failed to create cache directory {cache_dir}: {e}") 58 logging.error(f"Failed to create cache directory {cache_dir}: {e}")
50 raise 59 raise
51 60
61 # GPFM DinoVisionTransformer Implementation
62
63
64 class DinoVisionTransformer(torch.nn.Module):
65 """Simplified DinoVisionTransformer for GPFM."""
66
67 def __init__(self, img_size=224, patch_size=14, embed_dim=1024, depth=24, num_heads=16):
68 super().__init__()
69 self.embed_dim = embed_dim
70 self.num_features = embed_dim
71 self.patch_size = patch_size
72
73 # Patch embedding
74 self.patch_embed = torch.nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
75 num_patches = (img_size // patch_size) ** 2
76
77 # Class token
78 self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, embed_dim))
79
80 # Position embeddings
81 self.pos_embed = torch.nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
82
83 # Transformer blocks (simplified)
84 self.blocks = torch.nn.ModuleList([
85 torch.nn.TransformerEncoderLayer(
86 d_model=embed_dim,
87 nhead=num_heads,
88 dim_feedforward=embed_dim * 4,
89 dropout=0.0,
90 batch_first=True
91 ) for _ in range(depth)
92 ])
93
94 # Layer norm
95 self.norm = torch.nn.LayerNorm(embed_dim)
96
97 # Initialize weights
98 torch.nn.init.trunc_normal_(self.pos_embed, std=0.02)
99 torch.nn.init.trunc_normal_(self.cls_token, std=0.02)
100
101 def forward(self, x):
102 B = x.shape[0]
103
104 # Patch embedding
105 x = self.patch_embed(x) # B, embed_dim, H//patch_size, W//patch_size
106 x = x.flatten(2).transpose(1, 2) # B, num_patches, embed_dim
107
108 # Add class token
109 cls_tokens = self.cls_token.expand(B, -1, -1)
110 x = torch.cat([cls_tokens, x], dim=1)
111
112 # Add position embeddings
113 x = x + self.pos_embed
114
115 # Apply transformer blocks
116 for block in self.blocks:
117 x = block(x)
118
119 # Apply layer norm and return class token
120 x = self.norm(x)
121 return x[:, 0] # Return class token features
122
123
124 # GPFM Model Implementation
125 class GPFMModel(torch.nn.Module):
126 """GPFM (Generalizable Pathology Foundation Model) implementation."""
127
128 def __init__(self, device='cpu'):
129 super().__init__()
130 self.device = device
131 self.model = None
132 self.transformer = None
133 self.embed_dim = 1024 # GPFM uses 1024-dimensional embeddings
134 self._load_model()
135
136 def _download_weights(self, url, filepath):
137 """Download GPFM weights from the official repository."""
138 if os.path.exists(filepath):
139 logging.info(f"GPFM weights already exist at {filepath}")
140 return True
141
142 logging.info(f"Downloading GPFM weights from {url}")
143 try:
144 response = requests.get(url, stream=True, timeout=300)
145 response.raise_for_status()
146
147 os.makedirs(os.path.dirname(filepath), exist_ok=True)
148
149 # Get file size for progress tracking
150 total_size = int(response.headers.get('content-length', 0))
151 downloaded = 0
152
153 with open(filepath, 'wb') as f:
154 for chunk in response.iter_content(chunk_size=8192):
155 if chunk:
156 f.write(chunk)
157 downloaded += len(chunk)
158 if total_size > 0:
159 progress = (downloaded / total_size) * 100
160 if downloaded % (1024 * 1024 * 10) == 0: # Log every 10MB
161 logging.info(f"Downloaded {downloaded // (1024 * 1024)}MB / {total_size // (1024 * 1024)}MB ({progress:.1f}%)")
162
163 logging.info(f"GPFM weights downloaded successfully to {filepath}")
164 return True
165
166 except Exception as e:
167 logging.error(f"Failed to download GPFM weights: {e}")
168 if os.path.exists(filepath):
169 os.remove(filepath) # Clean up partial download
170 return False
171
172 def _load_model(self):
173 """Load GPFM model with pretrained weights."""
174 try:
175 # Create models directory
176 models_dir = os.path.join(cache_dir, 'gpfm_models')
177 os.makedirs(models_dir, exist_ok=True)
178
179 # GPFM weights URL from official repository
180 weights_url = "https://github.com/birkhoffkiki/GPFM/releases/download/ckpt/GPFM.pth"
181 weights_path = os.path.join(models_dir, 'GPFM.pth')
182
183 # Create GPFM DinoVisionTransformer architecture
184 self.model = DinoVisionTransformer(
185 img_size=224,
186 patch_size=14,
187 embed_dim=1024,
188 depth=24,
189 num_heads=16
190 )
191
192 # Try to download and load GPFM weights
193 weights_loaded = False
194 if self._download_weights(weights_url, weights_path):
195 try:
196 logging.info("Loading GPFM pretrained weights...")
197 checkpoint = torch.load(weights_path, map_location=self.device)
198
199 # Extract teacher model weights (GPFM format)
200 if 'teacher' in checkpoint:
201 state_dict = checkpoint['teacher']
202 logging.info("Found 'teacher' key in checkpoint")
203 else:
204 state_dict = checkpoint
205 logging.info("Using checkpoint directly")
206
207 # Rename keys to match our simplified architecture
208 new_state_dict = {}
209 for k, v in state_dict.items():
210 # Remove 'backbone.' prefix if present
211 if k.startswith('backbone.'):
212 k = k[9:] # Remove 'backbone.'
213
214 # Map GPFM keys to our simplified architecture
215 if k in ['cls_token', 'pos_embed']:
216 new_state_dict[k] = v
217 elif k.startswith('patch_embed.proj.'):
218 # Map patch embedding
219 new_k = k.replace('patch_embed.proj.', 'patch_embed.')
220 new_state_dict[new_k] = v
221 elif k.startswith('blocks.') and 'norm' in k:
222 # Map layer norms
223 if k.endswith('.norm1.weight') or k.endswith('.norm1.bias'):
224 # Skip intermediate norms for simplified model
225 continue
226 elif k.endswith('.norm2.weight') or k.endswith('.norm2.bias'):
227 continue
228 elif k == 'norm.weight' or k == 'norm.bias':
229 new_state_dict[k] = v
230
231 # Load compatible weights
232 missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False)
233 if missing_keys:
234 logging.warning(f"Missing keys: {missing_keys[:5]}...") # Show first 5
235 if unexpected_keys:
236 logging.warning(f"Unexpected keys: {unexpected_keys[:5]}...") # Show first 5
237
238 logging.info("GPFM pretrained weights loaded successfully")
239 weights_loaded = True
240
241 except Exception as e:
242 logging.warning(f"Could not load GPFM weights: {e}")
243
244 if not weights_loaded:
245 logging.info("Using randomly initialized GPFM architecture (no pretrained weights)")
246
247 self.model = self.model.to(self.device)
248 self.model.eval()
249
250 # GPFM preprocessing (based on official repository)
251 self.transformer = transforms.Compose([
252 transforms.Lambda(lambda x: x.convert("RGB")), # Ensure RGB format
253 transforms.Resize((224, 224)), # GPFM uses 224x224 (not 512x512 for features)
254 transforms.ToTensor(),
255 transforms.Normalize(
256 mean=[0.485, 0.456, 0.406], # ImageNet normalization
257 std=[0.229, 0.224, 0.225]
258 )
259 ])
260
261 logging.info(f"GPFM model initialized successfully (embed_dim: {self.embed_dim})")
262
263 except Exception as e:
264 logging.error(f"Failed to initialize GPFM model: {e}")
265 raise
266
267 def forward(self, x):
268 """Forward pass through GPFM model."""
269 with torch.no_grad():
270 return self.model(x)
271
272 def get_transformer(self, apply_normalization=True):
273 """Get the preprocessing transformer for GPFM."""
274 if apply_normalization:
275 return self.transformer
276 else:
277 # Return transformer without normalization
278 return transforms.Compose([
279 transforms.Lambda(lambda x: x.convert("RGB")),
280 transforms.Resize((224, 224)),
281 transforms.ToTensor()
282 ])
283
284
52 # Available models from torchvision 285 # Available models from torchvision
53 AVAILABLE_MODELS = { 286 AVAILABLE_MODELS = {
54 name: getattr(models, name) 287 name: getattr(models, name)
55 for name in dir(models) 288 for name in dir(models)
56 if callable( 289 if callable(
57 getattr(models, name) 290 getattr(models, name)
58 ) and "weights" in signature(getattr(models, name)).parameters 291 ) and "weights" in signature(getattr(models, name)).parameters
59 } 292 }
293
294 # Add GPFM model if available
295 if GPFM_AVAILABLE:
296 AVAILABLE_MODELS['gpfm'] = GPFMModel
60 297
61 # Default resize and normalization settings for models 298 # Default resize and normalization settings for models
62 MODEL_DEFAULTS = { 299 MODEL_DEFAULTS = {
63 "default": {"resize": (224, 224), "normalize": ( 300 "default": {"resize": (224, 224), "normalize": (
64 [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 301 [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
84 [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] 321 [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
85 )}, 322 )},
86 "vit_b_32": {"resize": (224, 224), "normalize": ( 323 "vit_b_32": {"resize": (224, 224), "normalize": (
87 [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] 324 [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
88 )}, 325 )},
326 "gpfm": {
327 "resize": (224, 224),
328 "normalize": ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
329 },
89 } 330 }
90 331
91 for model, settings in MODEL_DEFAULTS.items(): 332 for model, settings in MODEL_DEFAULTS.items():
92 if "normalize" not in settings: 333 if "normalize" not in settings:
93 settings["normalize"] = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 334 settings["normalize"] = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
150 if model_name not in AVAILABLE_MODELS: 391 if model_name not in AVAILABLE_MODELS:
151 raise ValueError( 392 raise ValueError(
152 f"Unsupported model: {model_name}. \ 393 f"Unsupported model: {model_name}. \
153 Available models: {list(AVAILABLE_MODELS.keys())}") 394 Available models: {list(AVAILABLE_MODELS.keys())}")
154 try: 395 try:
155 if "weights" in inspect.signature( 396 # Special handling for GPFM
397 if model_name == "gpfm":
398 model = AVAILABLE_MODELS[model_name](device=device)
399 logging.info("GPFM model loaded")
400 return model
401
402 # Standard torchvision models
403 if "weights" in signature(
156 AVAILABLE_MODELS[model_name]).parameters: 404 AVAILABLE_MODELS[model_name]).parameters:
157 model = AVAILABLE_MODELS[model_name](weights="DEFAULT").to(device) 405 model = AVAILABLE_MODELS[model_name](weights="DEFAULT").to(device)
158 else: 406 else:
159 model = AVAILABLE_MODELS[model_name]().to(device) 407 model = AVAILABLE_MODELS[model_name]().to(device)
160 logging.info("Model loaded") 408 logging.info("Model loaded")
229 elif transform_type == "rgba_to_rgb": 477 elif transform_type == "rgba_to_rgb":
230 initial_transform = RGBAtoRGBTransform() 478 initial_transform = RGBAtoRGBTransform()
231 else: 479 else:
232 initial_transform = transforms.Lambda(lambda x: x.convert("RGB")) 480 initial_transform = transforms.Lambda(lambda x: x.convert("RGB"))
233 481
234 transform_list = [initial_transform, 482 # Handle GPFM separately as it has its own preprocessing
235 transforms.Resize(resize), 483 if model_name == "gpfm":
236 transforms.ToTensor()] 484 # For GPFM, combine initial transform with GPFM's custom transformer
237 if apply_normalization: 485 if transform_type in ["grayscale", "clahe", "edges", "rgba_to_rgb"]:
238 transform_list.append(transforms.Normalize(mean=normalize[0], 486 transform = transforms.Compose([
239 std=normalize[1])) 487 initial_transform,
240 transform = transforms.Compose(transform_list) 488 model.get_transformer(apply_normalization=apply_normalization)
489 ])
490 else:
491 transform = model.get_transformer(apply_normalization=apply_normalization)
492 else:
493 # Standard torchvision models
494 transform_list = [initial_transform,
495 transforms.Resize(resize),
496 transforms.ToTensor()]
497 if apply_normalization:
498 transform_list.append(transforms.Normalize(mean=normalize[0],
499 std=normalize[1]))
500 transform = transforms.Compose(transform_list)
241 501
242 class ImageDataset(Dataset): 502 class ImageDataset(Dataset):
243 def __init__(self, zip_file, file_list, transform=None): 503 def __init__(self, zip_file, file_list, transform=None):
244 self.zip_file = zip_file 504 self.zip_file = zip_file
245 self.file_list = file_list 505 self.file_list = file_list
276 # Try DataLoader with reduced resource usage 536 # Try DataLoader with reduced resource usage
277 dataset = ImageDataset(zip_file, file_list, transform=transform) 537 dataset = ImageDataset(zip_file, file_list, transform=transform)
278 dataloader = DataLoader( 538 dataloader = DataLoader(
279 dataset, 539 dataset,
280 batch_size=16, # Reduced for lower memory usage 540 batch_size=16, # Reduced for lower memory usage
281 num_workers=1, # Reduced to minimize shared memory 541 num_workers=0, # Fix multiprocessing issues with GPFM
282 shuffle=False, 542 shuffle=False,
283 pin_memory=True if device == "cuda" else False, 543 pin_memory=True if device == "cuda" else False,
284 collate_fn=collate_fn, 544 collate_fn=collate_fn,
285 ) 545 )
286 for images, names in dataloader: 546 for images, names in dataloader: