Mercurial > repos > goeckslab > extract_embeddings
view pytorch_embedding.py @ 1:84f96c952c2c draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
| author | goeckslab |
|---|---|
| date | Sun, 09 Nov 2025 19:03:21 +0000 |
| parents | 38333676a029 |
| children |
line wrap: on
line source
""" This module provides functionality to extract image embeddings using a specified pretrained model from the torchvision library. It includes functions to: - List image files directly from a ZIP file without extraction. - Apply model-specific preprocessing and transformations. - Extract embeddings using various models. - Save the resulting embeddings into a CSV file. Modules required: - argparse: For command-line argument parsing. - os, csv, zipfile: For file handling (ZIP file reading, CSV writing). - inspect: For inspecting function signatures and models. - torch, torchvision: For loading and using pretrained models to extract embeddings. - PIL, cv2: For image processing tasks such as resizing, normalization, and conversion. """ import argparse import csv import logging import os import zipfile from inspect import signature import cv2 import numpy as np import torch import torchvision.models as models from PIL import Image from torch.utils.data import DataLoader, Dataset from torchvision import transforms # GPFM imports try: import requests GPFM_AVAILABLE = True except ImportError as e: GPFM_AVAILABLE = False logging.warning(f"GPFM dependencies not available: {e}") # Configure logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO, handlers=[ logging.StreamHandler(), # Console output logging.FileHandler("/tmp/ludwig_embeddings.log", mode="a") # File output ] ) # Create a cache directory in the current working directory cache_dir = os.path.join(os.getcwd(), 'hf_cache') try: os.makedirs(cache_dir, exist_ok=True) logging.info(f"Cache directory created: {cache_dir}, writable: {os.access(cache_dir, os.W_OK)}") except OSError as e: logging.error(f"Failed to create cache directory {cache_dir}: {e}") raise # GPFM DinoVisionTransformer Implementation class DinoVisionTransformer(torch.nn.Module): """Simplified DinoVisionTransformer for GPFM.""" def __init__(self, img_size=224, patch_size=14, embed_dim=1024, depth=24, num_heads=16): super().__init__() self.embed_dim = embed_dim self.num_features = embed_dim self.patch_size = patch_size # Patch embedding self.patch_embed = torch.nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size) num_patches = (img_size // patch_size) ** 2 # Class token self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, embed_dim)) # Position embeddings self.pos_embed = torch.nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) # Transformer blocks (simplified) self.blocks = torch.nn.ModuleList([ torch.nn.TransformerEncoderLayer( d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim * 4, dropout=0.0, batch_first=True ) for _ in range(depth) ]) # Layer norm self.norm = torch.nn.LayerNorm(embed_dim) # Initialize weights torch.nn.init.trunc_normal_(self.pos_embed, std=0.02) torch.nn.init.trunc_normal_(self.cls_token, std=0.02) def forward(self, x): B = x.shape[0] # Patch embedding x = self.patch_embed(x) # B, embed_dim, H//patch_size, W//patch_size x = x.flatten(2).transpose(1, 2) # B, num_patches, embed_dim # Add class token cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat([cls_tokens, x], dim=1) # Add position embeddings x = x + self.pos_embed # Apply transformer blocks for block in self.blocks: x = block(x) # Apply layer norm and return class token x = self.norm(x) return x[:, 0] # Return class token features # GPFM Model Implementation class GPFMModel(torch.nn.Module): """GPFM (Generalizable Pathology Foundation Model) implementation.""" def __init__(self, device='cpu'): super().__init__() self.device = device self.model = None self.transformer = None self.embed_dim = 1024 # GPFM uses 1024-dimensional embeddings self._load_model() def _download_weights(self, url, filepath): """Download GPFM weights from the official repository.""" if os.path.exists(filepath): logging.info(f"GPFM weights already exist at {filepath}") return True logging.info(f"Downloading GPFM weights from {url}") try: response = requests.get(url, stream=True, timeout=300) response.raise_for_status() os.makedirs(os.path.dirname(filepath), exist_ok=True) # Get file size for progress tracking total_size = int(response.headers.get('content-length', 0)) downloaded = 0 with open(filepath, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): if chunk: f.write(chunk) downloaded += len(chunk) if total_size > 0: progress = (downloaded / total_size) * 100 if downloaded % (1024 * 1024 * 10) == 0: # Log every 10MB logging.info(f"Downloaded {downloaded // (1024 * 1024)}MB / {total_size // (1024 * 1024)}MB ({progress:.1f}%)") logging.info(f"GPFM weights downloaded successfully to {filepath}") return True except Exception as e: logging.error(f"Failed to download GPFM weights: {e}") if os.path.exists(filepath): os.remove(filepath) # Clean up partial download return False def _load_model(self): """Load GPFM model with pretrained weights.""" try: # Create models directory models_dir = os.path.join(cache_dir, 'gpfm_models') os.makedirs(models_dir, exist_ok=True) # GPFM weights URL from official repository weights_url = "https://github.com/birkhoffkiki/GPFM/releases/download/ckpt/GPFM.pth" weights_path = os.path.join(models_dir, 'GPFM.pth') # Create GPFM DinoVisionTransformer architecture self.model = DinoVisionTransformer( img_size=224, patch_size=14, embed_dim=1024, depth=24, num_heads=16 ) # Try to download and load GPFM weights weights_loaded = False if self._download_weights(weights_url, weights_path): try: logging.info("Loading GPFM pretrained weights...") checkpoint = torch.load(weights_path, map_location=self.device) # Extract teacher model weights (GPFM format) if 'teacher' in checkpoint: state_dict = checkpoint['teacher'] logging.info("Found 'teacher' key in checkpoint") else: state_dict = checkpoint logging.info("Using checkpoint directly") # Rename keys to match our simplified architecture new_state_dict = {} for k, v in state_dict.items(): # Remove 'backbone.' prefix if present if k.startswith('backbone.'): k = k[9:] # Remove 'backbone.' # Map GPFM keys to our simplified architecture if k in ['cls_token', 'pos_embed']: new_state_dict[k] = v elif k.startswith('patch_embed.proj.'): # Map patch embedding new_k = k.replace('patch_embed.proj.', 'patch_embed.') new_state_dict[new_k] = v elif k.startswith('blocks.') and 'norm' in k: # Map layer norms if k.endswith('.norm1.weight') or k.endswith('.norm1.bias'): # Skip intermediate norms for simplified model continue elif k.endswith('.norm2.weight') or k.endswith('.norm2.bias'): continue elif k == 'norm.weight' or k == 'norm.bias': new_state_dict[k] = v # Load compatible weights missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False) if missing_keys: logging.warning(f"Missing keys: {missing_keys[:5]}...") # Show first 5 if unexpected_keys: logging.warning(f"Unexpected keys: {unexpected_keys[:5]}...") # Show first 5 logging.info("GPFM pretrained weights loaded successfully") weights_loaded = True except Exception as e: logging.warning(f"Could not load GPFM weights: {e}") if not weights_loaded: logging.info("Using randomly initialized GPFM architecture (no pretrained weights)") self.model = self.model.to(self.device) self.model.eval() # GPFM preprocessing (based on official repository) self.transformer = transforms.Compose([ transforms.Lambda(lambda x: x.convert("RGB")), # Ensure RGB format transforms.Resize((224, 224)), # GPFM uses 224x224 (not 512x512 for features) transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], # ImageNet normalization std=[0.229, 0.224, 0.225] ) ]) logging.info(f"GPFM model initialized successfully (embed_dim: {self.embed_dim})") except Exception as e: logging.error(f"Failed to initialize GPFM model: {e}") raise def forward(self, x): """Forward pass through GPFM model.""" with torch.no_grad(): return self.model(x) def get_transformer(self, apply_normalization=True): """Get the preprocessing transformer for GPFM.""" if apply_normalization: return self.transformer else: # Return transformer without normalization return transforms.Compose([ transforms.Lambda(lambda x: x.convert("RGB")), transforms.Resize((224, 224)), transforms.ToTensor() ]) # Available models from torchvision AVAILABLE_MODELS = { name: getattr(models, name) for name in dir(models) if callable( getattr(models, name) ) and "weights" in signature(getattr(models, name)).parameters } # Add GPFM model if available if GPFM_AVAILABLE: AVAILABLE_MODELS['gpfm'] = GPFMModel # Default resize and normalization settings for models MODEL_DEFAULTS = { "default": {"resize": (224, 224), "normalize": ( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] )}, "efficientnet_b1": {"resize": (240, 240)}, "efficientnet_b2": {"resize": (260, 260)}, "efficientnet_b3": {"resize": (300, 300)}, "efficientnet_b4": {"resize": (380, 380)}, "efficientnet_b5": {"resize": (456, 456)}, "efficientnet_b6": {"resize": (528, 528)}, "efficientnet_b7": {"resize": (600, 600)}, "inception_v3": {"resize": (299, 299)}, "swin_b": {"resize": (224, 224), "normalize": ( [0.5, 0.0, 0.5], [0.5, 0.5, 0.5] )}, "swin_s": {"resize": (224, 224), "normalize": ( [0.5, 0.0, 0.5], [0.5, 0.5, 0.5] )}, "swin_t": {"resize": (224, 224), "normalize": ( [0.5, 0.0, 0.5], [0.5, 0.5, 0.5] )}, "vit_b_16": {"resize": (224, 224), "normalize": ( [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] )}, "vit_b_32": {"resize": (224, 224), "normalize": ( [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] )}, "gpfm": { "resize": (224, 224), "normalize": ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) }, } for model, settings in MODEL_DEFAULTS.items(): if "normalize" not in settings: settings["normalize"] = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Custom transform classes class CLAHETransform: def __init__(self, clip_limit=2.0, tile_grid_size=(8, 8)): self.clahe = cv2.createCLAHE( clipLimit=clip_limit, tileGridSize=tile_grid_size ) def __call__(self, img): img = np.array(img.convert("L")) img = self.clahe.apply(img) return Image.fromarray(img).convert("RGB") class CannyTransform: def __init__(self, threshold1=100, threshold2=200): self.threshold1 = threshold1 self.threshold2 = threshold2 def __call__(self, img): img = np.array(img.convert("L")) edges = cv2.Canny(img, self.threshold1, self.threshold2) return Image.fromarray(edges).convert("RGB") class RGBAtoRGBTransform: def __call__(self, img): if img.mode == "RGBA": background = Image.new("RGBA", img.size, (255, 255, 255, 255)) img = Image.alpha_composite(background, img).convert("RGB") else: img = img.convert("RGB") return img def get_image_files_from_zip(zip_file): """Returns a list of image file names in the ZIP file.""" try: with zipfile.ZipFile(zip_file, "r") as zip_ref: file_list = [ f for f in zip_ref.namelist() if f.lower().endswith( (".png", ".jpg", ".jpeg", ".bmp", ".gif") ) ] return file_list except zipfile.BadZipFile as exc: raise RuntimeError("Invalid ZIP file.") from exc except Exception as exc: raise RuntimeError("Error reading ZIP file.") from exc def load_model(model_name, device): """Loads a specified torchvision model and modifies it for feature extraction.""" if model_name not in AVAILABLE_MODELS: raise ValueError( f"Unsupported model: {model_name}. \ Available models: {list(AVAILABLE_MODELS.keys())}") try: # Special handling for GPFM if model_name == "gpfm": model = AVAILABLE_MODELS[model_name](device=device) logging.info("GPFM model loaded") return model # Standard torchvision models if "weights" in signature( AVAILABLE_MODELS[model_name]).parameters: model = AVAILABLE_MODELS[model_name](weights="DEFAULT").to(device) else: model = AVAILABLE_MODELS[model_name]().to(device) logging.info("Model loaded") except Exception as e: logging.error(f"Failed to load model {model_name}: {e}") raise if hasattr(model, "fc"): model.fc = torch.nn.Identity() elif hasattr(model, "classifier"): model.classifier = torch.nn.Identity() elif hasattr(model, "head"): model.head = torch.nn.Identity() model.eval() return model def write_csv(output_csv, list_embeddings, ludwig_format=False): """Writes embeddings to a CSV file, optionally in Ludwig format.""" with open(output_csv, mode="w", encoding="utf-8", newline="") as csv_file: csv_writer = csv.writer(csv_file) if list_embeddings: if ludwig_format: header = ["sample_name", "embedding"] formatted_embeddings = [] for embedding in list_embeddings: sample_name = embedding[0] vector = embedding[1:] embedding_str = " ".join(map(str, vector)) formatted_embeddings.append([sample_name, embedding_str]) csv_writer.writerow(header) csv_writer.writerows(formatted_embeddings) logging.info("CSV created in Ludwig format") else: header = ["sample_name"] + [f"vector{i + 1}" for i in range( len(list_embeddings[0]) - 1 )] csv_writer.writerow(header) csv_writer.writerows(list_embeddings) logging.info("CSV created") else: csv_writer.writerow(["sample_name"] if not ludwig_format else ["sample_name", "embedding"]) logging.info("No valid images found. Empty CSV created.") def extract_embeddings( model_name, apply_normalization, zip_file, file_list, transform_type="rgb"): """Extracts embeddings from images using batch processing or sequential fallback.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = load_model(model_name, device) model_settings = MODEL_DEFAULTS.get(model_name, MODEL_DEFAULTS["default"]) resize = model_settings["resize"] normalize = model_settings.get("normalize", ( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] )) # Define transform pipeline if transform_type == "grayscale": initial_transform = transforms.Grayscale(num_output_channels=3) elif transform_type == "clahe": initial_transform = CLAHETransform() elif transform_type == "edges": initial_transform = CannyTransform() elif transform_type == "rgba_to_rgb": initial_transform = RGBAtoRGBTransform() else: initial_transform = transforms.Lambda(lambda x: x.convert("RGB")) # Handle GPFM separately as it has its own preprocessing if model_name == "gpfm": # For GPFM, combine initial transform with GPFM's custom transformer if transform_type in ["grayscale", "clahe", "edges", "rgba_to_rgb"]: transform = transforms.Compose([ initial_transform, model.get_transformer(apply_normalization=apply_normalization) ]) else: transform = model.get_transformer(apply_normalization=apply_normalization) else: # Standard torchvision models transform_list = [initial_transform, transforms.Resize(resize), transforms.ToTensor()] if apply_normalization: transform_list.append(transforms.Normalize(mean=normalize[0], std=normalize[1])) transform = transforms.Compose(transform_list) class ImageDataset(Dataset): def __init__(self, zip_file, file_list, transform=None): self.zip_file = zip_file self.file_list = file_list self.transform = transform def __len__(self): return len(self.file_list) def __getitem__(self, idx): with zipfile.ZipFile(self.zip_file, "r") as zip_ref: with zip_ref.open(self.file_list[idx]) as file: try: image = Image.open(file) if self.transform: image = self.transform(image) return image, os.path.basename(self.file_list[idx]) except Exception as e: logging.warning( "Skipping %s: %s", self.file_list[idx], e ) return None, os.path.basename(self.file_list[idx]) # Custom collate function def collate_fn(batch): batch = [item for item in batch if item[0] is not None] if not batch: return None, None images, names = zip(*batch) return torch.stack(images), names list_embeddings = [] with torch.inference_mode(): try: # Try DataLoader with reduced resource usage dataset = ImageDataset(zip_file, file_list, transform=transform) dataloader = DataLoader( dataset, batch_size=16, # Reduced for lower memory usage num_workers=0, # Fix multiprocessing issues with GPFM shuffle=False, pin_memory=True if device == "cuda" else False, collate_fn=collate_fn, ) for images, names in dataloader: if images is None: continue images = images.to(device) embeddings = model(images).cpu().numpy() for name, embedding in zip(names, embeddings): list_embeddings.append([name] + embedding.tolist()) except RuntimeError as e: logging.warning( f"DataLoader failed: {e}. \ Falling back to sequential processing." ) # Fallback to sequential processing for file in file_list: with zipfile.ZipFile(zip_file, "r") as zip_ref: with zip_ref.open(file) as img_file: try: image = Image.open(img_file) image = transform(image) input_tensor = image.unsqueeze(0).to(device) embedding = model( input_tensor ).squeeze().cpu().numpy() list_embeddings.append( [os.path.basename(file)] + embedding.tolist() ) except Exception as e: logging.warning("Skipping %s: %s", file, e) return list_embeddings def main(zip_file, output_csv, model_name, apply_normalization=False, transform_type="rgb", ludwig_format=False): """Main entry point for processing the zip file and extracting embeddings.""" file_list = get_image_files_from_zip(zip_file) logging.info("Image files listed from ZIP") list_embeddings = extract_embeddings( model_name, apply_normalization, zip_file, file_list, transform_type ) logging.info("Embeddings extracted") write_csv(output_csv, list_embeddings, ludwig_format) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Extract image embeddings.") parser.add_argument( "--zip_file", required=True, help="Path to the ZIP file containing images." ) parser.add_argument( "--model_name", required=True, choices=AVAILABLE_MODELS.keys(), help="Model for embedding extraction." ) parser.add_argument( "--normalize", action="store_true", help="Whether to apply normalization." ) parser.add_argument( "--transform_type", required=True, help="Image transformation type." ) parser.add_argument( "--output_csv", required=True, help="Path to the output CSV file" ) parser.add_argument( "--ludwig_format", action="store_true", help="Prepare CSV file in Ludwig input format" ) args = parser.parse_args() main( args.zip_file, args.output_csv, args.model_name, args.normalize, args.transform_type, args.ludwig_format )
