Mercurial > repos > goeckslab > image_learner
comparison ludwig_backend.py @ 16:8729f69e9207 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
| author | goeckslab |
|---|---|
| date | Wed, 03 Dec 2025 01:28:52 +0000 |
| parents | d17e3a1b8659 |
| children |
comparison
equal
deleted
inserted
replaced
| 15:d17e3a1b8659 | 16:8729f69e9207 |
|---|---|
| 1 import json | 1 import json |
| 2 import logging | 2 import logging |
| 3 import os | |
| 3 from pathlib import Path | 4 from pathlib import Path |
| 4 from typing import Any, Dict, Optional, Protocol, Tuple | 5 from typing import Any, Dict, Optional, Protocol, Tuple |
| 5 | 6 |
| 6 import pandas as pd | 7 import pandas as pd |
| 7 import pandas.api.types as ptypes | 8 import pandas.api.types as ptypes |
| 160 custom_model = raw_encoder["custom_model"] | 161 custom_model = raw_encoder["custom_model"] |
| 161 else: | 162 else: |
| 162 custom_model = model_name | 163 custom_model = model_name |
| 163 | 164 |
| 164 logger.info(f"DETECTED MetaFormer model: {custom_model}") | 165 logger.info(f"DETECTED MetaFormer model: {custom_model}") |
| 166 # Stash the model name for patched Stacked2DCNN in case Ludwig drops custom_model from kwargs | |
| 167 try: | |
| 168 from MetaFormer.metaformer_stacked_cnn import set_current_metaformer_model | |
| 169 | |
| 170 set_current_metaformer_model(custom_model) | |
| 171 except Exception: | |
| 172 logger.debug("Could not set current MetaFormer model hint; proceeding without global override") | |
| 173 # Also pass via environment to survive process boundaries (e.g., Ray workers) | |
| 174 os.environ["GLEAM_META_FORMER_MODEL"] = custom_model | |
| 165 cfg_channels, cfg_height, cfg_width = 3, 224, 224 | 175 cfg_channels, cfg_height, cfg_width = 3, 224, 224 |
| 176 model_cfg = {} | |
| 166 if META_DEFAULT_CFGS: | 177 if META_DEFAULT_CFGS: |
| 167 model_cfg = META_DEFAULT_CFGS.get(custom_model, {}) | 178 model_cfg = META_DEFAULT_CFGS.get(custom_model, {}) |
| 168 input_size = model_cfg.get("input_size") | 179 input_size = model_cfg.get("input_size") |
| 169 if isinstance(input_size, (list, tuple)) and len(input_size) == 3: | 180 if isinstance(input_size, (list, tuple)) and len(input_size) == 3: |
| 170 cfg_channels, cfg_height, cfg_width = ( | 181 cfg_channels, cfg_height, cfg_width = ( |
| 171 int(input_size[0]), | 182 int(input_size[0]), |
| 172 int(input_size[1]), | 183 int(input_size[1]), |
| 173 int(input_size[2]), | 184 int(input_size[2]), |
| 174 ) | 185 ) |
| 175 | 186 |
| 176 target_height, target_width = cfg_height, cfg_width | 187 weights_url = None |
| 188 if isinstance(model_cfg, dict): | |
| 189 weights_url = model_cfg.get("url") | |
| 190 logger.info( | |
| 191 "MetaFormer cfg lookup: model=%s has_cfg=%s url=%s use_pretrained=%s", | |
| 192 custom_model, | |
| 193 bool(model_cfg), | |
| 194 weights_url, | |
| 195 use_pretrained, | |
| 196 ) | |
| 197 if use_pretrained and not weights_url: | |
| 198 logger.warning( | |
| 199 "MetaFormer pretrained requested for %s but no URL found in default cfgs; model will be randomly initialized", | |
| 200 custom_model, | |
| 201 ) | |
| 202 | |
| 177 resize_value = config_params.get("image_resize") | 203 resize_value = config_params.get("image_resize") |
| 178 if resize_value and resize_value != "original": | 204 if resize_value and resize_value != "original": |
| 179 try: | 205 try: |
| 180 dimensions = resize_value.split("x") | 206 dimensions = resize_value.split("x") |
| 181 if len(dimensions) == 2: | 207 if len(dimensions) == 2: |
| 196 ) | 222 ) |
| 197 target_height, target_width = cfg_height, cfg_width | 223 target_height, target_width = cfg_height, cfg_width |
| 198 else: | 224 else: |
| 199 image_zip_path = config_params.get("image_zip", "") | 225 image_zip_path = config_params.get("image_zip", "") |
| 200 detected_height, detected_width = self._detect_image_dimensions(image_zip_path) | 226 detected_height, detected_width = self._detect_image_dimensions(image_zip_path) |
| 201 if use_pretrained: | 227 target_height, target_width = detected_height, detected_width |
| 202 if (detected_height, detected_width) != (cfg_height, cfg_width): | 228 if use_pretrained and (detected_height, detected_width) != (cfg_height, cfg_width): |
| 203 logger.info( | 229 logger.info( |
| 204 "MetaFormer pretrained weights expect %sx%s; resizing from detected %sx%s", | 230 "MetaFormer pretrained weights expect %sx%s; proceeding with detected %sx%s", |
| 205 cfg_height, | 231 cfg_height, |
| 206 cfg_width, | 232 cfg_width, |
| 207 detected_height, | 233 detected_height, |
| 208 detected_width, | 234 detected_width, |
| 209 ) | 235 ) |
| 210 else: | |
| 211 target_height, target_width = detected_height, detected_width | |
| 212 if target_height <= 0 or target_width <= 0: | 236 if target_height <= 0 or target_width <= 0: |
| 213 raise ValueError( | 237 raise ValueError( |
| 214 f"Invalid detected image dimensions for MetaFormer: {target_height}x{target_width}." | 238 f"Invalid detected image dimensions for MetaFormer: {target_height}x{target_width}." |
| 215 ) | 239 ) |
| 216 | 240 |
