Mercurial > repos > goeckslab > extract_embeddings
comparison pytorch_embedding.py @ 1:84f96c952c2c draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
| author | goeckslab |
|---|---|
| date | Sun, 09 Nov 2025 19:03:21 +0000 |
| parents | 38333676a029 |
| children |
comparison
equal
deleted
inserted
replaced
| 0:38333676a029 | 1:84f96c952c2c |
|---|---|
| 16 and conversion. | 16 and conversion. |
| 17 """ | 17 """ |
| 18 | 18 |
| 19 import argparse | 19 import argparse |
| 20 import csv | 20 import csv |
| 21 import inspect | |
| 22 import logging | 21 import logging |
| 23 import os | 22 import os |
| 24 import zipfile | 23 import zipfile |
| 25 from inspect import signature | 24 from inspect import signature |
| 26 | 25 |
| 30 import torchvision.models as models | 29 import torchvision.models as models |
| 31 from PIL import Image | 30 from PIL import Image |
| 32 from torch.utils.data import DataLoader, Dataset | 31 from torch.utils.data import DataLoader, Dataset |
| 33 from torchvision import transforms | 32 from torchvision import transforms |
| 34 | 33 |
| 34 # GPFM imports | |
| 35 try: | |
| 36 import requests | |
| 37 GPFM_AVAILABLE = True | |
| 38 except ImportError as e: | |
| 39 GPFM_AVAILABLE = False | |
| 40 logging.warning(f"GPFM dependencies not available: {e}") | |
| 41 | |
| 35 # Configure logging | 42 # Configure logging |
| 36 logging.basicConfig( | 43 logging.basicConfig( |
| 37 filename="/tmp/ludwig_embeddings.log", | |
| 38 filemode="a", | |
| 39 format="%(asctime)s - %(levelname)s - %(message)s", | 44 format="%(asctime)s - %(levelname)s - %(message)s", |
| 40 level=logging.DEBUG, | 45 level=logging.INFO, |
| 46 handlers=[ | |
| 47 logging.StreamHandler(), # Console output | |
| 48 logging.FileHandler("/tmp/ludwig_embeddings.log", mode="a") # File output | |
| 49 ] | |
| 41 ) | 50 ) |
| 42 | 51 |
| 43 # Create a cache directory in the current working directory | 52 # Create a cache directory in the current working directory |
| 44 cache_dir = os.path.join(os.getcwd(), 'hf_cache') | 53 cache_dir = os.path.join(os.getcwd(), 'hf_cache') |
| 45 try: | 54 try: |
| 47 logging.info(f"Cache directory created: {cache_dir}, writable: {os.access(cache_dir, os.W_OK)}") | 56 logging.info(f"Cache directory created: {cache_dir}, writable: {os.access(cache_dir, os.W_OK)}") |
| 48 except OSError as e: | 57 except OSError as e: |
| 49 logging.error(f"Failed to create cache directory {cache_dir}: {e}") | 58 logging.error(f"Failed to create cache directory {cache_dir}: {e}") |
| 50 raise | 59 raise |
| 51 | 60 |
| 61 # GPFM DinoVisionTransformer Implementation | |
| 62 | |
| 63 | |
| 64 class DinoVisionTransformer(torch.nn.Module): | |
| 65 """Simplified DinoVisionTransformer for GPFM.""" | |
| 66 | |
| 67 def __init__(self, img_size=224, patch_size=14, embed_dim=1024, depth=24, num_heads=16): | |
| 68 super().__init__() | |
| 69 self.embed_dim = embed_dim | |
| 70 self.num_features = embed_dim | |
| 71 self.patch_size = patch_size | |
| 72 | |
| 73 # Patch embedding | |
| 74 self.patch_embed = torch.nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size) | |
| 75 num_patches = (img_size // patch_size) ** 2 | |
| 76 | |
| 77 # Class token | |
| 78 self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
| 79 | |
| 80 # Position embeddings | |
| 81 self.pos_embed = torch.nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) | |
| 82 | |
| 83 # Transformer blocks (simplified) | |
| 84 self.blocks = torch.nn.ModuleList([ | |
| 85 torch.nn.TransformerEncoderLayer( | |
| 86 d_model=embed_dim, | |
| 87 nhead=num_heads, | |
| 88 dim_feedforward=embed_dim * 4, | |
| 89 dropout=0.0, | |
| 90 batch_first=True | |
| 91 ) for _ in range(depth) | |
| 92 ]) | |
| 93 | |
| 94 # Layer norm | |
| 95 self.norm = torch.nn.LayerNorm(embed_dim) | |
| 96 | |
| 97 # Initialize weights | |
| 98 torch.nn.init.trunc_normal_(self.pos_embed, std=0.02) | |
| 99 torch.nn.init.trunc_normal_(self.cls_token, std=0.02) | |
| 100 | |
| 101 def forward(self, x): | |
| 102 B = x.shape[0] | |
| 103 | |
| 104 # Patch embedding | |
| 105 x = self.patch_embed(x) # B, embed_dim, H//patch_size, W//patch_size | |
| 106 x = x.flatten(2).transpose(1, 2) # B, num_patches, embed_dim | |
| 107 | |
| 108 # Add class token | |
| 109 cls_tokens = self.cls_token.expand(B, -1, -1) | |
| 110 x = torch.cat([cls_tokens, x], dim=1) | |
| 111 | |
| 112 # Add position embeddings | |
| 113 x = x + self.pos_embed | |
| 114 | |
| 115 # Apply transformer blocks | |
| 116 for block in self.blocks: | |
| 117 x = block(x) | |
| 118 | |
| 119 # Apply layer norm and return class token | |
| 120 x = self.norm(x) | |
| 121 return x[:, 0] # Return class token features | |
| 122 | |
| 123 | |
| 124 # GPFM Model Implementation | |
| 125 class GPFMModel(torch.nn.Module): | |
| 126 """GPFM (Generalizable Pathology Foundation Model) implementation.""" | |
| 127 | |
| 128 def __init__(self, device='cpu'): | |
| 129 super().__init__() | |
| 130 self.device = device | |
| 131 self.model = None | |
| 132 self.transformer = None | |
| 133 self.embed_dim = 1024 # GPFM uses 1024-dimensional embeddings | |
| 134 self._load_model() | |
| 135 | |
| 136 def _download_weights(self, url, filepath): | |
| 137 """Download GPFM weights from the official repository.""" | |
| 138 if os.path.exists(filepath): | |
| 139 logging.info(f"GPFM weights already exist at {filepath}") | |
| 140 return True | |
| 141 | |
| 142 logging.info(f"Downloading GPFM weights from {url}") | |
| 143 try: | |
| 144 response = requests.get(url, stream=True, timeout=300) | |
| 145 response.raise_for_status() | |
| 146 | |
| 147 os.makedirs(os.path.dirname(filepath), exist_ok=True) | |
| 148 | |
| 149 # Get file size for progress tracking | |
| 150 total_size = int(response.headers.get('content-length', 0)) | |
| 151 downloaded = 0 | |
| 152 | |
| 153 with open(filepath, 'wb') as f: | |
| 154 for chunk in response.iter_content(chunk_size=8192): | |
| 155 if chunk: | |
| 156 f.write(chunk) | |
| 157 downloaded += len(chunk) | |
| 158 if total_size > 0: | |
| 159 progress = (downloaded / total_size) * 100 | |
| 160 if downloaded % (1024 * 1024 * 10) == 0: # Log every 10MB | |
| 161 logging.info(f"Downloaded {downloaded // (1024 * 1024)}MB / {total_size // (1024 * 1024)}MB ({progress:.1f}%)") | |
| 162 | |
| 163 logging.info(f"GPFM weights downloaded successfully to {filepath}") | |
| 164 return True | |
| 165 | |
| 166 except Exception as e: | |
| 167 logging.error(f"Failed to download GPFM weights: {e}") | |
| 168 if os.path.exists(filepath): | |
| 169 os.remove(filepath) # Clean up partial download | |
| 170 return False | |
| 171 | |
| 172 def _load_model(self): | |
| 173 """Load GPFM model with pretrained weights.""" | |
| 174 try: | |
| 175 # Create models directory | |
| 176 models_dir = os.path.join(cache_dir, 'gpfm_models') | |
| 177 os.makedirs(models_dir, exist_ok=True) | |
| 178 | |
| 179 # GPFM weights URL from official repository | |
| 180 weights_url = "https://github.com/birkhoffkiki/GPFM/releases/download/ckpt/GPFM.pth" | |
| 181 weights_path = os.path.join(models_dir, 'GPFM.pth') | |
| 182 | |
| 183 # Create GPFM DinoVisionTransformer architecture | |
| 184 self.model = DinoVisionTransformer( | |
| 185 img_size=224, | |
| 186 patch_size=14, | |
| 187 embed_dim=1024, | |
| 188 depth=24, | |
| 189 num_heads=16 | |
| 190 ) | |
| 191 | |
| 192 # Try to download and load GPFM weights | |
| 193 weights_loaded = False | |
| 194 if self._download_weights(weights_url, weights_path): | |
| 195 try: | |
| 196 logging.info("Loading GPFM pretrained weights...") | |
| 197 checkpoint = torch.load(weights_path, map_location=self.device) | |
| 198 | |
| 199 # Extract teacher model weights (GPFM format) | |
| 200 if 'teacher' in checkpoint: | |
| 201 state_dict = checkpoint['teacher'] | |
| 202 logging.info("Found 'teacher' key in checkpoint") | |
| 203 else: | |
| 204 state_dict = checkpoint | |
| 205 logging.info("Using checkpoint directly") | |
| 206 | |
| 207 # Rename keys to match our simplified architecture | |
| 208 new_state_dict = {} | |
| 209 for k, v in state_dict.items(): | |
| 210 # Remove 'backbone.' prefix if present | |
| 211 if k.startswith('backbone.'): | |
| 212 k = k[9:] # Remove 'backbone.' | |
| 213 | |
| 214 # Map GPFM keys to our simplified architecture | |
| 215 if k in ['cls_token', 'pos_embed']: | |
| 216 new_state_dict[k] = v | |
| 217 elif k.startswith('patch_embed.proj.'): | |
| 218 # Map patch embedding | |
| 219 new_k = k.replace('patch_embed.proj.', 'patch_embed.') | |
| 220 new_state_dict[new_k] = v | |
| 221 elif k.startswith('blocks.') and 'norm' in k: | |
| 222 # Map layer norms | |
| 223 if k.endswith('.norm1.weight') or k.endswith('.norm1.bias'): | |
| 224 # Skip intermediate norms for simplified model | |
| 225 continue | |
| 226 elif k.endswith('.norm2.weight') or k.endswith('.norm2.bias'): | |
| 227 continue | |
| 228 elif k == 'norm.weight' or k == 'norm.bias': | |
| 229 new_state_dict[k] = v | |
| 230 | |
| 231 # Load compatible weights | |
| 232 missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False) | |
| 233 if missing_keys: | |
| 234 logging.warning(f"Missing keys: {missing_keys[:5]}...") # Show first 5 | |
| 235 if unexpected_keys: | |
| 236 logging.warning(f"Unexpected keys: {unexpected_keys[:5]}...") # Show first 5 | |
| 237 | |
| 238 logging.info("GPFM pretrained weights loaded successfully") | |
| 239 weights_loaded = True | |
| 240 | |
| 241 except Exception as e: | |
| 242 logging.warning(f"Could not load GPFM weights: {e}") | |
| 243 | |
| 244 if not weights_loaded: | |
| 245 logging.info("Using randomly initialized GPFM architecture (no pretrained weights)") | |
| 246 | |
| 247 self.model = self.model.to(self.device) | |
| 248 self.model.eval() | |
| 249 | |
| 250 # GPFM preprocessing (based on official repository) | |
| 251 self.transformer = transforms.Compose([ | |
| 252 transforms.Lambda(lambda x: x.convert("RGB")), # Ensure RGB format | |
| 253 transforms.Resize((224, 224)), # GPFM uses 224x224 (not 512x512 for features) | |
| 254 transforms.ToTensor(), | |
| 255 transforms.Normalize( | |
| 256 mean=[0.485, 0.456, 0.406], # ImageNet normalization | |
| 257 std=[0.229, 0.224, 0.225] | |
| 258 ) | |
| 259 ]) | |
| 260 | |
| 261 logging.info(f"GPFM model initialized successfully (embed_dim: {self.embed_dim})") | |
| 262 | |
| 263 except Exception as e: | |
| 264 logging.error(f"Failed to initialize GPFM model: {e}") | |
| 265 raise | |
| 266 | |
| 267 def forward(self, x): | |
| 268 """Forward pass through GPFM model.""" | |
| 269 with torch.no_grad(): | |
| 270 return self.model(x) | |
| 271 | |
| 272 def get_transformer(self, apply_normalization=True): | |
| 273 """Get the preprocessing transformer for GPFM.""" | |
| 274 if apply_normalization: | |
| 275 return self.transformer | |
| 276 else: | |
| 277 # Return transformer without normalization | |
| 278 return transforms.Compose([ | |
| 279 transforms.Lambda(lambda x: x.convert("RGB")), | |
| 280 transforms.Resize((224, 224)), | |
| 281 transforms.ToTensor() | |
| 282 ]) | |
| 283 | |
| 284 | |
| 52 # Available models from torchvision | 285 # Available models from torchvision |
| 53 AVAILABLE_MODELS = { | 286 AVAILABLE_MODELS = { |
| 54 name: getattr(models, name) | 287 name: getattr(models, name) |
| 55 for name in dir(models) | 288 for name in dir(models) |
| 56 if callable( | 289 if callable( |
| 57 getattr(models, name) | 290 getattr(models, name) |
| 58 ) and "weights" in signature(getattr(models, name)).parameters | 291 ) and "weights" in signature(getattr(models, name)).parameters |
| 59 } | 292 } |
| 293 | |
| 294 # Add GPFM model if available | |
| 295 if GPFM_AVAILABLE: | |
| 296 AVAILABLE_MODELS['gpfm'] = GPFMModel | |
| 60 | 297 |
| 61 # Default resize and normalization settings for models | 298 # Default resize and normalization settings for models |
| 62 MODEL_DEFAULTS = { | 299 MODEL_DEFAULTS = { |
| 63 "default": {"resize": (224, 224), "normalize": ( | 300 "default": {"resize": (224, 224), "normalize": ( |
| 64 [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] | 301 [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] |
| 84 [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] | 321 [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] |
| 85 )}, | 322 )}, |
| 86 "vit_b_32": {"resize": (224, 224), "normalize": ( | 323 "vit_b_32": {"resize": (224, 224), "normalize": ( |
| 87 [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] | 324 [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] |
| 88 )}, | 325 )}, |
| 326 "gpfm": { | |
| 327 "resize": (224, 224), | |
| 328 "normalize": ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| 329 }, | |
| 89 } | 330 } |
| 90 | 331 |
| 91 for model, settings in MODEL_DEFAULTS.items(): | 332 for model, settings in MODEL_DEFAULTS.items(): |
| 92 if "normalize" not in settings: | 333 if "normalize" not in settings: |
| 93 settings["normalize"] = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | 334 settings["normalize"] = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| 150 if model_name not in AVAILABLE_MODELS: | 391 if model_name not in AVAILABLE_MODELS: |
| 151 raise ValueError( | 392 raise ValueError( |
| 152 f"Unsupported model: {model_name}. \ | 393 f"Unsupported model: {model_name}. \ |
| 153 Available models: {list(AVAILABLE_MODELS.keys())}") | 394 Available models: {list(AVAILABLE_MODELS.keys())}") |
| 154 try: | 395 try: |
| 155 if "weights" in inspect.signature( | 396 # Special handling for GPFM |
| 397 if model_name == "gpfm": | |
| 398 model = AVAILABLE_MODELS[model_name](device=device) | |
| 399 logging.info("GPFM model loaded") | |
| 400 return model | |
| 401 | |
| 402 # Standard torchvision models | |
| 403 if "weights" in signature( | |
| 156 AVAILABLE_MODELS[model_name]).parameters: | 404 AVAILABLE_MODELS[model_name]).parameters: |
| 157 model = AVAILABLE_MODELS[model_name](weights="DEFAULT").to(device) | 405 model = AVAILABLE_MODELS[model_name](weights="DEFAULT").to(device) |
| 158 else: | 406 else: |
| 159 model = AVAILABLE_MODELS[model_name]().to(device) | 407 model = AVAILABLE_MODELS[model_name]().to(device) |
| 160 logging.info("Model loaded") | 408 logging.info("Model loaded") |
| 229 elif transform_type == "rgba_to_rgb": | 477 elif transform_type == "rgba_to_rgb": |
| 230 initial_transform = RGBAtoRGBTransform() | 478 initial_transform = RGBAtoRGBTransform() |
| 231 else: | 479 else: |
| 232 initial_transform = transforms.Lambda(lambda x: x.convert("RGB")) | 480 initial_transform = transforms.Lambda(lambda x: x.convert("RGB")) |
| 233 | 481 |
| 234 transform_list = [initial_transform, | 482 # Handle GPFM separately as it has its own preprocessing |
| 235 transforms.Resize(resize), | 483 if model_name == "gpfm": |
| 236 transforms.ToTensor()] | 484 # For GPFM, combine initial transform with GPFM's custom transformer |
| 237 if apply_normalization: | 485 if transform_type in ["grayscale", "clahe", "edges", "rgba_to_rgb"]: |
| 238 transform_list.append(transforms.Normalize(mean=normalize[0], | 486 transform = transforms.Compose([ |
| 239 std=normalize[1])) | 487 initial_transform, |
| 240 transform = transforms.Compose(transform_list) | 488 model.get_transformer(apply_normalization=apply_normalization) |
| 489 ]) | |
| 490 else: | |
| 491 transform = model.get_transformer(apply_normalization=apply_normalization) | |
| 492 else: | |
| 493 # Standard torchvision models | |
| 494 transform_list = [initial_transform, | |
| 495 transforms.Resize(resize), | |
| 496 transforms.ToTensor()] | |
| 497 if apply_normalization: | |
| 498 transform_list.append(transforms.Normalize(mean=normalize[0], | |
| 499 std=normalize[1])) | |
| 500 transform = transforms.Compose(transform_list) | |
| 241 | 501 |
| 242 class ImageDataset(Dataset): | 502 class ImageDataset(Dataset): |
| 243 def __init__(self, zip_file, file_list, transform=None): | 503 def __init__(self, zip_file, file_list, transform=None): |
| 244 self.zip_file = zip_file | 504 self.zip_file = zip_file |
| 245 self.file_list = file_list | 505 self.file_list = file_list |
| 276 # Try DataLoader with reduced resource usage | 536 # Try DataLoader with reduced resource usage |
| 277 dataset = ImageDataset(zip_file, file_list, transform=transform) | 537 dataset = ImageDataset(zip_file, file_list, transform=transform) |
| 278 dataloader = DataLoader( | 538 dataloader = DataLoader( |
| 279 dataset, | 539 dataset, |
| 280 batch_size=16, # Reduced for lower memory usage | 540 batch_size=16, # Reduced for lower memory usage |
| 281 num_workers=1, # Reduced to minimize shared memory | 541 num_workers=0, # Fix multiprocessing issues with GPFM |
| 282 shuffle=False, | 542 shuffle=False, |
| 283 pin_memory=True if device == "cuda" else False, | 543 pin_memory=True if device == "cuda" else False, |
| 284 collate_fn=collate_fn, | 544 collate_fn=collate_fn, |
| 285 ) | 545 ) |
| 286 for images, names in dataloader: | 546 for images, names in dataloader: |
