Mercurial > repos > goeckslab > extract_embeddings
annotate pytorch_embedding.py @ 1:84f96c952c2c draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
| author | goeckslab |
|---|---|
| date | Sun, 09 Nov 2025 19:03:21 +0000 |
| parents | 38333676a029 |
| children |
| rev | line source |
|---|---|
|
0
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
1 """ |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
2 This module provides functionality to extract image embeddings |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
3 using a specified |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
4 pretrained model from the torchvision library. It includes functions to: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
5 - List image files directly from a ZIP file without extraction. |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
6 - Apply model-specific preprocessing and transformations. |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
7 - Extract embeddings using various models. |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
8 - Save the resulting embeddings into a CSV file. |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
9 Modules required: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
10 - argparse: For command-line argument parsing. |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
11 - os, csv, zipfile: For file handling (ZIP file reading, CSV writing). |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
12 - inspect: For inspecting function signatures and models. |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
13 - torch, torchvision: For loading and using pretrained models |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
14 to extract embeddings. |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
15 - PIL, cv2: For image processing tasks such as resizing, normalization, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
16 and conversion. |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
17 """ |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
18 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
19 import argparse |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
20 import csv |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
21 import logging |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
22 import os |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
23 import zipfile |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
24 from inspect import signature |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
25 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
26 import cv2 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
27 import numpy as np |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
28 import torch |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
29 import torchvision.models as models |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
30 from PIL import Image |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
31 from torch.utils.data import DataLoader, Dataset |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
32 from torchvision import transforms |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
33 |
|
1
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
34 # GPFM imports |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
35 try: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
36 import requests |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
37 GPFM_AVAILABLE = True |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
38 except ImportError as e: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
39 GPFM_AVAILABLE = False |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
40 logging.warning(f"GPFM dependencies not available: {e}") |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
41 |
|
0
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
42 # Configure logging |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
43 logging.basicConfig( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
44 format="%(asctime)s - %(levelname)s - %(message)s", |
|
1
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
45 level=logging.INFO, |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
46 handlers=[ |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
47 logging.StreamHandler(), # Console output |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
48 logging.FileHandler("/tmp/ludwig_embeddings.log", mode="a") # File output |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
49 ] |
|
0
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
50 ) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
51 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
52 # Create a cache directory in the current working directory |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
53 cache_dir = os.path.join(os.getcwd(), 'hf_cache') |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
54 try: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
55 os.makedirs(cache_dir, exist_ok=True) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
56 logging.info(f"Cache directory created: {cache_dir}, writable: {os.access(cache_dir, os.W_OK)}") |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
57 except OSError as e: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
58 logging.error(f"Failed to create cache directory {cache_dir}: {e}") |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
59 raise |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
60 |
|
1
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
61 # GPFM DinoVisionTransformer Implementation |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
62 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
63 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
64 class DinoVisionTransformer(torch.nn.Module): |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
65 """Simplified DinoVisionTransformer for GPFM.""" |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
66 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
67 def __init__(self, img_size=224, patch_size=14, embed_dim=1024, depth=24, num_heads=16): |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
68 super().__init__() |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
69 self.embed_dim = embed_dim |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
70 self.num_features = embed_dim |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
71 self.patch_size = patch_size |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
72 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
73 # Patch embedding |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
74 self.patch_embed = torch.nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
75 num_patches = (img_size // patch_size) ** 2 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
76 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
77 # Class token |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
78 self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
79 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
80 # Position embeddings |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
81 self.pos_embed = torch.nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
82 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
83 # Transformer blocks (simplified) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
84 self.blocks = torch.nn.ModuleList([ |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
85 torch.nn.TransformerEncoderLayer( |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
86 d_model=embed_dim, |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
87 nhead=num_heads, |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
88 dim_feedforward=embed_dim * 4, |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
89 dropout=0.0, |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
90 batch_first=True |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
91 ) for _ in range(depth) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
92 ]) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
93 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
94 # Layer norm |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
95 self.norm = torch.nn.LayerNorm(embed_dim) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
96 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
97 # Initialize weights |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
98 torch.nn.init.trunc_normal_(self.pos_embed, std=0.02) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
99 torch.nn.init.trunc_normal_(self.cls_token, std=0.02) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
100 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
101 def forward(self, x): |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
102 B = x.shape[0] |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
103 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
104 # Patch embedding |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
105 x = self.patch_embed(x) # B, embed_dim, H//patch_size, W//patch_size |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
106 x = x.flatten(2).transpose(1, 2) # B, num_patches, embed_dim |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
107 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
108 # Add class token |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
109 cls_tokens = self.cls_token.expand(B, -1, -1) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
110 x = torch.cat([cls_tokens, x], dim=1) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
111 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
112 # Add position embeddings |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
113 x = x + self.pos_embed |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
114 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
115 # Apply transformer blocks |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
116 for block in self.blocks: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
117 x = block(x) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
118 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
119 # Apply layer norm and return class token |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
120 x = self.norm(x) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
121 return x[:, 0] # Return class token features |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
122 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
123 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
124 # GPFM Model Implementation |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
125 class GPFMModel(torch.nn.Module): |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
126 """GPFM (Generalizable Pathology Foundation Model) implementation.""" |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
127 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
128 def __init__(self, device='cpu'): |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
129 super().__init__() |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
130 self.device = device |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
131 self.model = None |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
132 self.transformer = None |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
133 self.embed_dim = 1024 # GPFM uses 1024-dimensional embeddings |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
134 self._load_model() |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
135 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
136 def _download_weights(self, url, filepath): |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
137 """Download GPFM weights from the official repository.""" |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
138 if os.path.exists(filepath): |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
139 logging.info(f"GPFM weights already exist at {filepath}") |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
140 return True |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
141 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
142 logging.info(f"Downloading GPFM weights from {url}") |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
143 try: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
144 response = requests.get(url, stream=True, timeout=300) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
145 response.raise_for_status() |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
146 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
147 os.makedirs(os.path.dirname(filepath), exist_ok=True) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
148 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
149 # Get file size for progress tracking |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
150 total_size = int(response.headers.get('content-length', 0)) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
151 downloaded = 0 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
152 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
153 with open(filepath, 'wb') as f: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
154 for chunk in response.iter_content(chunk_size=8192): |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
155 if chunk: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
156 f.write(chunk) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
157 downloaded += len(chunk) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
158 if total_size > 0: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
159 progress = (downloaded / total_size) * 100 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
160 if downloaded % (1024 * 1024 * 10) == 0: # Log every 10MB |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
161 logging.info(f"Downloaded {downloaded // (1024 * 1024)}MB / {total_size // (1024 * 1024)}MB ({progress:.1f}%)") |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
162 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
163 logging.info(f"GPFM weights downloaded successfully to {filepath}") |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
164 return True |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
165 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
166 except Exception as e: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
167 logging.error(f"Failed to download GPFM weights: {e}") |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
168 if os.path.exists(filepath): |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
169 os.remove(filepath) # Clean up partial download |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
170 return False |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
171 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
172 def _load_model(self): |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
173 """Load GPFM model with pretrained weights.""" |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
174 try: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
175 # Create models directory |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
176 models_dir = os.path.join(cache_dir, 'gpfm_models') |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
177 os.makedirs(models_dir, exist_ok=True) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
178 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
179 # GPFM weights URL from official repository |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
180 weights_url = "https://github.com/birkhoffkiki/GPFM/releases/download/ckpt/GPFM.pth" |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
181 weights_path = os.path.join(models_dir, 'GPFM.pth') |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
182 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
183 # Create GPFM DinoVisionTransformer architecture |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
184 self.model = DinoVisionTransformer( |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
185 img_size=224, |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
186 patch_size=14, |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
187 embed_dim=1024, |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
188 depth=24, |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
189 num_heads=16 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
190 ) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
191 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
192 # Try to download and load GPFM weights |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
193 weights_loaded = False |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
194 if self._download_weights(weights_url, weights_path): |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
195 try: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
196 logging.info("Loading GPFM pretrained weights...") |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
197 checkpoint = torch.load(weights_path, map_location=self.device) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
198 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
199 # Extract teacher model weights (GPFM format) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
200 if 'teacher' in checkpoint: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
201 state_dict = checkpoint['teacher'] |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
202 logging.info("Found 'teacher' key in checkpoint") |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
203 else: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
204 state_dict = checkpoint |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
205 logging.info("Using checkpoint directly") |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
206 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
207 # Rename keys to match our simplified architecture |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
208 new_state_dict = {} |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
209 for k, v in state_dict.items(): |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
210 # Remove 'backbone.' prefix if present |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
211 if k.startswith('backbone.'): |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
212 k = k[9:] # Remove 'backbone.' |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
213 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
214 # Map GPFM keys to our simplified architecture |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
215 if k in ['cls_token', 'pos_embed']: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
216 new_state_dict[k] = v |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
217 elif k.startswith('patch_embed.proj.'): |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
218 # Map patch embedding |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
219 new_k = k.replace('patch_embed.proj.', 'patch_embed.') |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
220 new_state_dict[new_k] = v |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
221 elif k.startswith('blocks.') and 'norm' in k: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
222 # Map layer norms |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
223 if k.endswith('.norm1.weight') or k.endswith('.norm1.bias'): |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
224 # Skip intermediate norms for simplified model |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
225 continue |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
226 elif k.endswith('.norm2.weight') or k.endswith('.norm2.bias'): |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
227 continue |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
228 elif k == 'norm.weight' or k == 'norm.bias': |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
229 new_state_dict[k] = v |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
230 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
231 # Load compatible weights |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
232 missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
233 if missing_keys: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
234 logging.warning(f"Missing keys: {missing_keys[:5]}...") # Show first 5 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
235 if unexpected_keys: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
236 logging.warning(f"Unexpected keys: {unexpected_keys[:5]}...") # Show first 5 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
237 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
238 logging.info("GPFM pretrained weights loaded successfully") |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
239 weights_loaded = True |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
240 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
241 except Exception as e: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
242 logging.warning(f"Could not load GPFM weights: {e}") |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
243 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
244 if not weights_loaded: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
245 logging.info("Using randomly initialized GPFM architecture (no pretrained weights)") |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
246 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
247 self.model = self.model.to(self.device) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
248 self.model.eval() |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
249 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
250 # GPFM preprocessing (based on official repository) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
251 self.transformer = transforms.Compose([ |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
252 transforms.Lambda(lambda x: x.convert("RGB")), # Ensure RGB format |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
253 transforms.Resize((224, 224)), # GPFM uses 224x224 (not 512x512 for features) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
254 transforms.ToTensor(), |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
255 transforms.Normalize( |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
256 mean=[0.485, 0.456, 0.406], # ImageNet normalization |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
257 std=[0.229, 0.224, 0.225] |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
258 ) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
259 ]) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
260 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
261 logging.info(f"GPFM model initialized successfully (embed_dim: {self.embed_dim})") |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
262 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
263 except Exception as e: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
264 logging.error(f"Failed to initialize GPFM model: {e}") |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
265 raise |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
266 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
267 def forward(self, x): |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
268 """Forward pass through GPFM model.""" |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
269 with torch.no_grad(): |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
270 return self.model(x) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
271 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
272 def get_transformer(self, apply_normalization=True): |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
273 """Get the preprocessing transformer for GPFM.""" |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
274 if apply_normalization: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
275 return self.transformer |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
276 else: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
277 # Return transformer without normalization |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
278 return transforms.Compose([ |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
279 transforms.Lambda(lambda x: x.convert("RGB")), |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
280 transforms.Resize((224, 224)), |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
281 transforms.ToTensor() |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
282 ]) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
283 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
284 |
|
0
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
285 # Available models from torchvision |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
286 AVAILABLE_MODELS = { |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
287 name: getattr(models, name) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
288 for name in dir(models) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
289 if callable( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
290 getattr(models, name) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
291 ) and "weights" in signature(getattr(models, name)).parameters |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
292 } |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
293 |
|
1
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
294 # Add GPFM model if available |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
295 if GPFM_AVAILABLE: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
296 AVAILABLE_MODELS['gpfm'] = GPFMModel |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
297 |
|
0
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
298 # Default resize and normalization settings for models |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
299 MODEL_DEFAULTS = { |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
300 "default": {"resize": (224, 224), "normalize": ( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
301 [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
302 )}, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
303 "efficientnet_b1": {"resize": (240, 240)}, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
304 "efficientnet_b2": {"resize": (260, 260)}, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
305 "efficientnet_b3": {"resize": (300, 300)}, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
306 "efficientnet_b4": {"resize": (380, 380)}, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
307 "efficientnet_b5": {"resize": (456, 456)}, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
308 "efficientnet_b6": {"resize": (528, 528)}, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
309 "efficientnet_b7": {"resize": (600, 600)}, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
310 "inception_v3": {"resize": (299, 299)}, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
311 "swin_b": {"resize": (224, 224), "normalize": ( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
312 [0.5, 0.0, 0.5], [0.5, 0.5, 0.5] |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
313 )}, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
314 "swin_s": {"resize": (224, 224), "normalize": ( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
315 [0.5, 0.0, 0.5], [0.5, 0.5, 0.5] |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
316 )}, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
317 "swin_t": {"resize": (224, 224), "normalize": ( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
318 [0.5, 0.0, 0.5], [0.5, 0.5, 0.5] |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
319 )}, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
320 "vit_b_16": {"resize": (224, 224), "normalize": ( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
321 [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
322 )}, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
323 "vit_b_32": {"resize": (224, 224), "normalize": ( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
324 [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
325 )}, |
|
1
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
326 "gpfm": { |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
327 "resize": (224, 224), |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
328 "normalize": ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
329 }, |
|
0
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
330 } |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
331 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
332 for model, settings in MODEL_DEFAULTS.items(): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
333 if "normalize" not in settings: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
334 settings["normalize"] = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
335 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
336 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
337 # Custom transform classes |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
338 class CLAHETransform: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
339 def __init__(self, clip_limit=2.0, tile_grid_size=(8, 8)): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
340 self.clahe = cv2.createCLAHE( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
341 clipLimit=clip_limit, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
342 tileGridSize=tile_grid_size |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
343 ) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
344 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
345 def __call__(self, img): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
346 img = np.array(img.convert("L")) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
347 img = self.clahe.apply(img) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
348 return Image.fromarray(img).convert("RGB") |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
349 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
350 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
351 class CannyTransform: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
352 def __init__(self, threshold1=100, threshold2=200): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
353 self.threshold1 = threshold1 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
354 self.threshold2 = threshold2 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
355 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
356 def __call__(self, img): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
357 img = np.array(img.convert("L")) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
358 edges = cv2.Canny(img, self.threshold1, self.threshold2) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
359 return Image.fromarray(edges).convert("RGB") |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
360 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
361 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
362 class RGBAtoRGBTransform: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
363 def __call__(self, img): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
364 if img.mode == "RGBA": |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
365 background = Image.new("RGBA", img.size, (255, 255, 255, 255)) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
366 img = Image.alpha_composite(background, img).convert("RGB") |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
367 else: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
368 img = img.convert("RGB") |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
369 return img |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
370 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
371 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
372 def get_image_files_from_zip(zip_file): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
373 """Returns a list of image file names in the ZIP file.""" |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
374 try: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
375 with zipfile.ZipFile(zip_file, "r") as zip_ref: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
376 file_list = [ |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
377 f for f in zip_ref.namelist() if f.lower().endswith( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
378 (".png", ".jpg", ".jpeg", ".bmp", ".gif") |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
379 ) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
380 ] |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
381 return file_list |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
382 except zipfile.BadZipFile as exc: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
383 raise RuntimeError("Invalid ZIP file.") from exc |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
384 except Exception as exc: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
385 raise RuntimeError("Error reading ZIP file.") from exc |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
386 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
387 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
388 def load_model(model_name, device): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
389 """Loads a specified torchvision model and |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
390 modifies it for feature extraction.""" |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
391 if model_name not in AVAILABLE_MODELS: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
392 raise ValueError( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
393 f"Unsupported model: {model_name}. \ |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
394 Available models: {list(AVAILABLE_MODELS.keys())}") |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
395 try: |
|
1
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
396 # Special handling for GPFM |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
397 if model_name == "gpfm": |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
398 model = AVAILABLE_MODELS[model_name](device=device) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
399 logging.info("GPFM model loaded") |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
400 return model |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
401 |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
402 # Standard torchvision models |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
403 if "weights" in signature( |
|
0
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
404 AVAILABLE_MODELS[model_name]).parameters: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
405 model = AVAILABLE_MODELS[model_name](weights="DEFAULT").to(device) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
406 else: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
407 model = AVAILABLE_MODELS[model_name]().to(device) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
408 logging.info("Model loaded") |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
409 except Exception as e: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
410 logging.error(f"Failed to load model {model_name}: {e}") |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
411 raise |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
412 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
413 if hasattr(model, "fc"): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
414 model.fc = torch.nn.Identity() |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
415 elif hasattr(model, "classifier"): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
416 model.classifier = torch.nn.Identity() |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
417 elif hasattr(model, "head"): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
418 model.head = torch.nn.Identity() |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
419 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
420 model.eval() |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
421 return model |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
422 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
423 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
424 def write_csv(output_csv, list_embeddings, ludwig_format=False): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
425 """Writes embeddings to a CSV file, optionally in Ludwig format.""" |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
426 with open(output_csv, mode="w", encoding="utf-8", newline="") as csv_file: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
427 csv_writer = csv.writer(csv_file) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
428 if list_embeddings: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
429 if ludwig_format: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
430 header = ["sample_name", "embedding"] |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
431 formatted_embeddings = [] |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
432 for embedding in list_embeddings: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
433 sample_name = embedding[0] |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
434 vector = embedding[1:] |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
435 embedding_str = " ".join(map(str, vector)) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
436 formatted_embeddings.append([sample_name, embedding_str]) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
437 csv_writer.writerow(header) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
438 csv_writer.writerows(formatted_embeddings) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
439 logging.info("CSV created in Ludwig format") |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
440 else: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
441 header = ["sample_name"] + [f"vector{i + 1}" for i in range( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
442 len(list_embeddings[0]) - 1 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
443 )] |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
444 csv_writer.writerow(header) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
445 csv_writer.writerows(list_embeddings) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
446 logging.info("CSV created") |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
447 else: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
448 csv_writer.writerow(["sample_name"] if not ludwig_format |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
449 else ["sample_name", "embedding"]) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
450 logging.info("No valid images found. Empty CSV created.") |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
451 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
452 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
453 def extract_embeddings( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
454 model_name, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
455 apply_normalization, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
456 zip_file, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
457 file_list, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
458 transform_type="rgb"): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
459 """Extracts embeddings from images |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
460 using batch processing or sequential fallback.""" |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
461 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
462 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
463 model = load_model(model_name, device) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
464 model_settings = MODEL_DEFAULTS.get(model_name, MODEL_DEFAULTS["default"]) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
465 resize = model_settings["resize"] |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
466 normalize = model_settings.get("normalize", ( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
467 [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
468 )) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
469 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
470 # Define transform pipeline |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
471 if transform_type == "grayscale": |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
472 initial_transform = transforms.Grayscale(num_output_channels=3) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
473 elif transform_type == "clahe": |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
474 initial_transform = CLAHETransform() |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
475 elif transform_type == "edges": |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
476 initial_transform = CannyTransform() |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
477 elif transform_type == "rgba_to_rgb": |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
478 initial_transform = RGBAtoRGBTransform() |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
479 else: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
480 initial_transform = transforms.Lambda(lambda x: x.convert("RGB")) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
481 |
|
1
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
482 # Handle GPFM separately as it has its own preprocessing |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
483 if model_name == "gpfm": |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
484 # For GPFM, combine initial transform with GPFM's custom transformer |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
485 if transform_type in ["grayscale", "clahe", "edges", "rgba_to_rgb"]: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
486 transform = transforms.Compose([ |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
487 initial_transform, |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
488 model.get_transformer(apply_normalization=apply_normalization) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
489 ]) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
490 else: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
491 transform = model.get_transformer(apply_normalization=apply_normalization) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
492 else: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
493 # Standard torchvision models |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
494 transform_list = [initial_transform, |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
495 transforms.Resize(resize), |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
496 transforms.ToTensor()] |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
497 if apply_normalization: |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
498 transform_list.append(transforms.Normalize(mean=normalize[0], |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
499 std=normalize[1])) |
|
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
500 transform = transforms.Compose(transform_list) |
|
0
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
501 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
502 class ImageDataset(Dataset): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
503 def __init__(self, zip_file, file_list, transform=None): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
504 self.zip_file = zip_file |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
505 self.file_list = file_list |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
506 self.transform = transform |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
507 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
508 def __len__(self): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
509 return len(self.file_list) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
510 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
511 def __getitem__(self, idx): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
512 with zipfile.ZipFile(self.zip_file, "r") as zip_ref: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
513 with zip_ref.open(self.file_list[idx]) as file: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
514 try: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
515 image = Image.open(file) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
516 if self.transform: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
517 image = self.transform(image) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
518 return image, os.path.basename(self.file_list[idx]) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
519 except Exception as e: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
520 logging.warning( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
521 "Skipping %s: %s", self.file_list[idx], e |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
522 ) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
523 return None, os.path.basename(self.file_list[idx]) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
524 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
525 # Custom collate function |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
526 def collate_fn(batch): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
527 batch = [item for item in batch if item[0] is not None] |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
528 if not batch: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
529 return None, None |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
530 images, names = zip(*batch) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
531 return torch.stack(images), names |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
532 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
533 list_embeddings = [] |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
534 with torch.inference_mode(): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
535 try: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
536 # Try DataLoader with reduced resource usage |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
537 dataset = ImageDataset(zip_file, file_list, transform=transform) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
538 dataloader = DataLoader( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
539 dataset, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
540 batch_size=16, # Reduced for lower memory usage |
|
1
84f96c952c2c
planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents:
0
diff
changeset
|
541 num_workers=0, # Fix multiprocessing issues with GPFM |
|
0
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
542 shuffle=False, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
543 pin_memory=True if device == "cuda" else False, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
544 collate_fn=collate_fn, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
545 ) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
546 for images, names in dataloader: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
547 if images is None: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
548 continue |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
549 images = images.to(device) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
550 embeddings = model(images).cpu().numpy() |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
551 for name, embedding in zip(names, embeddings): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
552 list_embeddings.append([name] + embedding.tolist()) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
553 except RuntimeError as e: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
554 logging.warning( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
555 f"DataLoader failed: {e}. \ |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
556 Falling back to sequential processing." |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
557 ) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
558 # Fallback to sequential processing |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
559 for file in file_list: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
560 with zipfile.ZipFile(zip_file, "r") as zip_ref: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
561 with zip_ref.open(file) as img_file: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
562 try: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
563 image = Image.open(img_file) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
564 image = transform(image) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
565 input_tensor = image.unsqueeze(0).to(device) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
566 embedding = model( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
567 input_tensor |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
568 ).squeeze().cpu().numpy() |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
569 list_embeddings.append( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
570 [os.path.basename(file)] + embedding.tolist() |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
571 ) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
572 except Exception as e: |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
573 logging.warning("Skipping %s: %s", file, e) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
574 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
575 return list_embeddings |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
576 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
577 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
578 def main(zip_file, output_csv, model_name, apply_normalization=False, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
579 transform_type="rgb", ludwig_format=False): |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
580 """Main entry point for processing the zip file and |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
581 extracting embeddings.""" |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
582 file_list = get_image_files_from_zip(zip_file) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
583 logging.info("Image files listed from ZIP") |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
584 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
585 list_embeddings = extract_embeddings( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
586 model_name, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
587 apply_normalization, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
588 zip_file, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
589 file_list, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
590 transform_type |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
591 ) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
592 logging.info("Embeddings extracted") |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
593 write_csv(output_csv, list_embeddings, ludwig_format) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
594 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
595 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
596 if __name__ == "__main__": |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
597 parser = argparse.ArgumentParser(description="Extract image embeddings.") |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
598 parser.add_argument( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
599 "--zip_file", |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
600 required=True, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
601 help="Path to the ZIP file containing images." |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
602 ) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
603 parser.add_argument( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
604 "--model_name", |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
605 required=True, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
606 choices=AVAILABLE_MODELS.keys(), |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
607 help="Model for embedding extraction." |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
608 ) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
609 parser.add_argument( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
610 "--normalize", |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
611 action="store_true", |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
612 help="Whether to apply normalization." |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
613 ) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
614 parser.add_argument( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
615 "--transform_type", |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
616 required=True, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
617 help="Image transformation type." |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
618 ) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
619 parser.add_argument( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
620 "--output_csv", |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
621 required=True, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
622 help="Path to the output CSV file" |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
623 ) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
624 parser.add_argument( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
625 "--ludwig_format", |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
626 action="store_true", |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
627 help="Prepare CSV file in Ludwig input format" |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
628 ) |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
629 |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
630 args = parser.parse_args() |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
631 main( |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
632 args.zip_file, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
633 args.output_csv, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
634 args.model_name, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
635 args.normalize, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
636 args.transform_type, |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
637 args.ludwig_format |
|
38333676a029
planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff
changeset
|
638 ) |
