view MetaFormer/metaformer_stacked_cnn.py @ 11:c5150cceab47 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 0fe927b618cd4dfc87af7baaa827034cc6813225
author goeckslab
date Sat, 18 Oct 2025 03:17:09 +0000
parents
children
line wrap: on
line source

import logging
import os
import sys
from typing import Dict, List, Optional

import torch
import torch.nn as nn

sys.path.insert(0, os.path.dirname(__file__))

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)

SUPPORTED_PREFIXES = (
    'identityformer_',
    'randformer_',
    'poolformerv2_',
    'convformer_',
    'caformer_',
)

try:
    from metaformer_models import default_cfgs as META_DEFAULT_CFGS
    META_MODELS_AVAILABLE = True
    logger.info("MetaFormer models imported successfully")
except Exception as e:
    META_MODELS_AVAILABLE = False
    logger.warning(f"MetaFormer models not available: {e}")


def _resolve_metaformer_ctor(model_name: str):
    # Prefer getattr to avoid importing every factory explicitly
    try:
        # Import the module itself for dynamic access
        import metaformer_models
        _factories = metaformer_models.__dict__
        if model_name in _factories and callable(_factories[model_name]):
            return _factories[model_name]
    except Exception:
        pass
    return None


class MetaFormerStackedCNN(nn.Module):
    def __init__(
        self,
        height: int = 224,
        width: int = 224,
        num_channels: int = 3,
        output_size: int = 128,
        custom_model: str = "identityformer_s12",
        use_pretrained: bool = True,
        trainable: bool = True,
        conv_layers: Optional[List[Dict]] = None,
        num_conv_layers: Optional[int] = None,
        conv_activation: str = "relu",
        conv_dropout: float = 0.0,
        conv_norm: Optional[str] = None,
        conv_use_bias: bool = True,
        fc_layers: Optional[List[Dict]] = None,
        num_fc_layers: int = 1,
        fc_activation: str = "relu",
        fc_dropout: float = 0.0,
        fc_norm: Optional[str] = None,
        fc_use_bias: bool = True,
        **kwargs,
    ):
        super().__init__()
        logger.info("MetaFormerStackedCNN encoder instantiated")
        logger.info(f"Using MetaFormer model: {custom_model}")

        try:
            height = int(height)
            width = int(width)
            num_channels = int(num_channels)
        except (TypeError, ValueError) as exc:
            raise ValueError("MetaFormerStackedCNN requires integer height, width, and num_channels.") from exc

        if height <= 0 or width <= 0:
            raise ValueError(f"MetaFormerStackedCNN received non-positive dimensions: {height}x{width}.")
        if num_channels <= 0:
            raise ValueError(f"MetaFormerStackedCNN requires num_channels > 0, received {num_channels}.")

        self.height = height
        self.width = width
        self.num_channels = num_channels
        self.output_size = output_size
        self.custom_model = custom_model
        self.use_pretrained = use_pretrained
        self.trainable = trainable

        cfg = META_DEFAULT_CFGS.get(custom_model, {})
        input_size = cfg.get('input_size', (3, 224, 224))
        if isinstance(input_size, (list, tuple)) and len(input_size) == 3:
            expected_channels, expected_height, expected_width = input_size
        else:
            expected_channels, expected_height, expected_width = 3, 224, 224

        self.expected_channels = expected_channels
        self.expected_height = expected_height
        self.expected_width = expected_width

        logger.info(f"Initializing MetaFormerStackedCNN with model: {custom_model}")
        logger.info(
            "Input: %sx%sx%s -> Output: %s (expected backbone size: %sx%s)",
            num_channels,
            height,
            width,
            output_size,
            self.expected_height,
            self.expected_width,
        )

        self.channel_adapter: Optional[nn.Conv2d] = None
        if num_channels != self.expected_channels:
            self.channel_adapter = nn.Conv2d(
                num_channels, self.expected_channels, kernel_size=1, stride=1, padding=0
            )
            logger.info(
                "Added channel adapter: %s -> %s channels",
                num_channels,
                self.expected_channels,
            )

        self.size_adapter: Optional[nn.Module] = None
        if height != self.expected_height or width != self.expected_width:
            self.size_adapter = nn.AdaptiveAvgPool2d((height, width))
            logger.info(
                "Configured size adapter to requested input: %sx%s",
                height,
                width,
            )
        self.backbone_adapter: Optional[nn.Module] = None

        self.backbone = self._load_metaformer_backbone()
        self.feature_dim = self._get_feature_dim()

        self.fc_layers = self._create_fc_layers(
            input_dim=self.feature_dim,
            output_dim=output_size,
            num_layers=num_fc_layers,
            activation=fc_activation,
            dropout=fc_dropout,
            norm=fc_norm,
            use_bias=fc_use_bias,
            fc_layers_config=fc_layers,
        )

        if not trainable:
            for param in self.backbone.parameters():
                param.requires_grad = False
            logger.info("MetaFormer backbone frozen (trainable=False)")

        logger.info("MetaFormerStackedCNN initialized successfully")

    def _load_metaformer_backbone(self):
        if not META_MODELS_AVAILABLE:
            raise ImportError("MetaFormer models are not available")

        ctor = _resolve_metaformer_ctor(self.custom_model)
        if ctor is None:
            raise ValueError(f"Unknown MetaFormer model: {self.custom_model}")

        cfg = META_DEFAULT_CFGS.get(self.custom_model, {})
        weights_url = cfg.get('url')
        # track loading
        self._pretrained_loaded = False
        self._loaded_weights_url: Optional[str] = None
        if self.use_pretrained and weights_url:
            print(f"LOADING MetaFormer pretrained weights from: {weights_url}")
            logger.info(f"Loading pretrained weights from: {weights_url}")
        # Ensure we log whenever the factories call torch.hub.load_state_dict_from_url
        orig_loader = getattr(torch.hub, 'load_state_dict_from_url', None)

        def _wrapped_loader(url, *args, **kwargs):
            print(f"DOWNLOADING weights from: {url}")
            logger.info(f"DOWNLOADING weights from: {url}")
            self._pretrained_loaded = True
            self._loaded_weights_url = url
            result = orig_loader(url, *args, **kwargs)
            print(f"WEIGHTS DOWNLOADED successfully from: {url}")
            return result
        try:
            if self.use_pretrained and orig_loader is not None:
                torch.hub.load_state_dict_from_url = _wrapped_loader  # type: ignore[attr-defined]
            print(f"CREATING MetaFormer model: {self.custom_model} (pretrained={self.use_pretrained})")
            try:
                model = ctor(pretrained=self.use_pretrained, num_classes=1000)
                print(f"MetaFormer model CREATED: {self.custom_model}")
            except Exception as model_error:
                if self.use_pretrained:
                    print(f"⚠ Warning: Failed to load {self.custom_model} with pretrained weights: {model_error}")
                    print("Attempting to load without pretrained weights as fallback...")
                    logger.warning(f"Failed to load {self.custom_model} with pretrained weights: {model_error}")
                    model = ctor(pretrained=False, num_classes=1000)
                    print(f"✓ Successfully loaded {self.custom_model} without pretrained weights")
                    self.use_pretrained = False  # Update state to reflect actual loading
                else:
                    raise model_error
        finally:
            if orig_loader is not None:
                torch.hub.load_state_dict_from_url = orig_loader  # type: ignore[attr-defined]
        self._metaformer_weights_url = weights_url
        if self.use_pretrained:
            if self._pretrained_loaded:
                print(f"MetaFormer: pretrained weights loaded from {self._loaded_weights_url}")
                logger.info(f"MetaFormer: pretrained weights loaded from {self._loaded_weights_url}")
            else:
                # Warn but don't fail - weights may have failed to load but model creation succeeded
                print("⚠ Warning: MetaFormer pretrained weights were requested but not confirmed as loaded")
                logger.warning("MetaFormer: pretrained weights were requested but not confirmed as loaded")
        else:
            print(f"MetaFormer: using randomly initialized weights for {self.custom_model}")
            logger.info(f"MetaFormer: using randomly initialized weights for {self.custom_model}")
        logger.info(f"Loaded MetaFormer backbone: {self.custom_model} (pretrained={self.use_pretrained})")
        return model

    def _get_feature_dim(self):
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 224, 224)
            features = self.backbone.forward_features(dummy_input)
            feature_dim = features.shape[-1]
        logger.info(f"MetaFormer feature dimension: {feature_dim}")
        return feature_dim

    def _create_fc_layers(self, input_dim, output_dim, num_layers, activation, dropout, norm, use_bias, fc_layers_config):
        layers = []
        if fc_layers_config:
            current_dim = input_dim
            for i, layer_config in enumerate(fc_layers_config):
                layer_output_dim = layer_config.get('output_size', output_dim if i == len(fc_layers_config) - 1 else current_dim)
                layers.append(nn.Linear(current_dim, layer_output_dim, bias=use_bias))
                if i < len(fc_layers_config) - 1:
                    if activation == "relu":
                        layers.append(nn.ReLU())
                    elif activation == "tanh":
                        layers.append(nn.Tanh())
                    elif activation == "sigmoid":
                        layers.append(nn.Sigmoid())
                    elif activation == "leaky_relu":
                        layers.append(nn.LeakyReLU())
                if dropout > 0:
                    layers.append(nn.Dropout(dropout))
                if norm == "batch":
                    layers.append(nn.BatchNorm1d(layer_output_dim))
                elif norm == "layer":
                    layers.append(nn.LayerNorm(layer_output_dim))
                current_dim = layer_output_dim
        else:
            if num_layers == 1:
                layers.append(nn.Linear(input_dim, output_dim, bias=use_bias))
            else:
                intermediate_dims = [input_dim]
                for i in range(num_layers - 1):
                    intermediate_dim = int(input_dim * (0.5 ** (i + 1)))
                    intermediate_dim = max(intermediate_dim, output_dim)
                    intermediate_dims.append(intermediate_dim)
                intermediate_dims.append(output_dim)
                for i in range(num_layers):
                    layers.append(nn.Linear(intermediate_dims[i], intermediate_dims[i + 1], bias=use_bias))
                    if i < num_layers - 1:
                        if activation == "relu":
                            layers.append(nn.ReLU())
                        elif activation == "tanh":
                            layers.append(nn.Tanh())
                        elif activation == "sigmoid":
                            layers.append(nn.Sigmoid())
                        elif activation == "leaky_relu":
                            layers.append(nn.LeakyReLU())
                    if dropout > 0:
                        layers.append(nn.Dropout(dropout))
                    if norm == "batch":
                        layers.append(nn.BatchNorm1d(intermediate_dims[i + 1]))
                    elif norm == "layer":
                        layers.append(nn.LayerNorm(intermediate_dims[i + 1]))
        return nn.Sequential(*layers)

    def forward(self, x):
        if x.shape[1] != self.expected_channels:
            if (
                self.channel_adapter is None
                or self.channel_adapter.in_channels != x.shape[1]
                or self.channel_adapter.out_channels != self.expected_channels
            ):
                self.channel_adapter = nn.Conv2d(
                    x.shape[1],
                    self.expected_channels,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                ).to(x.device)
                logger.info(
                    "Created dynamic channel adapter: %s -> %s channels",
                    x.shape[1],
                    self.expected_channels,
                )
            x = self.channel_adapter(x)

        target_height, target_width = self.height, self.width
        if x.shape[2] != target_height or x.shape[3] != target_width:
            if (
                self.size_adapter is None
                or getattr(self.size_adapter, "output_size", None)
                != (target_height, target_width)
            ):
                self.size_adapter = nn.AdaptiveAvgPool2d(
                    (target_height, target_width)
                ).to(x.device)
                logger.info(
                    "Created size adapter: %sx%s -> %sx%s",
                    x.shape[2],
                    x.shape[3],
                    target_height,
                    target_width,
                )
            x = self.size_adapter(x)

        if target_height != self.expected_height or target_width != self.expected_width:
            if (
                self.backbone_adapter is None
                or getattr(self.backbone_adapter, "output_size", None)
                != (self.expected_height, self.expected_width)
            ):
                self.backbone_adapter = nn.AdaptiveAvgPool2d(
                    (self.expected_height, self.expected_width)
                ).to(x.device)
                logger.info(
                    "Aligning to MetaFormer backbone size: %sx%s",
                    self.expected_height,
                    self.expected_width,
                )
            x = self.backbone_adapter(x)

        features = self.backbone.forward_features(x)
        output = self.fc_layers(features)
        return {'encoder_output': output}

    @property
    def output_shape(self):
        return [self.output_size]


def create_metaformer_stacked_cnn(model_name: str, **kwargs) -> MetaFormerStackedCNN:
    encoder = MetaFormerStackedCNN(custom_model=model_name, **kwargs)
    return encoder


def patch_ludwig_stacked_cnn():
    # Only patch Ludwig if MetaFormer models are available in this runtime
    if not META_MODELS_AVAILABLE:
        logger.warning("MetaFormer models unavailable; skipping Ludwig patch for stacked_cnn.")
        return False
    return patch_ludwig_direct()


def _is_supported_metaformer(custom_model: Optional[str]) -> bool:
    return bool(custom_model) and custom_model.startswith(SUPPORTED_PREFIXES)


def patch_ludwig_direct():
    try:
        from ludwig.encoders.image.base import Stacked2DCNN
        original_stacked_cnn_init = Stacked2DCNN.__init__

        def patched_stacked_cnn_init(self, *args, **kwargs):
            custom_model = kwargs.pop("custom_model", None)
            if custom_model is None:
                custom_model = getattr(patch_ludwig_direct, '_metaformer_model', None)

            try:
                if META_MODELS_AVAILABLE and _is_supported_metaformer(custom_model):
                    print(f"DETECTED MetaFormer model: {custom_model}")
                    print("MetaFormer encoder is being loaded and used.")
                    # Initialize base class to keep Ludwig internals intact
                    original_stacked_cnn_init(self, *args, **kwargs)
                    # Create our MetaFormer encoder and graft behavior
                    mf_encoder = create_metaformer_stacked_cnn(custom_model, **kwargs)
                    # ensure base attributes won't be used accidentally
                    for attr in ("conv_layers", "fc_layers", "combiner", "output_shape", "reduce_output"):
                        if hasattr(self, attr):
                            try:
                                setattr(self, attr, getattr(mf_encoder, attr, None))
                            except Exception:
                                pass
                    self.forward = mf_encoder.forward
                    if hasattr(mf_encoder, 'backbone'):
                        self.backbone = mf_encoder.backbone
                    if hasattr(mf_encoder, 'fc_layers'):
                        self.fc_layers = mf_encoder.fc_layers
                    if hasattr(mf_encoder, 'custom_model'):
                        self.custom_model = mf_encoder.custom_model
                    # explicit confirmation logs
                    try:
                        url_info = getattr(mf_encoder, '_loaded_weights_url', None)
                        loaded_flag = getattr(mf_encoder, '_pretrained_loaded', False)
                        if loaded_flag and url_info:
                            print(f"CONFIRMED: MetaFormer '{custom_model}' using pretrained weights from: {url_info}")
                            logger.info(f"CONFIRMED: MetaFormer '{custom_model}' using pretrained weights from: {url_info}")
                        else:
                            print(f"CONFIRMED: MetaFormer '{custom_model}' using randomly initialized weights (no pretrained)")
                            logger.info(f"CONFIRMED: MetaFormer '{custom_model}' using randomly initialized weights")
                    except Exception:
                        pass
                else:
                    original_stacked_cnn_init(self, *args, **kwargs)
            finally:
                if hasattr(patch_ludwig_direct, '_metaformer_model'):
                    patch_ludwig_direct._metaformer_model = None

        Stacked2DCNN.__init__ = patched_stacked_cnn_init
        return True
    except Exception as e:
        logger.error(f"Failed to apply MetaFormer direct patch: {e}")
        return False


def set_current_metaformer_model(model_name: str):
    """Store the current MetaFormer model name for the patch to use."""
    setattr(patch_ludwig_direct, '_metaformer_model', model_name)


def clear_current_metaformer_model():
    """Remove any cached MetaFormer model hint."""
    if hasattr(patch_ludwig_direct, '_metaformer_model'):
        delattr(patch_ludwig_direct, '_metaformer_model')