annotate pytorch_embedding.py @ 1:84f96c952c2c draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
author goeckslab
date Sun, 09 Nov 2025 19:03:21 +0000
parents 38333676a029
children
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
0
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
1 """
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
2 This module provides functionality to extract image embeddings
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
3 using a specified
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
4 pretrained model from the torchvision library. It includes functions to:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
5 - List image files directly from a ZIP file without extraction.
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
6 - Apply model-specific preprocessing and transformations.
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
7 - Extract embeddings using various models.
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
8 - Save the resulting embeddings into a CSV file.
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
9 Modules required:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
10 - argparse: For command-line argument parsing.
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
11 - os, csv, zipfile: For file handling (ZIP file reading, CSV writing).
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
12 - inspect: For inspecting function signatures and models.
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
13 - torch, torchvision: For loading and using pretrained models
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
14 to extract embeddings.
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
15 - PIL, cv2: For image processing tasks such as resizing, normalization,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
16 and conversion.
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
17 """
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
18
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
19 import argparse
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
20 import csv
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
21 import logging
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
22 import os
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
23 import zipfile
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
24 from inspect import signature
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
25
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
26 import cv2
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
27 import numpy as np
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
28 import torch
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
29 import torchvision.models as models
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
30 from PIL import Image
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
31 from torch.utils.data import DataLoader, Dataset
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
32 from torchvision import transforms
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
33
1
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
34 # GPFM imports
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
35 try:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
36 import requests
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
37 GPFM_AVAILABLE = True
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
38 except ImportError as e:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
39 GPFM_AVAILABLE = False
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
40 logging.warning(f"GPFM dependencies not available: {e}")
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
41
0
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
42 # Configure logging
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
43 logging.basicConfig(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
44 format="%(asctime)s - %(levelname)s - %(message)s",
1
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
45 level=logging.INFO,
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
46 handlers=[
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
47 logging.StreamHandler(), # Console output
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
48 logging.FileHandler("/tmp/ludwig_embeddings.log", mode="a") # File output
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
49 ]
0
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
50 )
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
51
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
52 # Create a cache directory in the current working directory
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
53 cache_dir = os.path.join(os.getcwd(), 'hf_cache')
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
54 try:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
55 os.makedirs(cache_dir, exist_ok=True)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
56 logging.info(f"Cache directory created: {cache_dir}, writable: {os.access(cache_dir, os.W_OK)}")
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
57 except OSError as e:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
58 logging.error(f"Failed to create cache directory {cache_dir}: {e}")
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
59 raise
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
60
1
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
61 # GPFM DinoVisionTransformer Implementation
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
62
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
63
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
64 class DinoVisionTransformer(torch.nn.Module):
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
65 """Simplified DinoVisionTransformer for GPFM."""
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
66
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
67 def __init__(self, img_size=224, patch_size=14, embed_dim=1024, depth=24, num_heads=16):
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
68 super().__init__()
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
69 self.embed_dim = embed_dim
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
70 self.num_features = embed_dim
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
71 self.patch_size = patch_size
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
72
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
73 # Patch embedding
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
74 self.patch_embed = torch.nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
75 num_patches = (img_size // patch_size) ** 2
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
76
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
77 # Class token
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
78 self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, embed_dim))
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
79
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
80 # Position embeddings
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
81 self.pos_embed = torch.nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
82
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
83 # Transformer blocks (simplified)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
84 self.blocks = torch.nn.ModuleList([
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
85 torch.nn.TransformerEncoderLayer(
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
86 d_model=embed_dim,
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
87 nhead=num_heads,
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
88 dim_feedforward=embed_dim * 4,
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
89 dropout=0.0,
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
90 batch_first=True
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
91 ) for _ in range(depth)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
92 ])
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
93
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
94 # Layer norm
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
95 self.norm = torch.nn.LayerNorm(embed_dim)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
96
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
97 # Initialize weights
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
98 torch.nn.init.trunc_normal_(self.pos_embed, std=0.02)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
99 torch.nn.init.trunc_normal_(self.cls_token, std=0.02)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
100
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
101 def forward(self, x):
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
102 B = x.shape[0]
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
103
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
104 # Patch embedding
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
105 x = self.patch_embed(x) # B, embed_dim, H//patch_size, W//patch_size
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
106 x = x.flatten(2).transpose(1, 2) # B, num_patches, embed_dim
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
107
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
108 # Add class token
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
109 cls_tokens = self.cls_token.expand(B, -1, -1)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
110 x = torch.cat([cls_tokens, x], dim=1)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
111
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
112 # Add position embeddings
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
113 x = x + self.pos_embed
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
114
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
115 # Apply transformer blocks
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
116 for block in self.blocks:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
117 x = block(x)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
118
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
119 # Apply layer norm and return class token
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
120 x = self.norm(x)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
121 return x[:, 0] # Return class token features
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
122
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
123
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
124 # GPFM Model Implementation
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
125 class GPFMModel(torch.nn.Module):
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
126 """GPFM (Generalizable Pathology Foundation Model) implementation."""
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
127
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
128 def __init__(self, device='cpu'):
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
129 super().__init__()
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
130 self.device = device
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
131 self.model = None
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
132 self.transformer = None
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
133 self.embed_dim = 1024 # GPFM uses 1024-dimensional embeddings
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
134 self._load_model()
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
135
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
136 def _download_weights(self, url, filepath):
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
137 """Download GPFM weights from the official repository."""
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
138 if os.path.exists(filepath):
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
139 logging.info(f"GPFM weights already exist at {filepath}")
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
140 return True
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
141
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
142 logging.info(f"Downloading GPFM weights from {url}")
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
143 try:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
144 response = requests.get(url, stream=True, timeout=300)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
145 response.raise_for_status()
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
146
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
147 os.makedirs(os.path.dirname(filepath), exist_ok=True)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
148
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
149 # Get file size for progress tracking
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
150 total_size = int(response.headers.get('content-length', 0))
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
151 downloaded = 0
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
152
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
153 with open(filepath, 'wb') as f:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
154 for chunk in response.iter_content(chunk_size=8192):
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
155 if chunk:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
156 f.write(chunk)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
157 downloaded += len(chunk)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
158 if total_size > 0:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
159 progress = (downloaded / total_size) * 100
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
160 if downloaded % (1024 * 1024 * 10) == 0: # Log every 10MB
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
161 logging.info(f"Downloaded {downloaded // (1024 * 1024)}MB / {total_size // (1024 * 1024)}MB ({progress:.1f}%)")
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
162
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
163 logging.info(f"GPFM weights downloaded successfully to {filepath}")
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
164 return True
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
165
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
166 except Exception as e:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
167 logging.error(f"Failed to download GPFM weights: {e}")
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
168 if os.path.exists(filepath):
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
169 os.remove(filepath) # Clean up partial download
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
170 return False
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
171
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
172 def _load_model(self):
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
173 """Load GPFM model with pretrained weights."""
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
174 try:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
175 # Create models directory
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
176 models_dir = os.path.join(cache_dir, 'gpfm_models')
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
177 os.makedirs(models_dir, exist_ok=True)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
178
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
179 # GPFM weights URL from official repository
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
180 weights_url = "https://github.com/birkhoffkiki/GPFM/releases/download/ckpt/GPFM.pth"
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
181 weights_path = os.path.join(models_dir, 'GPFM.pth')
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
182
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
183 # Create GPFM DinoVisionTransformer architecture
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
184 self.model = DinoVisionTransformer(
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
185 img_size=224,
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
186 patch_size=14,
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
187 embed_dim=1024,
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
188 depth=24,
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
189 num_heads=16
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
190 )
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
191
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
192 # Try to download and load GPFM weights
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
193 weights_loaded = False
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
194 if self._download_weights(weights_url, weights_path):
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
195 try:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
196 logging.info("Loading GPFM pretrained weights...")
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
197 checkpoint = torch.load(weights_path, map_location=self.device)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
198
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
199 # Extract teacher model weights (GPFM format)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
200 if 'teacher' in checkpoint:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
201 state_dict = checkpoint['teacher']
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
202 logging.info("Found 'teacher' key in checkpoint")
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
203 else:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
204 state_dict = checkpoint
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
205 logging.info("Using checkpoint directly")
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
206
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
207 # Rename keys to match our simplified architecture
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
208 new_state_dict = {}
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
209 for k, v in state_dict.items():
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
210 # Remove 'backbone.' prefix if present
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
211 if k.startswith('backbone.'):
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
212 k = k[9:] # Remove 'backbone.'
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
213
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
214 # Map GPFM keys to our simplified architecture
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
215 if k in ['cls_token', 'pos_embed']:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
216 new_state_dict[k] = v
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
217 elif k.startswith('patch_embed.proj.'):
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
218 # Map patch embedding
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
219 new_k = k.replace('patch_embed.proj.', 'patch_embed.')
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
220 new_state_dict[new_k] = v
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
221 elif k.startswith('blocks.') and 'norm' in k:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
222 # Map layer norms
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
223 if k.endswith('.norm1.weight') or k.endswith('.norm1.bias'):
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
224 # Skip intermediate norms for simplified model
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
225 continue
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
226 elif k.endswith('.norm2.weight') or k.endswith('.norm2.bias'):
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
227 continue
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
228 elif k == 'norm.weight' or k == 'norm.bias':
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
229 new_state_dict[k] = v
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
230
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
231 # Load compatible weights
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
232 missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
233 if missing_keys:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
234 logging.warning(f"Missing keys: {missing_keys[:5]}...") # Show first 5
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
235 if unexpected_keys:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
236 logging.warning(f"Unexpected keys: {unexpected_keys[:5]}...") # Show first 5
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
237
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
238 logging.info("GPFM pretrained weights loaded successfully")
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
239 weights_loaded = True
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
240
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
241 except Exception as e:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
242 logging.warning(f"Could not load GPFM weights: {e}")
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
243
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
244 if not weights_loaded:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
245 logging.info("Using randomly initialized GPFM architecture (no pretrained weights)")
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
246
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
247 self.model = self.model.to(self.device)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
248 self.model.eval()
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
249
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
250 # GPFM preprocessing (based on official repository)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
251 self.transformer = transforms.Compose([
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
252 transforms.Lambda(lambda x: x.convert("RGB")), # Ensure RGB format
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
253 transforms.Resize((224, 224)), # GPFM uses 224x224 (not 512x512 for features)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
254 transforms.ToTensor(),
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
255 transforms.Normalize(
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
256 mean=[0.485, 0.456, 0.406], # ImageNet normalization
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
257 std=[0.229, 0.224, 0.225]
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
258 )
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
259 ])
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
260
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
261 logging.info(f"GPFM model initialized successfully (embed_dim: {self.embed_dim})")
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
262
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
263 except Exception as e:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
264 logging.error(f"Failed to initialize GPFM model: {e}")
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
265 raise
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
266
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
267 def forward(self, x):
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
268 """Forward pass through GPFM model."""
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
269 with torch.no_grad():
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
270 return self.model(x)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
271
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
272 def get_transformer(self, apply_normalization=True):
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
273 """Get the preprocessing transformer for GPFM."""
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
274 if apply_normalization:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
275 return self.transformer
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
276 else:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
277 # Return transformer without normalization
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
278 return transforms.Compose([
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
279 transforms.Lambda(lambda x: x.convert("RGB")),
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
280 transforms.Resize((224, 224)),
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
281 transforms.ToTensor()
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
282 ])
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
283
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
284
0
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
285 # Available models from torchvision
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
286 AVAILABLE_MODELS = {
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
287 name: getattr(models, name)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
288 for name in dir(models)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
289 if callable(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
290 getattr(models, name)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
291 ) and "weights" in signature(getattr(models, name)).parameters
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
292 }
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
293
1
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
294 # Add GPFM model if available
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
295 if GPFM_AVAILABLE:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
296 AVAILABLE_MODELS['gpfm'] = GPFMModel
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
297
0
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
298 # Default resize and normalization settings for models
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
299 MODEL_DEFAULTS = {
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
300 "default": {"resize": (224, 224), "normalize": (
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
301 [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
302 )},
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
303 "efficientnet_b1": {"resize": (240, 240)},
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
304 "efficientnet_b2": {"resize": (260, 260)},
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
305 "efficientnet_b3": {"resize": (300, 300)},
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
306 "efficientnet_b4": {"resize": (380, 380)},
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
307 "efficientnet_b5": {"resize": (456, 456)},
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
308 "efficientnet_b6": {"resize": (528, 528)},
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
309 "efficientnet_b7": {"resize": (600, 600)},
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
310 "inception_v3": {"resize": (299, 299)},
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
311 "swin_b": {"resize": (224, 224), "normalize": (
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
312 [0.5, 0.0, 0.5], [0.5, 0.5, 0.5]
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
313 )},
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
314 "swin_s": {"resize": (224, 224), "normalize": (
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
315 [0.5, 0.0, 0.5], [0.5, 0.5, 0.5]
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
316 )},
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
317 "swin_t": {"resize": (224, 224), "normalize": (
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
318 [0.5, 0.0, 0.5], [0.5, 0.5, 0.5]
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
319 )},
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
320 "vit_b_16": {"resize": (224, 224), "normalize": (
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
321 [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
322 )},
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
323 "vit_b_32": {"resize": (224, 224), "normalize": (
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
324 [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
325 )},
1
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
326 "gpfm": {
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
327 "resize": (224, 224),
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
328 "normalize": ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
329 },
0
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
330 }
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
331
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
332 for model, settings in MODEL_DEFAULTS.items():
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
333 if "normalize" not in settings:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
334 settings["normalize"] = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
335
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
336
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
337 # Custom transform classes
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
338 class CLAHETransform:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
339 def __init__(self, clip_limit=2.0, tile_grid_size=(8, 8)):
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
340 self.clahe = cv2.createCLAHE(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
341 clipLimit=clip_limit,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
342 tileGridSize=tile_grid_size
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
343 )
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
344
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
345 def __call__(self, img):
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
346 img = np.array(img.convert("L"))
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
347 img = self.clahe.apply(img)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
348 return Image.fromarray(img).convert("RGB")
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
349
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
350
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
351 class CannyTransform:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
352 def __init__(self, threshold1=100, threshold2=200):
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
353 self.threshold1 = threshold1
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
354 self.threshold2 = threshold2
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
355
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
356 def __call__(self, img):
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
357 img = np.array(img.convert("L"))
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
358 edges = cv2.Canny(img, self.threshold1, self.threshold2)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
359 return Image.fromarray(edges).convert("RGB")
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
360
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
361
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
362 class RGBAtoRGBTransform:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
363 def __call__(self, img):
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
364 if img.mode == "RGBA":
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
365 background = Image.new("RGBA", img.size, (255, 255, 255, 255))
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
366 img = Image.alpha_composite(background, img).convert("RGB")
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
367 else:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
368 img = img.convert("RGB")
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
369 return img
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
370
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
371
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
372 def get_image_files_from_zip(zip_file):
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
373 """Returns a list of image file names in the ZIP file."""
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
374 try:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
375 with zipfile.ZipFile(zip_file, "r") as zip_ref:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
376 file_list = [
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
377 f for f in zip_ref.namelist() if f.lower().endswith(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
378 (".png", ".jpg", ".jpeg", ".bmp", ".gif")
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
379 )
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
380 ]
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
381 return file_list
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
382 except zipfile.BadZipFile as exc:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
383 raise RuntimeError("Invalid ZIP file.") from exc
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
384 except Exception as exc:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
385 raise RuntimeError("Error reading ZIP file.") from exc
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
386
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
387
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
388 def load_model(model_name, device):
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
389 """Loads a specified torchvision model and
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
390 modifies it for feature extraction."""
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
391 if model_name not in AVAILABLE_MODELS:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
392 raise ValueError(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
393 f"Unsupported model: {model_name}. \
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
394 Available models: {list(AVAILABLE_MODELS.keys())}")
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
395 try:
1
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
396 # Special handling for GPFM
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
397 if model_name == "gpfm":
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
398 model = AVAILABLE_MODELS[model_name](device=device)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
399 logging.info("GPFM model loaded")
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
400 return model
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
401
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
402 # Standard torchvision models
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
403 if "weights" in signature(
0
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
404 AVAILABLE_MODELS[model_name]).parameters:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
405 model = AVAILABLE_MODELS[model_name](weights="DEFAULT").to(device)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
406 else:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
407 model = AVAILABLE_MODELS[model_name]().to(device)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
408 logging.info("Model loaded")
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
409 except Exception as e:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
410 logging.error(f"Failed to load model {model_name}: {e}")
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
411 raise
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
412
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
413 if hasattr(model, "fc"):
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
414 model.fc = torch.nn.Identity()
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
415 elif hasattr(model, "classifier"):
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
416 model.classifier = torch.nn.Identity()
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
417 elif hasattr(model, "head"):
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
418 model.head = torch.nn.Identity()
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
419
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
420 model.eval()
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
421 return model
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
422
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
423
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
424 def write_csv(output_csv, list_embeddings, ludwig_format=False):
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
425 """Writes embeddings to a CSV file, optionally in Ludwig format."""
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
426 with open(output_csv, mode="w", encoding="utf-8", newline="") as csv_file:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
427 csv_writer = csv.writer(csv_file)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
428 if list_embeddings:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
429 if ludwig_format:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
430 header = ["sample_name", "embedding"]
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
431 formatted_embeddings = []
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
432 for embedding in list_embeddings:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
433 sample_name = embedding[0]
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
434 vector = embedding[1:]
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
435 embedding_str = " ".join(map(str, vector))
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
436 formatted_embeddings.append([sample_name, embedding_str])
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
437 csv_writer.writerow(header)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
438 csv_writer.writerows(formatted_embeddings)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
439 logging.info("CSV created in Ludwig format")
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
440 else:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
441 header = ["sample_name"] + [f"vector{i + 1}" for i in range(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
442 len(list_embeddings[0]) - 1
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
443 )]
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
444 csv_writer.writerow(header)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
445 csv_writer.writerows(list_embeddings)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
446 logging.info("CSV created")
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
447 else:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
448 csv_writer.writerow(["sample_name"] if not ludwig_format
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
449 else ["sample_name", "embedding"])
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
450 logging.info("No valid images found. Empty CSV created.")
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
451
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
452
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
453 def extract_embeddings(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
454 model_name,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
455 apply_normalization,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
456 zip_file,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
457 file_list,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
458 transform_type="rgb"):
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
459 """Extracts embeddings from images
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
460 using batch processing or sequential fallback."""
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
461
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
462 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
463 model = load_model(model_name, device)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
464 model_settings = MODEL_DEFAULTS.get(model_name, MODEL_DEFAULTS["default"])
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
465 resize = model_settings["resize"]
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
466 normalize = model_settings.get("normalize", (
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
467 [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
468 ))
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
469
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
470 # Define transform pipeline
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
471 if transform_type == "grayscale":
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
472 initial_transform = transforms.Grayscale(num_output_channels=3)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
473 elif transform_type == "clahe":
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
474 initial_transform = CLAHETransform()
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
475 elif transform_type == "edges":
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
476 initial_transform = CannyTransform()
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
477 elif transform_type == "rgba_to_rgb":
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
478 initial_transform = RGBAtoRGBTransform()
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
479 else:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
480 initial_transform = transforms.Lambda(lambda x: x.convert("RGB"))
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
481
1
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
482 # Handle GPFM separately as it has its own preprocessing
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
483 if model_name == "gpfm":
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
484 # For GPFM, combine initial transform with GPFM's custom transformer
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
485 if transform_type in ["grayscale", "clahe", "edges", "rgba_to_rgb"]:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
486 transform = transforms.Compose([
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
487 initial_transform,
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
488 model.get_transformer(apply_normalization=apply_normalization)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
489 ])
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
490 else:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
491 transform = model.get_transformer(apply_normalization=apply_normalization)
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
492 else:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
493 # Standard torchvision models
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
494 transform_list = [initial_transform,
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
495 transforms.Resize(resize),
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
496 transforms.ToTensor()]
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
497 if apply_normalization:
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
498 transform_list.append(transforms.Normalize(mean=normalize[0],
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
499 std=normalize[1]))
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
500 transform = transforms.Compose(transform_list)
0
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
501
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
502 class ImageDataset(Dataset):
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
503 def __init__(self, zip_file, file_list, transform=None):
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
504 self.zip_file = zip_file
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
505 self.file_list = file_list
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
506 self.transform = transform
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
507
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
508 def __len__(self):
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
509 return len(self.file_list)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
510
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
511 def __getitem__(self, idx):
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
512 with zipfile.ZipFile(self.zip_file, "r") as zip_ref:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
513 with zip_ref.open(self.file_list[idx]) as file:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
514 try:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
515 image = Image.open(file)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
516 if self.transform:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
517 image = self.transform(image)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
518 return image, os.path.basename(self.file_list[idx])
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
519 except Exception as e:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
520 logging.warning(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
521 "Skipping %s: %s", self.file_list[idx], e
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
522 )
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
523 return None, os.path.basename(self.file_list[idx])
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
524
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
525 # Custom collate function
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
526 def collate_fn(batch):
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
527 batch = [item for item in batch if item[0] is not None]
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
528 if not batch:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
529 return None, None
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
530 images, names = zip(*batch)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
531 return torch.stack(images), names
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
532
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
533 list_embeddings = []
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
534 with torch.inference_mode():
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
535 try:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
536 # Try DataLoader with reduced resource usage
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
537 dataset = ImageDataset(zip_file, file_list, transform=transform)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
538 dataloader = DataLoader(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
539 dataset,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
540 batch_size=16, # Reduced for lower memory usage
1
84f96c952c2c planemo upload for repository https://github.com/goeckslab/gleam.git commit 5b6cd961948137853177b14b0fff80a5d40e8a07
goeckslab
parents: 0
diff changeset
541 num_workers=0, # Fix multiprocessing issues with GPFM
0
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
542 shuffle=False,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
543 pin_memory=True if device == "cuda" else False,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
544 collate_fn=collate_fn,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
545 )
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
546 for images, names in dataloader:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
547 if images is None:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
548 continue
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
549 images = images.to(device)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
550 embeddings = model(images).cpu().numpy()
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
551 for name, embedding in zip(names, embeddings):
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
552 list_embeddings.append([name] + embedding.tolist())
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
553 except RuntimeError as e:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
554 logging.warning(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
555 f"DataLoader failed: {e}. \
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
556 Falling back to sequential processing."
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
557 )
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
558 # Fallback to sequential processing
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
559 for file in file_list:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
560 with zipfile.ZipFile(zip_file, "r") as zip_ref:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
561 with zip_ref.open(file) as img_file:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
562 try:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
563 image = Image.open(img_file)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
564 image = transform(image)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
565 input_tensor = image.unsqueeze(0).to(device)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
566 embedding = model(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
567 input_tensor
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
568 ).squeeze().cpu().numpy()
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
569 list_embeddings.append(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
570 [os.path.basename(file)] + embedding.tolist()
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
571 )
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
572 except Exception as e:
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
573 logging.warning("Skipping %s: %s", file, e)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
574
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
575 return list_embeddings
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
576
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
577
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
578 def main(zip_file, output_csv, model_name, apply_normalization=False,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
579 transform_type="rgb", ludwig_format=False):
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
580 """Main entry point for processing the zip file and
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
581 extracting embeddings."""
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
582 file_list = get_image_files_from_zip(zip_file)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
583 logging.info("Image files listed from ZIP")
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
584
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
585 list_embeddings = extract_embeddings(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
586 model_name,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
587 apply_normalization,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
588 zip_file,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
589 file_list,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
590 transform_type
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
591 )
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
592 logging.info("Embeddings extracted")
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
593 write_csv(output_csv, list_embeddings, ludwig_format)
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
594
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
595
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
596 if __name__ == "__main__":
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
597 parser = argparse.ArgumentParser(description="Extract image embeddings.")
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
598 parser.add_argument(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
599 "--zip_file",
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
600 required=True,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
601 help="Path to the ZIP file containing images."
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
602 )
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
603 parser.add_argument(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
604 "--model_name",
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
605 required=True,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
606 choices=AVAILABLE_MODELS.keys(),
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
607 help="Model for embedding extraction."
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
608 )
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
609 parser.add_argument(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
610 "--normalize",
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
611 action="store_true",
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
612 help="Whether to apply normalization."
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
613 )
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
614 parser.add_argument(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
615 "--transform_type",
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
616 required=True,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
617 help="Image transformation type."
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
618 )
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
619 parser.add_argument(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
620 "--output_csv",
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
621 required=True,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
622 help="Path to the output CSV file"
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
623 )
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
624 parser.add_argument(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
625 "--ludwig_format",
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
626 action="store_true",
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
627 help="Prepare CSV file in Ludwig input format"
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
628 )
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
629
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
630 args = parser.parse_args()
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
631 main(
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
632 args.zip_file,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
633 args.output_csv,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
634 args.model_name,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
635 args.normalize,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
636 args.transform_type,
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
637 args.ludwig_format
38333676a029 planemo upload for repository https://github.com/goeckslab/gleam.git commit f57ec1ad637e8299db265ee08be0fa9d4d829b93
goeckslab
parents:
diff changeset
638 )