Mercurial > repos > goeckslab > extract_embeddings
changeset 0:38333676a029 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
author | goeckslab |
---|---|
date | Thu, 19 Jun 2025 23:33:23 +0000 |
parents | |
children | |
files | Docker/Dockerfile README.md pytorch_embedding.py pytorch_embedding.xml test-data/1_digit.zip |
diffstat | 5 files changed, 534 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/Docker/Dockerfile Thu Jun 19 23:33:23 2025 +0000 @@ -0,0 +1,21 @@ +# Use a lightweight Python 3.9 base image +FROM python:3.9-slim + +# Install system dependencies for OpenCV and other libraries in one layer +RUN apt-get update && apt-get install -y \ + libgl1-mesa-glx \ + libglib2.0-0 \ + && rm -rf /var/lib/apt/lists/* + +# Upgrade pip to the latest version +RUN pip install --upgrade pip + +# Install PyTorch and torchvision CPU-only versions +RUN pip install --no-cache-dir torch==2.0.0 torchvision==0.15.1 \ + -f https://download.pytorch.org/whl/cpu/torch_stable.html + +# Install NumPy with a version compatible with PyTorch 2.0.0 +RUN pip install --no-cache-dir numpy==1.24.4 + +# Install remaining Python dependencies +RUN pip install --no-cache-dir Pillow opencv-python pandas fastparquet argparse logging multiprocessing
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/README.md Thu Jun 19 23:33:23 2025 +0000 @@ -0,0 +1,2 @@ +# Galaxy-Embedding_extractor +Tool to extract and save learned feature vectors (embeddings) from pre-trained models for downstream tasks.
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pytorch_embedding.py Thu Jun 19 23:33:23 2025 +0000 @@ -0,0 +1,378 @@ +""" +This module provides functionality to extract image embeddings +using a specified +pretrained model from the torchvision library. It includes functions to: +- List image files directly from a ZIP file without extraction. +- Apply model-specific preprocessing and transformations. +- Extract embeddings using various models. +- Save the resulting embeddings into a CSV file. +Modules required: +- argparse: For command-line argument parsing. +- os, csv, zipfile: For file handling (ZIP file reading, CSV writing). +- inspect: For inspecting function signatures and models. +- torch, torchvision: For loading and using pretrained models +to extract embeddings. +- PIL, cv2: For image processing tasks such as resizing, normalization, +and conversion. +""" + +import argparse +import csv +import inspect +import logging +import os +import zipfile +from inspect import signature + +import cv2 +import numpy as np +import torch +import torchvision.models as models +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + +# Configure logging +logging.basicConfig( + filename="/tmp/ludwig_embeddings.log", + filemode="a", + format="%(asctime)s - %(levelname)s - %(message)s", + level=logging.DEBUG, +) + +# Create a cache directory in the current working directory +cache_dir = os.path.join(os.getcwd(), 'hf_cache') +try: + os.makedirs(cache_dir, exist_ok=True) + logging.info(f"Cache directory created: {cache_dir}, writable: {os.access(cache_dir, os.W_OK)}") +except OSError as e: + logging.error(f"Failed to create cache directory {cache_dir}: {e}") + raise + +# Available models from torchvision +AVAILABLE_MODELS = { + name: getattr(models, name) + for name in dir(models) + if callable( + getattr(models, name) + ) and "weights" in signature(getattr(models, name)).parameters +} + +# Default resize and normalization settings for models +MODEL_DEFAULTS = { + "default": {"resize": (224, 224), "normalize": ( + [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] + )}, + "efficientnet_b1": {"resize": (240, 240)}, + "efficientnet_b2": {"resize": (260, 260)}, + "efficientnet_b3": {"resize": (300, 300)}, + "efficientnet_b4": {"resize": (380, 380)}, + "efficientnet_b5": {"resize": (456, 456)}, + "efficientnet_b6": {"resize": (528, 528)}, + "efficientnet_b7": {"resize": (600, 600)}, + "inception_v3": {"resize": (299, 299)}, + "swin_b": {"resize": (224, 224), "normalize": ( + [0.5, 0.0, 0.5], [0.5, 0.5, 0.5] + )}, + "swin_s": {"resize": (224, 224), "normalize": ( + [0.5, 0.0, 0.5], [0.5, 0.5, 0.5] + )}, + "swin_t": {"resize": (224, 224), "normalize": ( + [0.5, 0.0, 0.5], [0.5, 0.5, 0.5] + )}, + "vit_b_16": {"resize": (224, 224), "normalize": ( + [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + )}, + "vit_b_32": {"resize": (224, 224), "normalize": ( + [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + )}, +} + +for model, settings in MODEL_DEFAULTS.items(): + if "normalize" not in settings: + settings["normalize"] = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + + +# Custom transform classes +class CLAHETransform: + def __init__(self, clip_limit=2.0, tile_grid_size=(8, 8)): + self.clahe = cv2.createCLAHE( + clipLimit=clip_limit, + tileGridSize=tile_grid_size + ) + + def __call__(self, img): + img = np.array(img.convert("L")) + img = self.clahe.apply(img) + return Image.fromarray(img).convert("RGB") + + +class CannyTransform: + def __init__(self, threshold1=100, threshold2=200): + self.threshold1 = threshold1 + self.threshold2 = threshold2 + + def __call__(self, img): + img = np.array(img.convert("L")) + edges = cv2.Canny(img, self.threshold1, self.threshold2) + return Image.fromarray(edges).convert("RGB") + + +class RGBAtoRGBTransform: + def __call__(self, img): + if img.mode == "RGBA": + background = Image.new("RGBA", img.size, (255, 255, 255, 255)) + img = Image.alpha_composite(background, img).convert("RGB") + else: + img = img.convert("RGB") + return img + + +def get_image_files_from_zip(zip_file): + """Returns a list of image file names in the ZIP file.""" + try: + with zipfile.ZipFile(zip_file, "r") as zip_ref: + file_list = [ + f for f in zip_ref.namelist() if f.lower().endswith( + (".png", ".jpg", ".jpeg", ".bmp", ".gif") + ) + ] + return file_list + except zipfile.BadZipFile as exc: + raise RuntimeError("Invalid ZIP file.") from exc + except Exception as exc: + raise RuntimeError("Error reading ZIP file.") from exc + + +def load_model(model_name, device): + """Loads a specified torchvision model and + modifies it for feature extraction.""" + if model_name not in AVAILABLE_MODELS: + raise ValueError( + f"Unsupported model: {model_name}. \ + Available models: {list(AVAILABLE_MODELS.keys())}") + try: + if "weights" in inspect.signature( + AVAILABLE_MODELS[model_name]).parameters: + model = AVAILABLE_MODELS[model_name](weights="DEFAULT").to(device) + else: + model = AVAILABLE_MODELS[model_name]().to(device) + logging.info("Model loaded") + except Exception as e: + logging.error(f"Failed to load model {model_name}: {e}") + raise + + if hasattr(model, "fc"): + model.fc = torch.nn.Identity() + elif hasattr(model, "classifier"): + model.classifier = torch.nn.Identity() + elif hasattr(model, "head"): + model.head = torch.nn.Identity() + + model.eval() + return model + + +def write_csv(output_csv, list_embeddings, ludwig_format=False): + """Writes embeddings to a CSV file, optionally in Ludwig format.""" + with open(output_csv, mode="w", encoding="utf-8", newline="") as csv_file: + csv_writer = csv.writer(csv_file) + if list_embeddings: + if ludwig_format: + header = ["sample_name", "embedding"] + formatted_embeddings = [] + for embedding in list_embeddings: + sample_name = embedding[0] + vector = embedding[1:] + embedding_str = " ".join(map(str, vector)) + formatted_embeddings.append([sample_name, embedding_str]) + csv_writer.writerow(header) + csv_writer.writerows(formatted_embeddings) + logging.info("CSV created in Ludwig format") + else: + header = ["sample_name"] + [f"vector{i + 1}" for i in range( + len(list_embeddings[0]) - 1 + )] + csv_writer.writerow(header) + csv_writer.writerows(list_embeddings) + logging.info("CSV created") + else: + csv_writer.writerow(["sample_name"] if not ludwig_format + else ["sample_name", "embedding"]) + logging.info("No valid images found. Empty CSV created.") + + +def extract_embeddings( + model_name, + apply_normalization, + zip_file, + file_list, + transform_type="rgb"): + """Extracts embeddings from images + using batch processing or sequential fallback.""" + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = load_model(model_name, device) + model_settings = MODEL_DEFAULTS.get(model_name, MODEL_DEFAULTS["default"]) + resize = model_settings["resize"] + normalize = model_settings.get("normalize", ( + [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] + )) + + # Define transform pipeline + if transform_type == "grayscale": + initial_transform = transforms.Grayscale(num_output_channels=3) + elif transform_type == "clahe": + initial_transform = CLAHETransform() + elif transform_type == "edges": + initial_transform = CannyTransform() + elif transform_type == "rgba_to_rgb": + initial_transform = RGBAtoRGBTransform() + else: + initial_transform = transforms.Lambda(lambda x: x.convert("RGB")) + + transform_list = [initial_transform, + transforms.Resize(resize), + transforms.ToTensor()] + if apply_normalization: + transform_list.append(transforms.Normalize(mean=normalize[0], + std=normalize[1])) + transform = transforms.Compose(transform_list) + + class ImageDataset(Dataset): + def __init__(self, zip_file, file_list, transform=None): + self.zip_file = zip_file + self.file_list = file_list + self.transform = transform + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, idx): + with zipfile.ZipFile(self.zip_file, "r") as zip_ref: + with zip_ref.open(self.file_list[idx]) as file: + try: + image = Image.open(file) + if self.transform: + image = self.transform(image) + return image, os.path.basename(self.file_list[idx]) + except Exception as e: + logging.warning( + "Skipping %s: %s", self.file_list[idx], e + ) + return None, os.path.basename(self.file_list[idx]) + + # Custom collate function + def collate_fn(batch): + batch = [item for item in batch if item[0] is not None] + if not batch: + return None, None + images, names = zip(*batch) + return torch.stack(images), names + + list_embeddings = [] + with torch.inference_mode(): + try: + # Try DataLoader with reduced resource usage + dataset = ImageDataset(zip_file, file_list, transform=transform) + dataloader = DataLoader( + dataset, + batch_size=16, # Reduced for lower memory usage + num_workers=1, # Reduced to minimize shared memory + shuffle=False, + pin_memory=True if device == "cuda" else False, + collate_fn=collate_fn, + ) + for images, names in dataloader: + if images is None: + continue + images = images.to(device) + embeddings = model(images).cpu().numpy() + for name, embedding in zip(names, embeddings): + list_embeddings.append([name] + embedding.tolist()) + except RuntimeError as e: + logging.warning( + f"DataLoader failed: {e}. \ + Falling back to sequential processing." + ) + # Fallback to sequential processing + for file in file_list: + with zipfile.ZipFile(zip_file, "r") as zip_ref: + with zip_ref.open(file) as img_file: + try: + image = Image.open(img_file) + image = transform(image) + input_tensor = image.unsqueeze(0).to(device) + embedding = model( + input_tensor + ).squeeze().cpu().numpy() + list_embeddings.append( + [os.path.basename(file)] + embedding.tolist() + ) + except Exception as e: + logging.warning("Skipping %s: %s", file, e) + + return list_embeddings + + +def main(zip_file, output_csv, model_name, apply_normalization=False, + transform_type="rgb", ludwig_format=False): + """Main entry point for processing the zip file and + extracting embeddings.""" + file_list = get_image_files_from_zip(zip_file) + logging.info("Image files listed from ZIP") + + list_embeddings = extract_embeddings( + model_name, + apply_normalization, + zip_file, + file_list, + transform_type + ) + logging.info("Embeddings extracted") + write_csv(output_csv, list_embeddings, ludwig_format) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Extract image embeddings.") + parser.add_argument( + "--zip_file", + required=True, + help="Path to the ZIP file containing images." + ) + parser.add_argument( + "--model_name", + required=True, + choices=AVAILABLE_MODELS.keys(), + help="Model for embedding extraction." + ) + parser.add_argument( + "--normalize", + action="store_true", + help="Whether to apply normalization." + ) + parser.add_argument( + "--transform_type", + required=True, + help="Image transformation type." + ) + parser.add_argument( + "--output_csv", + required=True, + help="Path to the output CSV file" + ) + parser.add_argument( + "--ludwig_format", + action="store_true", + help="Prepare CSV file in Ludwig input format" + ) + + args = parser.parse_args() + main( + args.zip_file, + args.output_csv, + args.model_name, + args.normalize, + args.transform_type, + args.ludwig_format + )
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pytorch_embedding.xml Thu Jun 19 23:33:23 2025 +0000 @@ -0,0 +1,133 @@ +<tool id="extract_embeddings" name="Image Embedding Extraction" version="1.0.0"> + <description>Extract image embeddings using a deep learning model</description> + + <requirements> + <container type="docker">quay.io/goeckslab/galaxy-ludwig-gpu:0.10.1</container> + </requirements> + <stdio> + <exit_code range="137" level="fatal_oom" description="Out of Memory" /> + <exit_code range="1:" level="fatal" description="Error occurred. Please check Tool Standard Error" /> + </stdio> + <command><![CDATA[ + mkdir -p "./hf_cache" && + export HF_HOME="./hf_cache" && + export TORCH_HOME="./hf_cache" && + python $__tool_directory__/pytorch_embedding.py + --zip_file "$input_zip" + --output_csv "$output_csv" + --model_name "$model_name" + #if $apply_normalization + --normalize + #end if + #if $ludwig_format + --ludwig_format + #end if + --transform_type "$transform_type" + ]]></command> + <configfiles> + <inputs name="inputs" /> + </configfiles> + <inputs> + <param argument="input_zip" type="data" format="zip" label="Input Zip File (Images)" help="Provide a zip file containing images to process." /> + <param argument="model_name" type="select" label="Model for Embedding Extraction" help="Select the model to use for embedding extraction."> + <option value="alexnet">AlexNet</option> + <option value="convnext_tiny">ConvNeXt-Tiny</option> + <option value="convnext_small">ConvNeXt-Small</option> + <option value="convnext_base">ConvNeXt-Base</option> + <option value="convnext_large">ConvNeXt-Large</option> + <option value="densenet121">DenseNet121</option> + <option value="densenet161">DenseNet161</option> + <option value="densenet169">DenseNet169</option> + <option value="densenet201">DenseNet201</option> + <option value="efficientnet_b0" >EfficientNet-B0</option> + <option value="efficientnet_b1">EfficientNet-B1</option> + <option value="efficientnet_b2">EfficientNet-B2</option> + <option value="efficientnet_b3">EfficientNet-B3</option> + <option value="efficientnet_b4">EfficientNet-B4</option> + <option value="efficientnet_b5">EfficientNet-B5</option> + <option value="efficientnet_b6">EfficientNet-B6</option> + <option value="efficientnet_b7">EfficientNet-B7</option> + <option value="efficientnet_v2_s">EfficientNetV2-S</option> + <option value="efficientnet_v2_m">EfficientNetV2-M</option> + <option value="efficientnet_v2_l">EfficientNetV2-L</option> + <option value="googlenet">GoogLeNet</option> + <option value="inception_v3">Inception-V3</option> + <option value="mnasnet0_5">MNASNet-0.5</option> + <option value="mnasnet0_75">MNASNet-0.75</option> + <option value="mnasnet1_0">MNASNet-1.0</option> + <option value="mnasnet1_3">MNASNet-1.3</option> + <option value="mobilenet_v2">MobileNetV2</option> + <option value="mobilenet_v3_large">MobileNetV3-Large</option> + <option value="mobilenet_v3_small">MobileNetV3-Small</option> + <option value="regnet_x_400mf">RegNet-X-400MF</option> + <option value="regnet_x_800mf">RegNet-X-800MF</option> + <option value="regnet_x_1_6gf">RegNet-X-1.6GF</option> + <option value="regnet_x_3_2gf">RegNet-X-3.2GF</option> + <option value="regnet_x_8gf">RegNet-X-8GF</option> + <option value="resnet18">ResNet-18</option> + <option value="resnet34">ResNet-34</option> + <option value="resnet50" selected="true">ResNet-50</option> + <option value="resnet101">ResNet-101</option> + <option value="resnet152">ResNet-152</option> + <option value="resnext50_32x4d">ResNeXt-50-32x4d</option> + <option value="resnext101_32x8d">ResNeXt-101-32x8d</option> + <option value="shufflenet_v2_x0_5">ShuffleNetV2-0.5x</option> + <option value="shufflenet_v2_x1_0">ShuffleNetV2-1.0x</option> + <option value="squeezenet1_0">SqueezeNet1.0</option> + <option value="squeezenet1_1">SqueezeNet1.1</option> + <option value="swin_b">Swin-B</option> + <option value="swin_s">Swin-S</option> + <option value="swin_t">Swin-T</option> + <option value="vgg11">VGG-11</option> + <option value="vgg13">VGG-13</option> + <option value="vgg16">VGG-16</option> + <option value="vgg19">VGG-19</option> + <option value="vit_b_16">ViT-B-16</option> + <option value="vit_b_32">ViT-B-32</option> + <option value="wide_resnet50_2">Wide-ResNet50-2</option> + <option value="wide_resnet101_2">Wide-ResNet101-2</option> + </param> + <param argument="apply_normalization" type="boolean" label="Apply Normalization" help="Enable or disable normalization of embeddings." checked="true"/> + <param argument="transform_type" type="select" label="Image Transformation Type" help="Choose the transformation type to apply before extraction."> + <option value="RGB" selected="true">RGB</option> + <option value="grayscale">Grayscale</option> + <option value="rgba_to_rgb">RGBA to RGB</option> + <option value="clahe">CLAHE (Contrast Limited Adaptive Histogram Equalization)</option> + <option value="edges">Edge Detection</option> + </param> + <param name="ludwig_format" type="boolean" optional="true" label="Convert vectors (stored as columns) into a single string column (Ludwig Format)?"/> + </inputs> + <outputs> + <data name="output_csv" format="csv" label="Extracted Embeddings" /> + </outputs> + + <tests> + <test> + <param name="input_zip" value="1_digit.zip" ftype="zip" /> + <param name="model_name" value="resnet50" /> + <param name="apply_normalization" value="true" /> + <param name="transform_type" value="RGB" /> + <output name="output_csv"> + <assert_contents> + <has_text text="sample_name" /> + <has_n_columns min="1" /> + </assert_contents> + </output> + </test> + </tests> + <help> + <![CDATA[ + **What it does** + This tool extracts image embeddings using a selected deep learning model. + + **Inputs** + - A zip file containing images to process. + - A model selection for embedding extraction. + - An option to apply normalization to the extracted embeddings. + - A choice of image transformation type before processing. + + **Outputs** + - A CSV file containing embeddings. Each row corresponds to an image, with the file name in the first column and embedding vectors in the subsequent columns. + ]]> + </help> +</tool>