Mercurial > repos > goeckslab > multimodal_learner
comparison utils.py @ 2:b708d0e210e6 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit ffd47c4881aaa9fc33e7d3993a8fdf4bd82f3792
| author | goeckslab |
|---|---|
| date | Sat, 10 Jan 2026 16:13:19 +0000 |
| parents | 375c36923da1 |
| children |
comparison
equal
deleted
inserted
replaced
| 1:a92f200d296e | 2:b708d0e210e6 |
|---|---|
| 1 import errno | |
| 1 import json | 2 import json |
| 2 import logging | 3 import logging |
| 3 import os | 4 import os |
| 4 import random | 5 import random |
| 5 import sys | 6 import sys |
| 6 import tempfile | 7 import tempfile |
| 7 import zipfile | 8 import zipfile |
| 9 from collections import OrderedDict | |
| 8 from pathlib import Path | 10 from pathlib import Path |
| 9 from typing import List, Optional | 11 from typing import List, Optional |
| 10 | 12 |
| 11 import numpy as np | 13 import numpy as np |
| 12 import pandas as pd | 14 import pandas as pd |
| 13 import torch | 15 import torch |
| 14 | 16 |
| 15 LOG = logging.getLogger(__name__) | 17 LOG = logging.getLogger(__name__) |
| 18 _IMAGE_EXTENSIONS = { | |
| 19 ".jpg", | |
| 20 ".jpeg", | |
| 21 ".png", | |
| 22 ".bmp", | |
| 23 ".gif", | |
| 24 ".tif", | |
| 25 ".tiff", | |
| 26 ".webp", | |
| 27 ".svs", | |
| 28 } | |
| 29 _MAX_PATH_COMPONENT = 255 | |
| 30 _MAX_EXTRACTED_INDEX_CACHE_SIZE = 2 | |
| 31 _MAX_EXTRACTED_INDEX_FILES = 100000 | |
| 32 _EXTRACTED_INDEX_CACHE = OrderedDict() | |
| 16 | 33 |
| 17 | 34 |
| 18 def str2bool(val) -> bool: | 35 def str2bool(val) -> bool: |
| 19 """Parse common truthy strings to bool.""" | 36 """Parse common truthy strings to bool.""" |
| 20 return str(val).strip().lower() in ("1", "true", "yes", "y") | 37 return str(val).strip().lower() in ("1", "true", "yes", "y") |
| 87 if not path.exists(): | 104 if not path.exists(): |
| 88 raise FileNotFoundError(f"Dataset not found: {path}") | 105 raise FileNotFoundError(f"Dataset not found: {path}") |
| 89 return pd.read_csv(path, sep=None, engine="python") | 106 return pd.read_csv(path, sep=None, engine="python") |
| 90 | 107 |
| 91 | 108 |
| 109 def _normalize_path_value(val: object) -> Optional[str]: | |
| 110 if val is None: | |
| 111 return None | |
| 112 s = str(val).strip().strip('"').strip("'") | |
| 113 return s if s else None | |
| 114 | |
| 115 | |
| 116 def _warn_if_long_component(path_str: str) -> None: | |
| 117 for part in path_str.replace("\\", "/").split("/"): | |
| 118 if len(part) > _MAX_PATH_COMPONENT: | |
| 119 LOG.warning( | |
| 120 "Path component exceeds %d chars; resolution may fail: %s", | |
| 121 _MAX_PATH_COMPONENT, | |
| 122 path_str, | |
| 123 ) | |
| 124 return | |
| 125 | |
| 126 | |
| 127 def _build_extracted_index(extracted_root: Optional[Path]) -> set: | |
| 128 if extracted_root is None: | |
| 129 return set() | |
| 130 index = set() | |
| 131 for root, _dirs, files in os.walk(extracted_root): | |
| 132 rel_root = os.path.relpath(root, extracted_root) | |
| 133 for fname in files: | |
| 134 ext = os.path.splitext(fname)[1].lower() | |
| 135 if ext not in _IMAGE_EXTENSIONS: | |
| 136 continue | |
| 137 rel_path = fname if rel_root == "." else os.path.join(rel_root, fname) | |
| 138 index.add(rel_path.replace("\\", "/")) | |
| 139 index.add(fname) | |
| 140 return index | |
| 141 | |
| 142 | |
| 143 def _get_cached_extracted_index(extracted_root: Optional[Path]) -> set: | |
| 144 if extracted_root is None: | |
| 145 return set() | |
| 146 try: | |
| 147 root = extracted_root.resolve() | |
| 148 except Exception: | |
| 149 root = extracted_root | |
| 150 cache_key = str(root) | |
| 151 try: | |
| 152 mtime_ns = root.stat().st_mtime_ns | |
| 153 except OSError: | |
| 154 _EXTRACTED_INDEX_CACHE.pop(cache_key, None) | |
| 155 return _build_extracted_index(root) | |
| 156 cached = _EXTRACTED_INDEX_CACHE.get(cache_key) | |
| 157 if cached: | |
| 158 cached_mtime, cached_index = cached | |
| 159 if cached_mtime == mtime_ns: | |
| 160 _EXTRACTED_INDEX_CACHE.move_to_end(cache_key) | |
| 161 LOG.debug("Using cached extracted index for %s (%d entries)", root, len(cached_index)) | |
| 162 return cached_index | |
| 163 _EXTRACTED_INDEX_CACHE.pop(cache_key, None) | |
| 164 LOG.debug("Invalidated extracted index cache for %s (mtime changed)", root) | |
| 165 else: | |
| 166 LOG.debug("No extracted index cache for %s; building", root) | |
| 167 index = _build_extracted_index(root) | |
| 168 if len(index) <= _MAX_EXTRACTED_INDEX_FILES: | |
| 169 _EXTRACTED_INDEX_CACHE[cache_key] = (mtime_ns, index) | |
| 170 _EXTRACTED_INDEX_CACHE.move_to_end(cache_key) | |
| 171 while len(_EXTRACTED_INDEX_CACHE) > _MAX_EXTRACTED_INDEX_CACHE_SIZE: | |
| 172 _EXTRACTED_INDEX_CACHE.popitem(last=False) | |
| 173 else: | |
| 174 LOG.debug("Extracted index has %d entries; skipping cache for %s", len(index), root) | |
| 175 return index | |
| 176 | |
| 177 | |
| 92 def prepare_image_search_dirs(args) -> Optional[Path]: | 178 def prepare_image_search_dirs(args) -> Optional[Path]: |
| 93 if not args.images_zip: | 179 if not args.images_zip: |
| 94 return None | 180 return None |
| 95 | 181 |
| 96 root = Path(tempfile.mkdtemp(prefix="autogluon_images_")) | 182 root = Path(tempfile.mkdtemp(prefix="autogluon_images_")) |
| 115 """ | 201 """ |
| 116 if df is None or df.empty: | 202 if df is None or df.empty: |
| 117 return [] | 203 return [] |
| 118 | 204 |
| 119 image_columns = [c for c in (image_columns or []) if c in df.columns] | 205 image_columns = [c for c in (image_columns or []) if c in df.columns] |
| 206 extracted_index = None | |
| 207 | |
| 208 def get_extracted_index() -> set: | |
| 209 nonlocal extracted_index | |
| 210 if extracted_index is None: | |
| 211 extracted_index = _get_cached_extracted_index(extracted_root) | |
| 212 return extracted_index | |
| 120 | 213 |
| 121 def resolve(p): | 214 def resolve(p): |
| 122 if pd.isna(p): | 215 if pd.isna(p): |
| 123 return None | 216 return None |
| 124 orig = Path(str(p).strip()) | 217 raw = _normalize_path_value(p) |
| 218 if not raw: | |
| 219 return None | |
| 220 _warn_if_long_component(raw) | |
| 221 orig = Path(raw) | |
| 125 candidates = [] | 222 candidates = [] |
| 126 if orig.is_absolute(): | 223 if orig.is_absolute(): |
| 127 candidates.append(orig) | 224 candidates.append(orig) |
| 128 if extracted_root is not None: | 225 if extracted_root is not None: |
| 129 candidates.extend([extracted_root / orig, extracted_root / orig.name]) | 226 candidates.extend([extracted_root / orig, extracted_root / orig.name]) |
| 130 for cand in candidates: | 227 for cand in candidates: |
| 131 if cand.exists(): | 228 try: |
| 132 return str(cand.resolve()) | 229 if cand.exists(): |
| 133 return None | 230 return str(cand.resolve()) |
| 231 except OSError as e: | |
| 232 if e.errno == errno.ENAMETOOLONG: | |
| 233 LOG.warning("Path too long for filesystem: %s", cand) | |
| 234 continue | |
| 235 return None | |
| 236 | |
| 237 def matches_extracted(p) -> bool: | |
| 238 if pd.isna(p): | |
| 239 return False | |
| 240 raw = _normalize_path_value(p) | |
| 241 if not raw: | |
| 242 return False | |
| 243 _warn_if_long_component(raw) | |
| 244 index = get_extracted_index() | |
| 245 if not index: | |
| 246 return False | |
| 247 norm = raw.replace("\\", "/").lstrip("./") | |
| 248 return norm in index | |
| 134 | 249 |
| 135 # Infer image columns if none were provided | 250 # Infer image columns if none were provided |
| 136 if not image_columns: | 251 if not image_columns: |
| 137 obj_cols = [c for c in df.columns if str(df[c].dtype) == "object"] | 252 obj_cols = [c for c in df.columns if str(df[c].dtype) == "object"] |
| 138 inferred = [] | 253 inferred = [] |
| 139 for col in obj_cols: | 254 for col in obj_cols: |
| 140 sample = df[col].dropna().head(50) | 255 sample = df[col].dropna().head(50) |
| 141 if sample.empty: | 256 if sample.empty: |
| 142 continue | 257 continue |
| 258 if extracted_root is not None: | |
| 259 index = get_extracted_index() | |
| 260 else: | |
| 261 index = set() | |
| 262 if index: | |
| 263 matched = sample.apply(matches_extracted) | |
| 264 if matched.any(): | |
| 265 inferred.append(col) | |
| 266 continue | |
| 143 resolved_sample = sample.apply(resolve) | 267 resolved_sample = sample.apply(resolve) |
| 144 if resolved_sample.notna().any(): | 268 if resolved_sample.notna().any(): |
| 145 inferred.append(col) | 269 inferred.append(col) |
| 146 image_columns = inferred | 270 image_columns = inferred |
| 147 if image_columns: | 271 if image_columns: |
