Mercurial > repos > goeckslab > extract_embeddings
changeset 1:84f96c952c2c draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
| author | goeckslab |
|---|---|
| date | Sun, 09 Nov 2025 19:03:21 +0000 |
| parents | 38333676a029 |
| children | |
| files | Docker/Dockerfile pytorch_embedding.py pytorch_embedding.xml |
| diffstat | 3 files changed, 331 insertions(+), 24 deletions(-) [+] |
line wrap: on
line diff
--- a/Docker/Dockerfile Thu Jun 19 23:33:23 2025 +0000 +++ b/Docker/Dockerfile Sun Nov 09 19:03:21 2025 +0000 @@ -1,21 +1,48 @@ -# Use a lightweight Python 3.9 base image -FROM python:3.9-slim +# Use NVIDIA CUDA base image for GPU support +FROM nvidia/cuda:11.8-devel-ubuntu20.04 -# Install system dependencies for OpenCV and other libraries in one layer +# Set environment variables +ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED=1 + +# Install system dependencies RUN apt-get update && apt-get install -y \ + python3 \ + python3-pip \ + python3-dev \ libgl1-mesa-glx \ libglib2.0-0 \ + git \ + wget \ && rm -rf /var/lib/apt/lists/* -# Upgrade pip to the latest version +# Create symbolic links for python +RUN ln -s /usr/bin/python3 /usr/bin/python && \ + ln -s /usr/bin/pip3 /usr/bin/pip + +# Upgrade pip RUN pip install --upgrade pip -# Install PyTorch and torchvision CPU-only versions +# Install PyTorch with CUDA support RUN pip install --no-cache-dir torch==2.0.0 torchvision==0.15.1 \ - -f https://download.pytorch.org/whl/cpu/torch_stable.html + -f https://download.pytorch.org/whl/cu118/torch_stable.html # Install NumPy with a version compatible with PyTorch 2.0.0 RUN pip install --no-cache-dir numpy==1.24.4 -# Install remaining Python dependencies -RUN pip install --no-cache-dir Pillow opencv-python pandas fastparquet argparse logging multiprocessing +# Install timm for model support (quote to prevent shell redirection) +RUN pip install --no-cache-dir 'timm>=1.0.3' + +# Install remaining Python dependencies (exclude stdlib packages, add requests for GPFM) +RUN pip install --no-cache-dir Pillow opencv-python pandas fastparquet requests + +# Install HuggingFace transformers for model loading +RUN pip install --no-cache-dir transformers huggingface-hub + +# Set working directory +WORKDIR /workspace + +# Create cache directory for HuggingFace models +RUN mkdir -p /workspace/hf_cache +ENV HF_HOME=/workspace/hf_cache +ENV TORCH_HOME=/workspace/hf_cache
--- a/pytorch_embedding.py Thu Jun 19 23:33:23 2025 +0000 +++ b/pytorch_embedding.py Sun Nov 09 19:03:21 2025 +0000 @@ -18,7 +18,6 @@ import argparse import csv -import inspect import logging import os import zipfile @@ -32,12 +31,22 @@ from torch.utils.data import DataLoader, Dataset from torchvision import transforms +# GPFM imports +try: + import requests + GPFM_AVAILABLE = True +except ImportError as e: + GPFM_AVAILABLE = False + logging.warning(f"GPFM dependencies not available: {e}") + # Configure logging logging.basicConfig( - filename="/tmp/ludwig_embeddings.log", - filemode="a", format="%(asctime)s - %(levelname)s - %(message)s", - level=logging.DEBUG, + level=logging.INFO, + handlers=[ + logging.StreamHandler(), # Console output + logging.FileHandler("/tmp/ludwig_embeddings.log", mode="a") # File output + ] ) # Create a cache directory in the current working directory @@ -49,6 +58,230 @@ logging.error(f"Failed to create cache directory {cache_dir}: {e}") raise +# GPFM DinoVisionTransformer Implementation + + +class DinoVisionTransformer(torch.nn.Module): + """Simplified DinoVisionTransformer for GPFM.""" + + def __init__(self, img_size=224, patch_size=14, embed_dim=1024, depth=24, num_heads=16): + super().__init__() + self.embed_dim = embed_dim + self.num_features = embed_dim + self.patch_size = patch_size + + # Patch embedding + self.patch_embed = torch.nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size) + num_patches = (img_size // patch_size) ** 2 + + # Class token + self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, embed_dim)) + + # Position embeddings + self.pos_embed = torch.nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + + # Transformer blocks (simplified) + self.blocks = torch.nn.ModuleList([ + torch.nn.TransformerEncoderLayer( + d_model=embed_dim, + nhead=num_heads, + dim_feedforward=embed_dim * 4, + dropout=0.0, + batch_first=True + ) for _ in range(depth) + ]) + + # Layer norm + self.norm = torch.nn.LayerNorm(embed_dim) + + # Initialize weights + torch.nn.init.trunc_normal_(self.pos_embed, std=0.02) + torch.nn.init.trunc_normal_(self.cls_token, std=0.02) + + def forward(self, x): + B = x.shape[0] + + # Patch embedding + x = self.patch_embed(x) # B, embed_dim, H//patch_size, W//patch_size + x = x.flatten(2).transpose(1, 2) # B, num_patches, embed_dim + + # Add class token + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat([cls_tokens, x], dim=1) + + # Add position embeddings + x = x + self.pos_embed + + # Apply transformer blocks + for block in self.blocks: + x = block(x) + + # Apply layer norm and return class token + x = self.norm(x) + return x[:, 0] # Return class token features + + +# GPFM Model Implementation +class GPFMModel(torch.nn.Module): + """GPFM (Generalizable Pathology Foundation Model) implementation.""" + + def __init__(self, device='cpu'): + super().__init__() + self.device = device + self.model = None + self.transformer = None + self.embed_dim = 1024 # GPFM uses 1024-dimensional embeddings + self._load_model() + + def _download_weights(self, url, filepath): + """Download GPFM weights from the official repository.""" + if os.path.exists(filepath): + logging.info(f"GPFM weights already exist at {filepath}") + return True + + logging.info(f"Downloading GPFM weights from {url}") + try: + response = requests.get(url, stream=True, timeout=300) + response.raise_for_status() + + os.makedirs(os.path.dirname(filepath), exist_ok=True) + + # Get file size for progress tracking + total_size = int(response.headers.get('content-length', 0)) + downloaded = 0 + + with open(filepath, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + downloaded += len(chunk) + if total_size > 0: + progress = (downloaded / total_size) * 100 + if downloaded % (1024 * 1024 * 10) == 0: # Log every 10MB + logging.info(f"Downloaded {downloaded // (1024 * 1024)}MB / {total_size // (1024 * 1024)}MB ({progress:.1f}%)") + + logging.info(f"GPFM weights downloaded successfully to {filepath}") + return True + + except Exception as e: + logging.error(f"Failed to download GPFM weights: {e}") + if os.path.exists(filepath): + os.remove(filepath) # Clean up partial download + return False + + def _load_model(self): + """Load GPFM model with pretrained weights.""" + try: + # Create models directory + models_dir = os.path.join(cache_dir, 'gpfm_models') + os.makedirs(models_dir, exist_ok=True) + + # GPFM weights URL from official repository + weights_url = "https://github.com/birkhoffkiki/GPFM/releases/download/ckpt/GPFM.pth" + weights_path = os.path.join(models_dir, 'GPFM.pth') + + # Create GPFM DinoVisionTransformer architecture + self.model = DinoVisionTransformer( + img_size=224, + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16 + ) + + # Try to download and load GPFM weights + weights_loaded = False + if self._download_weights(weights_url, weights_path): + try: + logging.info("Loading GPFM pretrained weights...") + checkpoint = torch.load(weights_path, map_location=self.device) + + # Extract teacher model weights (GPFM format) + if 'teacher' in checkpoint: + state_dict = checkpoint['teacher'] + logging.info("Found 'teacher' key in checkpoint") + else: + state_dict = checkpoint + logging.info("Using checkpoint directly") + + # Rename keys to match our simplified architecture + new_state_dict = {} + for k, v in state_dict.items(): + # Remove 'backbone.' prefix if present + if k.startswith('backbone.'): + k = k[9:] # Remove 'backbone.' + + # Map GPFM keys to our simplified architecture + if k in ['cls_token', 'pos_embed']: + new_state_dict[k] = v + elif k.startswith('patch_embed.proj.'): + # Map patch embedding + new_k = k.replace('patch_embed.proj.', 'patch_embed.') + new_state_dict[new_k] = v + elif k.startswith('blocks.') and 'norm' in k: + # Map layer norms + if k.endswith('.norm1.weight') or k.endswith('.norm1.bias'): + # Skip intermediate norms for simplified model + continue + elif k.endswith('.norm2.weight') or k.endswith('.norm2.bias'): + continue + elif k == 'norm.weight' or k == 'norm.bias': + new_state_dict[k] = v + + # Load compatible weights + missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False) + if missing_keys: + logging.warning(f"Missing keys: {missing_keys[:5]}...") # Show first 5 + if unexpected_keys: + logging.warning(f"Unexpected keys: {unexpected_keys[:5]}...") # Show first 5 + + logging.info("GPFM pretrained weights loaded successfully") + weights_loaded = True + + except Exception as e: + logging.warning(f"Could not load GPFM weights: {e}") + + if not weights_loaded: + logging.info("Using randomly initialized GPFM architecture (no pretrained weights)") + + self.model = self.model.to(self.device) + self.model.eval() + + # GPFM preprocessing (based on official repository) + self.transformer = transforms.Compose([ + transforms.Lambda(lambda x: x.convert("RGB")), # Ensure RGB format + transforms.Resize((224, 224)), # GPFM uses 224x224 (not 512x512 for features) + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], # ImageNet normalization + std=[0.229, 0.224, 0.225] + ) + ]) + + logging.info(f"GPFM model initialized successfully (embed_dim: {self.embed_dim})") + + except Exception as e: + logging.error(f"Failed to initialize GPFM model: {e}") + raise + + def forward(self, x): + """Forward pass through GPFM model.""" + with torch.no_grad(): + return self.model(x) + + def get_transformer(self, apply_normalization=True): + """Get the preprocessing transformer for GPFM.""" + if apply_normalization: + return self.transformer + else: + # Return transformer without normalization + return transforms.Compose([ + transforms.Lambda(lambda x: x.convert("RGB")), + transforms.Resize((224, 224)), + transforms.ToTensor() + ]) + + # Available models from torchvision AVAILABLE_MODELS = { name: getattr(models, name) @@ -58,6 +291,10 @@ ) and "weights" in signature(getattr(models, name)).parameters } +# Add GPFM model if available +if GPFM_AVAILABLE: + AVAILABLE_MODELS['gpfm'] = GPFMModel + # Default resize and normalization settings for models MODEL_DEFAULTS = { "default": {"resize": (224, 224), "normalize": ( @@ -86,6 +323,10 @@ "vit_b_32": {"resize": (224, 224), "normalize": ( [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] )}, + "gpfm": { + "resize": (224, 224), + "normalize": ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + }, } for model, settings in MODEL_DEFAULTS.items(): @@ -152,7 +393,14 @@ f"Unsupported model: {model_name}. \ Available models: {list(AVAILABLE_MODELS.keys())}") try: - if "weights" in inspect.signature( + # Special handling for GPFM + if model_name == "gpfm": + model = AVAILABLE_MODELS[model_name](device=device) + logging.info("GPFM model loaded") + return model + + # Standard torchvision models + if "weights" in signature( AVAILABLE_MODELS[model_name]).parameters: model = AVAILABLE_MODELS[model_name](weights="DEFAULT").to(device) else: @@ -231,13 +479,25 @@ else: initial_transform = transforms.Lambda(lambda x: x.convert("RGB")) - transform_list = [initial_transform, - transforms.Resize(resize), - transforms.ToTensor()] - if apply_normalization: - transform_list.append(transforms.Normalize(mean=normalize[0], - std=normalize[1])) - transform = transforms.Compose(transform_list) + # Handle GPFM separately as it has its own preprocessing + if model_name == "gpfm": + # For GPFM, combine initial transform with GPFM's custom transformer + if transform_type in ["grayscale", "clahe", "edges", "rgba_to_rgb"]: + transform = transforms.Compose([ + initial_transform, + model.get_transformer(apply_normalization=apply_normalization) + ]) + else: + transform = model.get_transformer(apply_normalization=apply_normalization) + else: + # Standard torchvision models + transform_list = [initial_transform, + transforms.Resize(resize), + transforms.ToTensor()] + if apply_normalization: + transform_list.append(transforms.Normalize(mean=normalize[0], + std=normalize[1])) + transform = transforms.Compose(transform_list) class ImageDataset(Dataset): def __init__(self, zip_file, file_list, transform=None): @@ -278,7 +538,7 @@ dataloader = DataLoader( dataset, batch_size=16, # Reduced for lower memory usage - num_workers=1, # Reduced to minimize shared memory + num_workers=0, # Fix multiprocessing issues with GPFM shuffle=False, pin_memory=True if device == "cuda" else False, collate_fn=collate_fn,
--- a/pytorch_embedding.xml Thu Jun 19 23:33:23 2025 +0000 +++ b/pytorch_embedding.xml Sun Nov 09 19:03:21 2025 +0000 @@ -50,6 +50,7 @@ <option value="efficientnet_v2_s">EfficientNetV2-S</option> <option value="efficientnet_v2_m">EfficientNetV2-M</option> <option value="efficientnet_v2_l">EfficientNetV2-L</option> + <option value="gpfm">GPFM (Generalizable Pathology Foundation Model)</option> <option value="googlenet">GoogLeNet</option> <option value="inception_v3">Inception-V3</option> <option value="mnasnet0_5">MNASNet-0.5</option> @@ -114,20 +115,39 @@ </assert_contents> </output> </test> + <test> + <param name="input_zip" value="1_digit.zip" ftype="zip" /> + <param name="model_name" value="gpfm" /> + <param name="apply_normalization" value="true" /> + <param name="transform_type" value="RGB" /> + <output name="output_csv"> + <assert_contents> + <has_text text="sample_name" /> + <has_n_columns min="1" /> + </assert_contents> + </output> + </test> </tests> <help> <![CDATA[ **What it does** - This tool extracts image embeddings using a selected deep learning model. + This tool extracts image embeddings using a selected deep learning model, including specialized pathology models like GPFM. **Inputs** - A zip file containing images to process. - - A model selection for embedding extraction. + - A model selection for embedding extraction (includes GPFM for pathology images). - An option to apply normalization to the extracted embeddings. - A choice of image transformation type before processing. + **Models Available** + - Standard computer vision models (ResNet, EfficientNet, ViT, etc.) + - GPFM: Generalizable Pathology Foundation Model - specialized for medical/pathology images + * Automatically downloads 1.2GB pretrained weights on first use + * Uses DinoVisionTransformer architecture (1024-dimensional embeddings) + * Optimized for histopathology images at 224x224 resolution + **Outputs** - A CSV file containing embeddings. Each row corresponds to an image, with the file name in the first column and embedding vectors in the subsequent columns. ]]> </help> -</tool> +</tool> \ No newline at end of file
