Mercurial > repos > goeckslab > extract_embeddings
view pytorch_embedding.py @ 0:38333676a029 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
author | goeckslab |
---|---|
date | Thu, 19 Jun 2025 23:33:23 +0000 |
parents | |
children |
line wrap: on
line source
""" This module provides functionality to extract image embeddings using a specified pretrained model from the torchvision library. It includes functions to: - List image files directly from a ZIP file without extraction. - Apply model-specific preprocessing and transformations. - Extract embeddings using various models. - Save the resulting embeddings into a CSV file. Modules required: - argparse: For command-line argument parsing. - os, csv, zipfile: For file handling (ZIP file reading, CSV writing). - inspect: For inspecting function signatures and models. - torch, torchvision: For loading and using pretrained models to extract embeddings. - PIL, cv2: For image processing tasks such as resizing, normalization, and conversion. """ import argparse import csv import inspect import logging import os import zipfile from inspect import signature import cv2 import numpy as np import torch import torchvision.models as models from PIL import Image from torch.utils.data import DataLoader, Dataset from torchvision import transforms # Configure logging logging.basicConfig( filename="/tmp/ludwig_embeddings.log", filemode="a", format="%(asctime)s - %(levelname)s - %(message)s", level=logging.DEBUG, ) # Create a cache directory in the current working directory cache_dir = os.path.join(os.getcwd(), 'hf_cache') try: os.makedirs(cache_dir, exist_ok=True) logging.info(f"Cache directory created: {cache_dir}, writable: {os.access(cache_dir, os.W_OK)}") except OSError as e: logging.error(f"Failed to create cache directory {cache_dir}: {e}") raise # Available models from torchvision AVAILABLE_MODELS = { name: getattr(models, name) for name in dir(models) if callable( getattr(models, name) ) and "weights" in signature(getattr(models, name)).parameters } # Default resize and normalization settings for models MODEL_DEFAULTS = { "default": {"resize": (224, 224), "normalize": ( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] )}, "efficientnet_b1": {"resize": (240, 240)}, "efficientnet_b2": {"resize": (260, 260)}, "efficientnet_b3": {"resize": (300, 300)}, "efficientnet_b4": {"resize": (380, 380)}, "efficientnet_b5": {"resize": (456, 456)}, "efficientnet_b6": {"resize": (528, 528)}, "efficientnet_b7": {"resize": (600, 600)}, "inception_v3": {"resize": (299, 299)}, "swin_b": {"resize": (224, 224), "normalize": ( [0.5, 0.0, 0.5], [0.5, 0.5, 0.5] )}, "swin_s": {"resize": (224, 224), "normalize": ( [0.5, 0.0, 0.5], [0.5, 0.5, 0.5] )}, "swin_t": {"resize": (224, 224), "normalize": ( [0.5, 0.0, 0.5], [0.5, 0.5, 0.5] )}, "vit_b_16": {"resize": (224, 224), "normalize": ( [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] )}, "vit_b_32": {"resize": (224, 224), "normalize": ( [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] )}, } for model, settings in MODEL_DEFAULTS.items(): if "normalize" not in settings: settings["normalize"] = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Custom transform classes class CLAHETransform: def __init__(self, clip_limit=2.0, tile_grid_size=(8, 8)): self.clahe = cv2.createCLAHE( clipLimit=clip_limit, tileGridSize=tile_grid_size ) def __call__(self, img): img = np.array(img.convert("L")) img = self.clahe.apply(img) return Image.fromarray(img).convert("RGB") class CannyTransform: def __init__(self, threshold1=100, threshold2=200): self.threshold1 = threshold1 self.threshold2 = threshold2 def __call__(self, img): img = np.array(img.convert("L")) edges = cv2.Canny(img, self.threshold1, self.threshold2) return Image.fromarray(edges).convert("RGB") class RGBAtoRGBTransform: def __call__(self, img): if img.mode == "RGBA": background = Image.new("RGBA", img.size, (255, 255, 255, 255)) img = Image.alpha_composite(background, img).convert("RGB") else: img = img.convert("RGB") return img def get_image_files_from_zip(zip_file): """Returns a list of image file names in the ZIP file.""" try: with zipfile.ZipFile(zip_file, "r") as zip_ref: file_list = [ f for f in zip_ref.namelist() if f.lower().endswith( (".png", ".jpg", ".jpeg", ".bmp", ".gif") ) ] return file_list except zipfile.BadZipFile as exc: raise RuntimeError("Invalid ZIP file.") from exc except Exception as exc: raise RuntimeError("Error reading ZIP file.") from exc def load_model(model_name, device): """Loads a specified torchvision model and modifies it for feature extraction.""" if model_name not in AVAILABLE_MODELS: raise ValueError( f"Unsupported model: {model_name}. \ Available models: {list(AVAILABLE_MODELS.keys())}") try: if "weights" in inspect.signature( AVAILABLE_MODELS[model_name]).parameters: model = AVAILABLE_MODELS[model_name](weights="DEFAULT").to(device) else: model = AVAILABLE_MODELS[model_name]().to(device) logging.info("Model loaded") except Exception as e: logging.error(f"Failed to load model {model_name}: {e}") raise if hasattr(model, "fc"): model.fc = torch.nn.Identity() elif hasattr(model, "classifier"): model.classifier = torch.nn.Identity() elif hasattr(model, "head"): model.head = torch.nn.Identity() model.eval() return model def write_csv(output_csv, list_embeddings, ludwig_format=False): """Writes embeddings to a CSV file, optionally in Ludwig format.""" with open(output_csv, mode="w", encoding="utf-8", newline="") as csv_file: csv_writer = csv.writer(csv_file) if list_embeddings: if ludwig_format: header = ["sample_name", "embedding"] formatted_embeddings = [] for embedding in list_embeddings: sample_name = embedding[0] vector = embedding[1:] embedding_str = " ".join(map(str, vector)) formatted_embeddings.append([sample_name, embedding_str]) csv_writer.writerow(header) csv_writer.writerows(formatted_embeddings) logging.info("CSV created in Ludwig format") else: header = ["sample_name"] + [f"vector{i + 1}" for i in range( len(list_embeddings[0]) - 1 )] csv_writer.writerow(header) csv_writer.writerows(list_embeddings) logging.info("CSV created") else: csv_writer.writerow(["sample_name"] if not ludwig_format else ["sample_name", "embedding"]) logging.info("No valid images found. Empty CSV created.") def extract_embeddings( model_name, apply_normalization, zip_file, file_list, transform_type="rgb"): """Extracts embeddings from images using batch processing or sequential fallback.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = load_model(model_name, device) model_settings = MODEL_DEFAULTS.get(model_name, MODEL_DEFAULTS["default"]) resize = model_settings["resize"] normalize = model_settings.get("normalize", ( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] )) # Define transform pipeline if transform_type == "grayscale": initial_transform = transforms.Grayscale(num_output_channels=3) elif transform_type == "clahe": initial_transform = CLAHETransform() elif transform_type == "edges": initial_transform = CannyTransform() elif transform_type == "rgba_to_rgb": initial_transform = RGBAtoRGBTransform() else: initial_transform = transforms.Lambda(lambda x: x.convert("RGB")) transform_list = [initial_transform, transforms.Resize(resize), transforms.ToTensor()] if apply_normalization: transform_list.append(transforms.Normalize(mean=normalize[0], std=normalize[1])) transform = transforms.Compose(transform_list) class ImageDataset(Dataset): def __init__(self, zip_file, file_list, transform=None): self.zip_file = zip_file self.file_list = file_list self.transform = transform def __len__(self): return len(self.file_list) def __getitem__(self, idx): with zipfile.ZipFile(self.zip_file, "r") as zip_ref: with zip_ref.open(self.file_list[idx]) as file: try: image = Image.open(file) if self.transform: image = self.transform(image) return image, os.path.basename(self.file_list[idx]) except Exception as e: logging.warning( "Skipping %s: %s", self.file_list[idx], e ) return None, os.path.basename(self.file_list[idx]) # Custom collate function def collate_fn(batch): batch = [item for item in batch if item[0] is not None] if not batch: return None, None images, names = zip(*batch) return torch.stack(images), names list_embeddings = [] with torch.inference_mode(): try: # Try DataLoader with reduced resource usage dataset = ImageDataset(zip_file, file_list, transform=transform) dataloader = DataLoader( dataset, batch_size=16, # Reduced for lower memory usage num_workers=1, # Reduced to minimize shared memory shuffle=False, pin_memory=True if device == "cuda" else False, collate_fn=collate_fn, ) for images, names in dataloader: if images is None: continue images = images.to(device) embeddings = model(images).cpu().numpy() for name, embedding in zip(names, embeddings): list_embeddings.append([name] + embedding.tolist()) except RuntimeError as e: logging.warning( f"DataLoader failed: {e}. \ Falling back to sequential processing." ) # Fallback to sequential processing for file in file_list: with zipfile.ZipFile(zip_file, "r") as zip_ref: with zip_ref.open(file) as img_file: try: image = Image.open(img_file) image = transform(image) input_tensor = image.unsqueeze(0).to(device) embedding = model( input_tensor ).squeeze().cpu().numpy() list_embeddings.append( [os.path.basename(file)] + embedding.tolist() ) except Exception as e: logging.warning("Skipping %s: %s", file, e) return list_embeddings def main(zip_file, output_csv, model_name, apply_normalization=False, transform_type="rgb", ludwig_format=False): """Main entry point for processing the zip file and extracting embeddings.""" file_list = get_image_files_from_zip(zip_file) logging.info("Image files listed from ZIP") list_embeddings = extract_embeddings( model_name, apply_normalization, zip_file, file_list, transform_type ) logging.info("Embeddings extracted") write_csv(output_csv, list_embeddings, ludwig_format) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Extract image embeddings.") parser.add_argument( "--zip_file", required=True, help="Path to the ZIP file containing images." ) parser.add_argument( "--model_name", required=True, choices=AVAILABLE_MODELS.keys(), help="Model for embedding extraction." ) parser.add_argument( "--normalize", action="store_true", help="Whether to apply normalization." ) parser.add_argument( "--transform_type", required=True, help="Image transformation type." ) parser.add_argument( "--output_csv", required=True, help="Path to the output CSV file" ) parser.add_argument( "--ludwig_format", action="store_true", help="Prepare CSV file in Ludwig input format" ) args = parser.parse_args() main( args.zip_file, args.output_csv, args.model_name, args.normalize, args.transform_type, args.ludwig_format )