changeset 1:84f96c952c2c draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
author goeckslab
date Sun, 09 Nov 2025 19:03:21 +0000
parents 38333676a029
children
files Docker/Dockerfile pytorch_embedding.py pytorch_embedding.xml
diffstat 3 files changed, 331 insertions(+), 24 deletions(-) [+]
line wrap: on
line diff
--- a/Docker/Dockerfile	Thu Jun 19 23:33:23 2025 +0000
+++ b/Docker/Dockerfile	Sun Nov 09 19:03:21 2025 +0000
@@ -1,21 +1,48 @@
-# Use a lightweight Python 3.9 base image
-FROM python:3.9-slim
+# Use NVIDIA CUDA base image for GPU support
+FROM nvidia/cuda:11.8-devel-ubuntu20.04
 
-# Install system dependencies for OpenCV and other libraries in one layer
+# Set environment variables
+ENV DEBIAN_FRONTEND=noninteractive
+ENV PYTHONUNBUFFERED=1
+
+# Install system dependencies
 RUN apt-get update && apt-get install -y \
+    python3 \
+    python3-pip \
+    python3-dev \
     libgl1-mesa-glx \
     libglib2.0-0 \
+    git \
+    wget \
     && rm -rf /var/lib/apt/lists/*
 
-# Upgrade pip to the latest version
+# Create symbolic links for python
+RUN ln -s /usr/bin/python3 /usr/bin/python && \
+    ln -s /usr/bin/pip3 /usr/bin/pip
+
+# Upgrade pip
 RUN pip install --upgrade pip
 
-# Install PyTorch and torchvision CPU-only versions
+# Install PyTorch with CUDA support
 RUN pip install --no-cache-dir torch==2.0.0 torchvision==0.15.1 \
-    -f https://download.pytorch.org/whl/cpu/torch_stable.html
+    -f https://download.pytorch.org/whl/cu118/torch_stable.html
 
 # Install NumPy with a version compatible with PyTorch 2.0.0
 RUN pip install --no-cache-dir numpy==1.24.4
 
-# Install remaining Python dependencies
-RUN pip install --no-cache-dir Pillow opencv-python pandas fastparquet argparse logging multiprocessing
+# Install timm for model support (quote to prevent shell redirection)
+RUN pip install --no-cache-dir 'timm>=1.0.3'
+
+# Install remaining Python dependencies (exclude stdlib packages, add requests for GPFM)
+RUN pip install --no-cache-dir Pillow opencv-python pandas fastparquet requests
+
+# Install HuggingFace transformers for model loading
+RUN pip install --no-cache-dir transformers huggingface-hub
+
+# Set working directory
+WORKDIR /workspace
+
+# Create cache directory for HuggingFace models
+RUN mkdir -p /workspace/hf_cache
+ENV HF_HOME=/workspace/hf_cache
+ENV TORCH_HOME=/workspace/hf_cache
--- a/pytorch_embedding.py	Thu Jun 19 23:33:23 2025 +0000
+++ b/pytorch_embedding.py	Sun Nov 09 19:03:21 2025 +0000
@@ -18,7 +18,6 @@
 
 import argparse
 import csv
-import inspect
 import logging
 import os
 import zipfile
@@ -32,12 +31,22 @@
 from torch.utils.data import DataLoader, Dataset
 from torchvision import transforms
 
+# GPFM imports
+try:
+    import requests
+    GPFM_AVAILABLE = True
+except ImportError as e:
+    GPFM_AVAILABLE = False
+    logging.warning(f"GPFM dependencies not available: {e}")
+
 # Configure logging
 logging.basicConfig(
-    filename="/tmp/ludwig_embeddings.log",
-    filemode="a",
     format="%(asctime)s - %(levelname)s - %(message)s",
-    level=logging.DEBUG,
+    level=logging.INFO,
+    handlers=[
+        logging.StreamHandler(),  # Console output
+        logging.FileHandler("/tmp/ludwig_embeddings.log", mode="a")  # File output
+    ]
 )
 
 # Create a cache directory in the current working directory
@@ -49,6 +58,230 @@
     logging.error(f"Failed to create cache directory {cache_dir}: {e}")
     raise
 
+# GPFM DinoVisionTransformer Implementation
+
+
+class DinoVisionTransformer(torch.nn.Module):
+    """Simplified DinoVisionTransformer for GPFM."""
+
+    def __init__(self, img_size=224, patch_size=14, embed_dim=1024, depth=24, num_heads=16):
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.num_features = embed_dim
+        self.patch_size = patch_size
+
+        # Patch embedding
+        self.patch_embed = torch.nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
+        num_patches = (img_size // patch_size) ** 2
+
+        # Class token
+        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, embed_dim))
+
+        # Position embeddings
+        self.pos_embed = torch.nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+
+        # Transformer blocks (simplified)
+        self.blocks = torch.nn.ModuleList([
+            torch.nn.TransformerEncoderLayer(
+                d_model=embed_dim,
+                nhead=num_heads,
+                dim_feedforward=embed_dim * 4,
+                dropout=0.0,
+                batch_first=True
+            ) for _ in range(depth)
+        ])
+
+        # Layer norm
+        self.norm = torch.nn.LayerNorm(embed_dim)
+
+        # Initialize weights
+        torch.nn.init.trunc_normal_(self.pos_embed, std=0.02)
+        torch.nn.init.trunc_normal_(self.cls_token, std=0.02)
+
+    def forward(self, x):
+        B = x.shape[0]
+
+        # Patch embedding
+        x = self.patch_embed(x)  # B, embed_dim, H//patch_size, W//patch_size
+        x = x.flatten(2).transpose(1, 2)  # B, num_patches, embed_dim
+
+        # Add class token
+        cls_tokens = self.cls_token.expand(B, -1, -1)
+        x = torch.cat([cls_tokens, x], dim=1)
+
+        # Add position embeddings
+        x = x + self.pos_embed
+
+        # Apply transformer blocks
+        for block in self.blocks:
+            x = block(x)
+
+        # Apply layer norm and return class token
+        x = self.norm(x)
+        return x[:, 0]  # Return class token features
+
+
+# GPFM Model Implementation
+class GPFMModel(torch.nn.Module):
+    """GPFM (Generalizable Pathology Foundation Model) implementation."""
+
+    def __init__(self, device='cpu'):
+        super().__init__()
+        self.device = device
+        self.model = None
+        self.transformer = None
+        self.embed_dim = 1024  # GPFM uses 1024-dimensional embeddings
+        self._load_model()
+
+    def _download_weights(self, url, filepath):
+        """Download GPFM weights from the official repository."""
+        if os.path.exists(filepath):
+            logging.info(f"GPFM weights already exist at {filepath}")
+            return True
+
+        logging.info(f"Downloading GPFM weights from {url}")
+        try:
+            response = requests.get(url, stream=True, timeout=300)
+            response.raise_for_status()
+
+            os.makedirs(os.path.dirname(filepath), exist_ok=True)
+
+            # Get file size for progress tracking
+            total_size = int(response.headers.get('content-length', 0))
+            downloaded = 0
+
+            with open(filepath, 'wb') as f:
+                for chunk in response.iter_content(chunk_size=8192):
+                    if chunk:
+                        f.write(chunk)
+                        downloaded += len(chunk)
+                        if total_size > 0:
+                            progress = (downloaded / total_size) * 100
+                            if downloaded % (1024 * 1024 * 10) == 0:  # Log every 10MB
+                                logging.info(f"Downloaded {downloaded // (1024 * 1024)}MB / {total_size // (1024 * 1024)}MB ({progress:.1f}%)")
+
+            logging.info(f"GPFM weights downloaded successfully to {filepath}")
+            return True
+
+        except Exception as e:
+            logging.error(f"Failed to download GPFM weights: {e}")
+            if os.path.exists(filepath):
+                os.remove(filepath)  # Clean up partial download
+            return False
+
+    def _load_model(self):
+        """Load GPFM model with pretrained weights."""
+        try:
+            # Create models directory
+            models_dir = os.path.join(cache_dir, 'gpfm_models')
+            os.makedirs(models_dir, exist_ok=True)
+
+            # GPFM weights URL from official repository
+            weights_url = "https://github.com/birkhoffkiki/GPFM/releases/download/ckpt/GPFM.pth"
+            weights_path = os.path.join(models_dir, 'GPFM.pth')
+
+            # Create GPFM DinoVisionTransformer architecture
+            self.model = DinoVisionTransformer(
+                img_size=224,
+                patch_size=14,
+                embed_dim=1024,
+                depth=24,
+                num_heads=16
+            )
+
+            # Try to download and load GPFM weights
+            weights_loaded = False
+            if self._download_weights(weights_url, weights_path):
+                try:
+                    logging.info("Loading GPFM pretrained weights...")
+                    checkpoint = torch.load(weights_path, map_location=self.device)
+
+                    # Extract teacher model weights (GPFM format)
+                    if 'teacher' in checkpoint:
+                        state_dict = checkpoint['teacher']
+                        logging.info("Found 'teacher' key in checkpoint")
+                    else:
+                        state_dict = checkpoint
+                        logging.info("Using checkpoint directly")
+
+                    # Rename keys to match our simplified architecture
+                    new_state_dict = {}
+                    for k, v in state_dict.items():
+                        # Remove 'backbone.' prefix if present
+                        if k.startswith('backbone.'):
+                            k = k[9:]  # Remove 'backbone.'
+
+                        # Map GPFM keys to our simplified architecture
+                        if k in ['cls_token', 'pos_embed']:
+                            new_state_dict[k] = v
+                        elif k.startswith('patch_embed.proj.'):
+                            # Map patch embedding
+                            new_k = k.replace('patch_embed.proj.', 'patch_embed.')
+                            new_state_dict[new_k] = v
+                        elif k.startswith('blocks.') and 'norm' in k:
+                            # Map layer norms
+                            if k.endswith('.norm1.weight') or k.endswith('.norm1.bias'):
+                                # Skip intermediate norms for simplified model
+                                continue
+                            elif k.endswith('.norm2.weight') or k.endswith('.norm2.bias'):
+                                continue
+                        elif k == 'norm.weight' or k == 'norm.bias':
+                            new_state_dict[k] = v
+
+                    # Load compatible weights
+                    missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False)
+                    if missing_keys:
+                        logging.warning(f"Missing keys: {missing_keys[:5]}...")  # Show first 5
+                    if unexpected_keys:
+                        logging.warning(f"Unexpected keys: {unexpected_keys[:5]}...")  # Show first 5
+
+                    logging.info("GPFM pretrained weights loaded successfully")
+                    weights_loaded = True
+
+                except Exception as e:
+                    logging.warning(f"Could not load GPFM weights: {e}")
+
+            if not weights_loaded:
+                logging.info("Using randomly initialized GPFM architecture (no pretrained weights)")
+
+            self.model = self.model.to(self.device)
+            self.model.eval()
+
+            # GPFM preprocessing (based on official repository)
+            self.transformer = transforms.Compose([
+                transforms.Lambda(lambda x: x.convert("RGB")),  # Ensure RGB format
+                transforms.Resize((224, 224)),  # GPFM uses 224x224 (not 512x512 for features)
+                transforms.ToTensor(),
+                transforms.Normalize(
+                    mean=[0.485, 0.456, 0.406],  # ImageNet normalization
+                    std=[0.229, 0.224, 0.225]
+                )
+            ])
+
+            logging.info(f"GPFM model initialized successfully (embed_dim: {self.embed_dim})")
+
+        except Exception as e:
+            logging.error(f"Failed to initialize GPFM model: {e}")
+            raise
+
+    def forward(self, x):
+        """Forward pass through GPFM model."""
+        with torch.no_grad():
+            return self.model(x)
+
+    def get_transformer(self, apply_normalization=True):
+        """Get the preprocessing transformer for GPFM."""
+        if apply_normalization:
+            return self.transformer
+        else:
+            # Return transformer without normalization
+            return transforms.Compose([
+                transforms.Lambda(lambda x: x.convert("RGB")),
+                transforms.Resize((224, 224)),
+                transforms.ToTensor()
+            ])
+
+
 # Available models from torchvision
 AVAILABLE_MODELS = {
     name: getattr(models, name)
@@ -58,6 +291,10 @@
     ) and "weights" in signature(getattr(models, name)).parameters
 }
 
+# Add GPFM model if available
+if GPFM_AVAILABLE:
+    AVAILABLE_MODELS['gpfm'] = GPFMModel
+
 # Default resize and normalization settings for models
 MODEL_DEFAULTS = {
     "default": {"resize": (224, 224), "normalize": (
@@ -86,6 +323,10 @@
     "vit_b_32": {"resize": (224, 224), "normalize": (
         [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
     )},
+    "gpfm": {
+        "resize": (224, 224),
+        "normalize": ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+    },
 }
 
 for model, settings in MODEL_DEFAULTS.items():
@@ -152,7 +393,14 @@
             f"Unsupported model: {model_name}. \
             Available models: {list(AVAILABLE_MODELS.keys())}")
     try:
-        if "weights" in inspect.signature(
+        # Special handling for GPFM
+        if model_name == "gpfm":
+            model = AVAILABLE_MODELS[model_name](device=device)
+            logging.info("GPFM model loaded")
+            return model
+
+        # Standard torchvision models
+        if "weights" in signature(
                 AVAILABLE_MODELS[model_name]).parameters:
             model = AVAILABLE_MODELS[model_name](weights="DEFAULT").to(device)
         else:
@@ -231,13 +479,25 @@
     else:
         initial_transform = transforms.Lambda(lambda x: x.convert("RGB"))
 
-    transform_list = [initial_transform,
-                      transforms.Resize(resize),
-                      transforms.ToTensor()]
-    if apply_normalization:
-        transform_list.append(transforms.Normalize(mean=normalize[0],
-                                                   std=normalize[1]))
-    transform = transforms.Compose(transform_list)
+    # Handle GPFM separately as it has its own preprocessing
+    if model_name == "gpfm":
+        # For GPFM, combine initial transform with GPFM's custom transformer
+        if transform_type in ["grayscale", "clahe", "edges", "rgba_to_rgb"]:
+            transform = transforms.Compose([
+                initial_transform,
+                model.get_transformer(apply_normalization=apply_normalization)
+            ])
+        else:
+            transform = model.get_transformer(apply_normalization=apply_normalization)
+    else:
+        # Standard torchvision models
+        transform_list = [initial_transform,
+                          transforms.Resize(resize),
+                          transforms.ToTensor()]
+        if apply_normalization:
+            transform_list.append(transforms.Normalize(mean=normalize[0],
+                                                       std=normalize[1]))
+        transform = transforms.Compose(transform_list)
 
     class ImageDataset(Dataset):
         def __init__(self, zip_file, file_list, transform=None):
@@ -278,7 +538,7 @@
             dataloader = DataLoader(
                 dataset,
                 batch_size=16,  # Reduced for lower memory usage
-                num_workers=1,  # Reduced to minimize shared memory
+                num_workers=0,  # Fix multiprocessing issues with GPFM
                 shuffle=False,
                 pin_memory=True if device == "cuda" else False,
                 collate_fn=collate_fn,
--- a/pytorch_embedding.xml	Thu Jun 19 23:33:23 2025 +0000
+++ b/pytorch_embedding.xml	Sun Nov 09 19:03:21 2025 +0000
@@ -50,6 +50,7 @@
             <option value="efficientnet_v2_s">EfficientNetV2-S</option>
             <option value="efficientnet_v2_m">EfficientNetV2-M</option>
             <option value="efficientnet_v2_l">EfficientNetV2-L</option>
+            <option value="gpfm">GPFM (Generalizable Pathology Foundation Model)</option>
             <option value="googlenet">GoogLeNet</option>
             <option value="inception_v3">Inception-V3</option>
             <option value="mnasnet0_5">MNASNet-0.5</option>
@@ -114,20 +115,39 @@
                 </assert_contents>
             </output>
         </test>
+        <test>
+            <param name="input_zip" value="1_digit.zip" ftype="zip" />
+            <param name="model_name" value="gpfm" />
+            <param name="apply_normalization" value="true" />
+            <param name="transform_type" value="RGB" />
+            <output name="output_csv">
+                <assert_contents>
+                    <has_text text="sample_name" />
+                    <has_n_columns min="1" />
+                </assert_contents>
+            </output>
+        </test>
     </tests>
     <help>
         <![CDATA[
         **What it does**
-        This tool extracts image embeddings using a selected deep learning model.
+        This tool extracts image embeddings using a selected deep learning model, including specialized pathology models like GPFM.
 
         **Inputs**
         - A zip file containing images to process.
-        - A model selection for embedding extraction.
+        - A model selection for embedding extraction (includes GPFM for pathology images).
         - An option to apply normalization to the extracted embeddings.
         - A choice of image transformation type before processing.
 
+        **Models Available**
+        - Standard computer vision models (ResNet, EfficientNet, ViT, etc.)
+        - GPFM: Generalizable Pathology Foundation Model - specialized for medical/pathology images
+          * Automatically downloads 1.2GB pretrained weights on first use
+          * Uses DinoVisionTransformer architecture (1024-dimensional embeddings)
+          * Optimized for histopathology images at 224x224 resolution
+
         **Outputs**
         - A CSV file containing embeddings. Each row corresponds to an image, with the file name in the first column and embedding vectors in the subsequent columns.
         ]]>
     </help>
-</tool>
+</tool>
\ No newline at end of file