Mercurial > repos > goeckslab > extract_embeddings
comparison pytorch_embedding.py @ 0:38333676a029 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
author | goeckslab |
---|---|
date | Thu, 19 Jun 2025 23:33:23 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:38333676a029 |
---|---|
1 """ | |
2 This module provides functionality to extract image embeddings | |
3 using a specified | |
4 pretrained model from the torchvision library. It includes functions to: | |
5 - List image files directly from a ZIP file without extraction. | |
6 - Apply model-specific preprocessing and transformations. | |
7 - Extract embeddings using various models. | |
8 - Save the resulting embeddings into a CSV file. | |
9 Modules required: | |
10 - argparse: For command-line argument parsing. | |
11 - os, csv, zipfile: For file handling (ZIP file reading, CSV writing). | |
12 - inspect: For inspecting function signatures and models. | |
13 - torch, torchvision: For loading and using pretrained models | |
14 to extract embeddings. | |
15 - PIL, cv2: For image processing tasks such as resizing, normalization, | |
16 and conversion. | |
17 """ | |
18 | |
19 import argparse | |
20 import csv | |
21 import inspect | |
22 import logging | |
23 import os | |
24 import zipfile | |
25 from inspect import signature | |
26 | |
27 import cv2 | |
28 import numpy as np | |
29 import torch | |
30 import torchvision.models as models | |
31 from PIL import Image | |
32 from torch.utils.data import DataLoader, Dataset | |
33 from torchvision import transforms | |
34 | |
35 # Configure logging | |
36 logging.basicConfig( | |
37 filename="/tmp/ludwig_embeddings.log", | |
38 filemode="a", | |
39 format="%(asctime)s - %(levelname)s - %(message)s", | |
40 level=logging.DEBUG, | |
41 ) | |
42 | |
43 # Create a cache directory in the current working directory | |
44 cache_dir = os.path.join(os.getcwd(), 'hf_cache') | |
45 try: | |
46 os.makedirs(cache_dir, exist_ok=True) | |
47 logging.info(f"Cache directory created: {cache_dir}, writable: {os.access(cache_dir, os.W_OK)}") | |
48 except OSError as e: | |
49 logging.error(f"Failed to create cache directory {cache_dir}: {e}") | |
50 raise | |
51 | |
52 # Available models from torchvision | |
53 AVAILABLE_MODELS = { | |
54 name: getattr(models, name) | |
55 for name in dir(models) | |
56 if callable( | |
57 getattr(models, name) | |
58 ) and "weights" in signature(getattr(models, name)).parameters | |
59 } | |
60 | |
61 # Default resize and normalization settings for models | |
62 MODEL_DEFAULTS = { | |
63 "default": {"resize": (224, 224), "normalize": ( | |
64 [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] | |
65 )}, | |
66 "efficientnet_b1": {"resize": (240, 240)}, | |
67 "efficientnet_b2": {"resize": (260, 260)}, | |
68 "efficientnet_b3": {"resize": (300, 300)}, | |
69 "efficientnet_b4": {"resize": (380, 380)}, | |
70 "efficientnet_b5": {"resize": (456, 456)}, | |
71 "efficientnet_b6": {"resize": (528, 528)}, | |
72 "efficientnet_b7": {"resize": (600, 600)}, | |
73 "inception_v3": {"resize": (299, 299)}, | |
74 "swin_b": {"resize": (224, 224), "normalize": ( | |
75 [0.5, 0.0, 0.5], [0.5, 0.5, 0.5] | |
76 )}, | |
77 "swin_s": {"resize": (224, 224), "normalize": ( | |
78 [0.5, 0.0, 0.5], [0.5, 0.5, 0.5] | |
79 )}, | |
80 "swin_t": {"resize": (224, 224), "normalize": ( | |
81 [0.5, 0.0, 0.5], [0.5, 0.5, 0.5] | |
82 )}, | |
83 "vit_b_16": {"resize": (224, 224), "normalize": ( | |
84 [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] | |
85 )}, | |
86 "vit_b_32": {"resize": (224, 224), "normalize": ( | |
87 [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] | |
88 )}, | |
89 } | |
90 | |
91 for model, settings in MODEL_DEFAULTS.items(): | |
92 if "normalize" not in settings: | |
93 settings["normalize"] = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
94 | |
95 | |
96 # Custom transform classes | |
97 class CLAHETransform: | |
98 def __init__(self, clip_limit=2.0, tile_grid_size=(8, 8)): | |
99 self.clahe = cv2.createCLAHE( | |
100 clipLimit=clip_limit, | |
101 tileGridSize=tile_grid_size | |
102 ) | |
103 | |
104 def __call__(self, img): | |
105 img = np.array(img.convert("L")) | |
106 img = self.clahe.apply(img) | |
107 return Image.fromarray(img).convert("RGB") | |
108 | |
109 | |
110 class CannyTransform: | |
111 def __init__(self, threshold1=100, threshold2=200): | |
112 self.threshold1 = threshold1 | |
113 self.threshold2 = threshold2 | |
114 | |
115 def __call__(self, img): | |
116 img = np.array(img.convert("L")) | |
117 edges = cv2.Canny(img, self.threshold1, self.threshold2) | |
118 return Image.fromarray(edges).convert("RGB") | |
119 | |
120 | |
121 class RGBAtoRGBTransform: | |
122 def __call__(self, img): | |
123 if img.mode == "RGBA": | |
124 background = Image.new("RGBA", img.size, (255, 255, 255, 255)) | |
125 img = Image.alpha_composite(background, img).convert("RGB") | |
126 else: | |
127 img = img.convert("RGB") | |
128 return img | |
129 | |
130 | |
131 def get_image_files_from_zip(zip_file): | |
132 """Returns a list of image file names in the ZIP file.""" | |
133 try: | |
134 with zipfile.ZipFile(zip_file, "r") as zip_ref: | |
135 file_list = [ | |
136 f for f in zip_ref.namelist() if f.lower().endswith( | |
137 (".png", ".jpg", ".jpeg", ".bmp", ".gif") | |
138 ) | |
139 ] | |
140 return file_list | |
141 except zipfile.BadZipFile as exc: | |
142 raise RuntimeError("Invalid ZIP file.") from exc | |
143 except Exception as exc: | |
144 raise RuntimeError("Error reading ZIP file.") from exc | |
145 | |
146 | |
147 def load_model(model_name, device): | |
148 """Loads a specified torchvision model and | |
149 modifies it for feature extraction.""" | |
150 if model_name not in AVAILABLE_MODELS: | |
151 raise ValueError( | |
152 f"Unsupported model: {model_name}. \ | |
153 Available models: {list(AVAILABLE_MODELS.keys())}") | |
154 try: | |
155 if "weights" in inspect.signature( | |
156 AVAILABLE_MODELS[model_name]).parameters: | |
157 model = AVAILABLE_MODELS[model_name](weights="DEFAULT").to(device) | |
158 else: | |
159 model = AVAILABLE_MODELS[model_name]().to(device) | |
160 logging.info("Model loaded") | |
161 except Exception as e: | |
162 logging.error(f"Failed to load model {model_name}: {e}") | |
163 raise | |
164 | |
165 if hasattr(model, "fc"): | |
166 model.fc = torch.nn.Identity() | |
167 elif hasattr(model, "classifier"): | |
168 model.classifier = torch.nn.Identity() | |
169 elif hasattr(model, "head"): | |
170 model.head = torch.nn.Identity() | |
171 | |
172 model.eval() | |
173 return model | |
174 | |
175 | |
176 def write_csv(output_csv, list_embeddings, ludwig_format=False): | |
177 """Writes embeddings to a CSV file, optionally in Ludwig format.""" | |
178 with open(output_csv, mode="w", encoding="utf-8", newline="") as csv_file: | |
179 csv_writer = csv.writer(csv_file) | |
180 if list_embeddings: | |
181 if ludwig_format: | |
182 header = ["sample_name", "embedding"] | |
183 formatted_embeddings = [] | |
184 for embedding in list_embeddings: | |
185 sample_name = embedding[0] | |
186 vector = embedding[1:] | |
187 embedding_str = " ".join(map(str, vector)) | |
188 formatted_embeddings.append([sample_name, embedding_str]) | |
189 csv_writer.writerow(header) | |
190 csv_writer.writerows(formatted_embeddings) | |
191 logging.info("CSV created in Ludwig format") | |
192 else: | |
193 header = ["sample_name"] + [f"vector{i + 1}" for i in range( | |
194 len(list_embeddings[0]) - 1 | |
195 )] | |
196 csv_writer.writerow(header) | |
197 csv_writer.writerows(list_embeddings) | |
198 logging.info("CSV created") | |
199 else: | |
200 csv_writer.writerow(["sample_name"] if not ludwig_format | |
201 else ["sample_name", "embedding"]) | |
202 logging.info("No valid images found. Empty CSV created.") | |
203 | |
204 | |
205 def extract_embeddings( | |
206 model_name, | |
207 apply_normalization, | |
208 zip_file, | |
209 file_list, | |
210 transform_type="rgb"): | |
211 """Extracts embeddings from images | |
212 using batch processing or sequential fallback.""" | |
213 | |
214 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
215 model = load_model(model_name, device) | |
216 model_settings = MODEL_DEFAULTS.get(model_name, MODEL_DEFAULTS["default"]) | |
217 resize = model_settings["resize"] | |
218 normalize = model_settings.get("normalize", ( | |
219 [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] | |
220 )) | |
221 | |
222 # Define transform pipeline | |
223 if transform_type == "grayscale": | |
224 initial_transform = transforms.Grayscale(num_output_channels=3) | |
225 elif transform_type == "clahe": | |
226 initial_transform = CLAHETransform() | |
227 elif transform_type == "edges": | |
228 initial_transform = CannyTransform() | |
229 elif transform_type == "rgba_to_rgb": | |
230 initial_transform = RGBAtoRGBTransform() | |
231 else: | |
232 initial_transform = transforms.Lambda(lambda x: x.convert("RGB")) | |
233 | |
234 transform_list = [initial_transform, | |
235 transforms.Resize(resize), | |
236 transforms.ToTensor()] | |
237 if apply_normalization: | |
238 transform_list.append(transforms.Normalize(mean=normalize[0], | |
239 std=normalize[1])) | |
240 transform = transforms.Compose(transform_list) | |
241 | |
242 class ImageDataset(Dataset): | |
243 def __init__(self, zip_file, file_list, transform=None): | |
244 self.zip_file = zip_file | |
245 self.file_list = file_list | |
246 self.transform = transform | |
247 | |
248 def __len__(self): | |
249 return len(self.file_list) | |
250 | |
251 def __getitem__(self, idx): | |
252 with zipfile.ZipFile(self.zip_file, "r") as zip_ref: | |
253 with zip_ref.open(self.file_list[idx]) as file: | |
254 try: | |
255 image = Image.open(file) | |
256 if self.transform: | |
257 image = self.transform(image) | |
258 return image, os.path.basename(self.file_list[idx]) | |
259 except Exception as e: | |
260 logging.warning( | |
261 "Skipping %s: %s", self.file_list[idx], e | |
262 ) | |
263 return None, os.path.basename(self.file_list[idx]) | |
264 | |
265 # Custom collate function | |
266 def collate_fn(batch): | |
267 batch = [item for item in batch if item[0] is not None] | |
268 if not batch: | |
269 return None, None | |
270 images, names = zip(*batch) | |
271 return torch.stack(images), names | |
272 | |
273 list_embeddings = [] | |
274 with torch.inference_mode(): | |
275 try: | |
276 # Try DataLoader with reduced resource usage | |
277 dataset = ImageDataset(zip_file, file_list, transform=transform) | |
278 dataloader = DataLoader( | |
279 dataset, | |
280 batch_size=16, # Reduced for lower memory usage | |
281 num_workers=1, # Reduced to minimize shared memory | |
282 shuffle=False, | |
283 pin_memory=True if device == "cuda" else False, | |
284 collate_fn=collate_fn, | |
285 ) | |
286 for images, names in dataloader: | |
287 if images is None: | |
288 continue | |
289 images = images.to(device) | |
290 embeddings = model(images).cpu().numpy() | |
291 for name, embedding in zip(names, embeddings): | |
292 list_embeddings.append([name] + embedding.tolist()) | |
293 except RuntimeError as e: | |
294 logging.warning( | |
295 f"DataLoader failed: {e}. \ | |
296 Falling back to sequential processing." | |
297 ) | |
298 # Fallback to sequential processing | |
299 for file in file_list: | |
300 with zipfile.ZipFile(zip_file, "r") as zip_ref: | |
301 with zip_ref.open(file) as img_file: | |
302 try: | |
303 image = Image.open(img_file) | |
304 image = transform(image) | |
305 input_tensor = image.unsqueeze(0).to(device) | |
306 embedding = model( | |
307 input_tensor | |
308 ).squeeze().cpu().numpy() | |
309 list_embeddings.append( | |
310 [os.path.basename(file)] + embedding.tolist() | |
311 ) | |
312 except Exception as e: | |
313 logging.warning("Skipping %s: %s", file, e) | |
314 | |
315 return list_embeddings | |
316 | |
317 | |
318 def main(zip_file, output_csv, model_name, apply_normalization=False, | |
319 transform_type="rgb", ludwig_format=False): | |
320 """Main entry point for processing the zip file and | |
321 extracting embeddings.""" | |
322 file_list = get_image_files_from_zip(zip_file) | |
323 logging.info("Image files listed from ZIP") | |
324 | |
325 list_embeddings = extract_embeddings( | |
326 model_name, | |
327 apply_normalization, | |
328 zip_file, | |
329 file_list, | |
330 transform_type | |
331 ) | |
332 logging.info("Embeddings extracted") | |
333 write_csv(output_csv, list_embeddings, ludwig_format) | |
334 | |
335 | |
336 if __name__ == "__main__": | |
337 parser = argparse.ArgumentParser(description="Extract image embeddings.") | |
338 parser.add_argument( | |
339 "--zip_file", | |
340 required=True, | |
341 help="Path to the ZIP file containing images." | |
342 ) | |
343 parser.add_argument( | |
344 "--model_name", | |
345 required=True, | |
346 choices=AVAILABLE_MODELS.keys(), | |
347 help="Model for embedding extraction." | |
348 ) | |
349 parser.add_argument( | |
350 "--normalize", | |
351 action="store_true", | |
352 help="Whether to apply normalization." | |
353 ) | |
354 parser.add_argument( | |
355 "--transform_type", | |
356 required=True, | |
357 help="Image transformation type." | |
358 ) | |
359 parser.add_argument( | |
360 "--output_csv", | |
361 required=True, | |
362 help="Path to the output CSV file" | |
363 ) | |
364 parser.add_argument( | |
365 "--ludwig_format", | |
366 action="store_true", | |
367 help="Prepare CSV file in Ludwig input format" | |
368 ) | |
369 | |
370 args = parser.parse_args() | |
371 main( | |
372 args.zip_file, | |
373 args.output_csv, | |
374 args.model_name, | |
375 args.normalize, | |
376 args.transform_type, | |
377 args.ludwig_format | |
378 ) |