comparison pytorch_embedding.py @ 0:38333676a029 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
author goeckslab
date Thu, 19 Jun 2025 23:33:23 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:38333676a029
1 """
2 This module provides functionality to extract image embeddings
3 using a specified
4 pretrained model from the torchvision library. It includes functions to:
5 - List image files directly from a ZIP file without extraction.
6 - Apply model-specific preprocessing and transformations.
7 - Extract embeddings using various models.
8 - Save the resulting embeddings into a CSV file.
9 Modules required:
10 - argparse: For command-line argument parsing.
11 - os, csv, zipfile: For file handling (ZIP file reading, CSV writing).
12 - inspect: For inspecting function signatures and models.
13 - torch, torchvision: For loading and using pretrained models
14 to extract embeddings.
15 - PIL, cv2: For image processing tasks such as resizing, normalization,
16 and conversion.
17 """
18
19 import argparse
20 import csv
21 import inspect
22 import logging
23 import os
24 import zipfile
25 from inspect import signature
26
27 import cv2
28 import numpy as np
29 import torch
30 import torchvision.models as models
31 from PIL import Image
32 from torch.utils.data import DataLoader, Dataset
33 from torchvision import transforms
34
35 # Configure logging
36 logging.basicConfig(
37 filename="/tmp/ludwig_embeddings.log",
38 filemode="a",
39 format="%(asctime)s - %(levelname)s - %(message)s",
40 level=logging.DEBUG,
41 )
42
43 # Create a cache directory in the current working directory
44 cache_dir = os.path.join(os.getcwd(), 'hf_cache')
45 try:
46 os.makedirs(cache_dir, exist_ok=True)
47 logging.info(f"Cache directory created: {cache_dir}, writable: {os.access(cache_dir, os.W_OK)}")
48 except OSError as e:
49 logging.error(f"Failed to create cache directory {cache_dir}: {e}")
50 raise
51
52 # Available models from torchvision
53 AVAILABLE_MODELS = {
54 name: getattr(models, name)
55 for name in dir(models)
56 if callable(
57 getattr(models, name)
58 ) and "weights" in signature(getattr(models, name)).parameters
59 }
60
61 # Default resize and normalization settings for models
62 MODEL_DEFAULTS = {
63 "default": {"resize": (224, 224), "normalize": (
64 [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
65 )},
66 "efficientnet_b1": {"resize": (240, 240)},
67 "efficientnet_b2": {"resize": (260, 260)},
68 "efficientnet_b3": {"resize": (300, 300)},
69 "efficientnet_b4": {"resize": (380, 380)},
70 "efficientnet_b5": {"resize": (456, 456)},
71 "efficientnet_b6": {"resize": (528, 528)},
72 "efficientnet_b7": {"resize": (600, 600)},
73 "inception_v3": {"resize": (299, 299)},
74 "swin_b": {"resize": (224, 224), "normalize": (
75 [0.5, 0.0, 0.5], [0.5, 0.5, 0.5]
76 )},
77 "swin_s": {"resize": (224, 224), "normalize": (
78 [0.5, 0.0, 0.5], [0.5, 0.5, 0.5]
79 )},
80 "swin_t": {"resize": (224, 224), "normalize": (
81 [0.5, 0.0, 0.5], [0.5, 0.5, 0.5]
82 )},
83 "vit_b_16": {"resize": (224, 224), "normalize": (
84 [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
85 )},
86 "vit_b_32": {"resize": (224, 224), "normalize": (
87 [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
88 )},
89 }
90
91 for model, settings in MODEL_DEFAULTS.items():
92 if "normalize" not in settings:
93 settings["normalize"] = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
94
95
96 # Custom transform classes
97 class CLAHETransform:
98 def __init__(self, clip_limit=2.0, tile_grid_size=(8, 8)):
99 self.clahe = cv2.createCLAHE(
100 clipLimit=clip_limit,
101 tileGridSize=tile_grid_size
102 )
103
104 def __call__(self, img):
105 img = np.array(img.convert("L"))
106 img = self.clahe.apply(img)
107 return Image.fromarray(img).convert("RGB")
108
109
110 class CannyTransform:
111 def __init__(self, threshold1=100, threshold2=200):
112 self.threshold1 = threshold1
113 self.threshold2 = threshold2
114
115 def __call__(self, img):
116 img = np.array(img.convert("L"))
117 edges = cv2.Canny(img, self.threshold1, self.threshold2)
118 return Image.fromarray(edges).convert("RGB")
119
120
121 class RGBAtoRGBTransform:
122 def __call__(self, img):
123 if img.mode == "RGBA":
124 background = Image.new("RGBA", img.size, (255, 255, 255, 255))
125 img = Image.alpha_composite(background, img).convert("RGB")
126 else:
127 img = img.convert("RGB")
128 return img
129
130
131 def get_image_files_from_zip(zip_file):
132 """Returns a list of image file names in the ZIP file."""
133 try:
134 with zipfile.ZipFile(zip_file, "r") as zip_ref:
135 file_list = [
136 f for f in zip_ref.namelist() if f.lower().endswith(
137 (".png", ".jpg", ".jpeg", ".bmp", ".gif")
138 )
139 ]
140 return file_list
141 except zipfile.BadZipFile as exc:
142 raise RuntimeError("Invalid ZIP file.") from exc
143 except Exception as exc:
144 raise RuntimeError("Error reading ZIP file.") from exc
145
146
147 def load_model(model_name, device):
148 """Loads a specified torchvision model and
149 modifies it for feature extraction."""
150 if model_name not in AVAILABLE_MODELS:
151 raise ValueError(
152 f"Unsupported model: {model_name}. \
153 Available models: {list(AVAILABLE_MODELS.keys())}")
154 try:
155 if "weights" in inspect.signature(
156 AVAILABLE_MODELS[model_name]).parameters:
157 model = AVAILABLE_MODELS[model_name](weights="DEFAULT").to(device)
158 else:
159 model = AVAILABLE_MODELS[model_name]().to(device)
160 logging.info("Model loaded")
161 except Exception as e:
162 logging.error(f"Failed to load model {model_name}: {e}")
163 raise
164
165 if hasattr(model, "fc"):
166 model.fc = torch.nn.Identity()
167 elif hasattr(model, "classifier"):
168 model.classifier = torch.nn.Identity()
169 elif hasattr(model, "head"):
170 model.head = torch.nn.Identity()
171
172 model.eval()
173 return model
174
175
176 def write_csv(output_csv, list_embeddings, ludwig_format=False):
177 """Writes embeddings to a CSV file, optionally in Ludwig format."""
178 with open(output_csv, mode="w", encoding="utf-8", newline="") as csv_file:
179 csv_writer = csv.writer(csv_file)
180 if list_embeddings:
181 if ludwig_format:
182 header = ["sample_name", "embedding"]
183 formatted_embeddings = []
184 for embedding in list_embeddings:
185 sample_name = embedding[0]
186 vector = embedding[1:]
187 embedding_str = " ".join(map(str, vector))
188 formatted_embeddings.append([sample_name, embedding_str])
189 csv_writer.writerow(header)
190 csv_writer.writerows(formatted_embeddings)
191 logging.info("CSV created in Ludwig format")
192 else:
193 header = ["sample_name"] + [f"vector{i + 1}" for i in range(
194 len(list_embeddings[0]) - 1
195 )]
196 csv_writer.writerow(header)
197 csv_writer.writerows(list_embeddings)
198 logging.info("CSV created")
199 else:
200 csv_writer.writerow(["sample_name"] if not ludwig_format
201 else ["sample_name", "embedding"])
202 logging.info("No valid images found. Empty CSV created.")
203
204
205 def extract_embeddings(
206 model_name,
207 apply_normalization,
208 zip_file,
209 file_list,
210 transform_type="rgb"):
211 """Extracts embeddings from images
212 using batch processing or sequential fallback."""
213
214 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
215 model = load_model(model_name, device)
216 model_settings = MODEL_DEFAULTS.get(model_name, MODEL_DEFAULTS["default"])
217 resize = model_settings["resize"]
218 normalize = model_settings.get("normalize", (
219 [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
220 ))
221
222 # Define transform pipeline
223 if transform_type == "grayscale":
224 initial_transform = transforms.Grayscale(num_output_channels=3)
225 elif transform_type == "clahe":
226 initial_transform = CLAHETransform()
227 elif transform_type == "edges":
228 initial_transform = CannyTransform()
229 elif transform_type == "rgba_to_rgb":
230 initial_transform = RGBAtoRGBTransform()
231 else:
232 initial_transform = transforms.Lambda(lambda x: x.convert("RGB"))
233
234 transform_list = [initial_transform,
235 transforms.Resize(resize),
236 transforms.ToTensor()]
237 if apply_normalization:
238 transform_list.append(transforms.Normalize(mean=normalize[0],
239 std=normalize[1]))
240 transform = transforms.Compose(transform_list)
241
242 class ImageDataset(Dataset):
243 def __init__(self, zip_file, file_list, transform=None):
244 self.zip_file = zip_file
245 self.file_list = file_list
246 self.transform = transform
247
248 def __len__(self):
249 return len(self.file_list)
250
251 def __getitem__(self, idx):
252 with zipfile.ZipFile(self.zip_file, "r") as zip_ref:
253 with zip_ref.open(self.file_list[idx]) as file:
254 try:
255 image = Image.open(file)
256 if self.transform:
257 image = self.transform(image)
258 return image, os.path.basename(self.file_list[idx])
259 except Exception as e:
260 logging.warning(
261 "Skipping %s: %s", self.file_list[idx], e
262 )
263 return None, os.path.basename(self.file_list[idx])
264
265 # Custom collate function
266 def collate_fn(batch):
267 batch = [item for item in batch if item[0] is not None]
268 if not batch:
269 return None, None
270 images, names = zip(*batch)
271 return torch.stack(images), names
272
273 list_embeddings = []
274 with torch.inference_mode():
275 try:
276 # Try DataLoader with reduced resource usage
277 dataset = ImageDataset(zip_file, file_list, transform=transform)
278 dataloader = DataLoader(
279 dataset,
280 batch_size=16, # Reduced for lower memory usage
281 num_workers=1, # Reduced to minimize shared memory
282 shuffle=False,
283 pin_memory=True if device == "cuda" else False,
284 collate_fn=collate_fn,
285 )
286 for images, names in dataloader:
287 if images is None:
288 continue
289 images = images.to(device)
290 embeddings = model(images).cpu().numpy()
291 for name, embedding in zip(names, embeddings):
292 list_embeddings.append([name] + embedding.tolist())
293 except RuntimeError as e:
294 logging.warning(
295 f"DataLoader failed: {e}. \
296 Falling back to sequential processing."
297 )
298 # Fallback to sequential processing
299 for file in file_list:
300 with zipfile.ZipFile(zip_file, "r") as zip_ref:
301 with zip_ref.open(file) as img_file:
302 try:
303 image = Image.open(img_file)
304 image = transform(image)
305 input_tensor = image.unsqueeze(0).to(device)
306 embedding = model(
307 input_tensor
308 ).squeeze().cpu().numpy()
309 list_embeddings.append(
310 [os.path.basename(file)] + embedding.tolist()
311 )
312 except Exception as e:
313 logging.warning("Skipping %s: %s", file, e)
314
315 return list_embeddings
316
317
318 def main(zip_file, output_csv, model_name, apply_normalization=False,
319 transform_type="rgb", ludwig_format=False):
320 """Main entry point for processing the zip file and
321 extracting embeddings."""
322 file_list = get_image_files_from_zip(zip_file)
323 logging.info("Image files listed from ZIP")
324
325 list_embeddings = extract_embeddings(
326 model_name,
327 apply_normalization,
328 zip_file,
329 file_list,
330 transform_type
331 )
332 logging.info("Embeddings extracted")
333 write_csv(output_csv, list_embeddings, ludwig_format)
334
335
336 if __name__ == "__main__":
337 parser = argparse.ArgumentParser(description="Extract image embeddings.")
338 parser.add_argument(
339 "--zip_file",
340 required=True,
341 help="Path to the ZIP file containing images."
342 )
343 parser.add_argument(
344 "--model_name",
345 required=True,
346 choices=AVAILABLE_MODELS.keys(),
347 help="Model for embedding extraction."
348 )
349 parser.add_argument(
350 "--normalize",
351 action="store_true",
352 help="Whether to apply normalization."
353 )
354 parser.add_argument(
355 "--transform_type",
356 required=True,
357 help="Image transformation type."
358 )
359 parser.add_argument(
360 "--output_csv",
361 required=True,
362 help="Path to the output CSV file"
363 )
364 parser.add_argument(
365 "--ludwig_format",
366 action="store_true",
367 help="Prepare CSV file in Ludwig input format"
368 )
369
370 args = parser.parse_args()
371 main(
372 args.zip_file,
373 args.output_csv,
374 args.model_name,
375 args.normalize,
376 args.transform_type,
377 args.ludwig_format
378 )