diff main.py @ 3:000f171afabb draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/bioimaging commit e08711c242a340a1671dfca35f52d3724086e968
author bgruening
date Wed, 26 Feb 2025 10:27:36 +0000
parents caea9ee1ffac
children 4fd6e8b051e9
line wrap: on
line diff
--- a/main.py	Tue Oct 15 12:57:42 2024 +0000
+++ b/main.py	Wed Feb 26 10:27:36 2025 +0000
@@ -7,70 +7,128 @@
 import imageio
 import numpy as np
 import torch
+import torch.nn.functional as F
 
 
-def find_dim_order(user_in_shape, input_image):
+def dynamic_resize(image: torch.Tensor, target_shape: tuple):
     """
-    Find the correct order of input image's
-    shape. For a few models, the order of input size
-    mentioned in the RDF.yaml file is reversed compared
-    to the input image's original size. If it is reversed,
-    transpose the image to find correct order of image's
-    dimensions.
+    Resize an input tensor dynamically to the target shape.
+
+    Parameters:
+    - image: Input tensor with shape (C, D1, D2, ..., DN) (any number of spatial dims)
+    - target_shape: Tuple specifying the target shape (C', D1', D2', ..., DN')
+
+    Returns:
+    - Resized tensor with target shape target_shape.
     """
-    image_shape = list(input_image.shape)
-    # reverse the input shape provided from RDF.yaml file
-    correct_order = user_in_shape.split(",")[::-1]
-    # remove 1s from the original dimensions
-    correct_order = [int(i) for i in correct_order if i != "1"]
-    if (correct_order[0] == image_shape[-1]) and (correct_order != image_shape):
-        input_image = torch.tensor(input_image.transpose())
-    return input_image, correct_order
+    # Extract input shape
+    input_shape = image.shape
+    num_dims = len(input_shape)  # Includes channels and spatial dimensions
+
+    # Ensure target shape matches the number of dimensions
+    if len(target_shape) != num_dims:
+        raise ValueError(
+            f"Target shape {target_shape} must match input dimensions {num_dims}"
+        )
+
+    # Extract target channels and spatial sizes
+    target_channels = target_shape[0]  # First element is the target channel count
+    target_spatial_size = target_shape[1:]  # Remaining elements are spatial dimensions
+
+    # Add batch dim (N=1) for resizing
+    image = image.unsqueeze(0)
+
+    # Choose the best interpolation mode based on dimensionality
+    if num_dims == 4:
+        interp_mode = "trilinear"
+    elif num_dims == 3:
+        interp_mode = "bilinear"
+    elif num_dims == 2:
+        interp_mode = "bicubic"
+    else:
+        interp_mode = "nearest"
+
+    # Resize spatial dimensions dynamically
+    image = F.interpolate(
+        image, size=target_spatial_size, mode=interp_mode, align_corners=False
+    )
+
+    # Adjust channels if necessary
+    current_channels = image.shape[1]
+
+    if target_channels > current_channels:
+        # Expand channels by repeating existing ones
+        expand_factor = target_channels // current_channels
+        remainder = target_channels % current_channels
+        image = image.repeat(1, expand_factor, *[1] * (num_dims - 1))
+
+        if remainder > 0:
+            extra_channels = image[
+                :, :remainder, ...
+            ]  # Take the first few channels to match target
+            image = torch.cat([image, extra_channels], dim=1)
+
+    elif target_channels < current_channels:
+        # Reduce channels by averaging adjacent ones
+        image = image[:, :target_channels, ...]  # Simply slice to reduce channels
+    return image.squeeze(0)  # Remove batch dimension before returning
 
 
 if __name__ == "__main__":
     arg_parser = argparse.ArgumentParser()
-    arg_parser.add_argument("-im", "--imaging_model", required=True, help="Input BioImage model")
-    arg_parser.add_argument("-ii", "--image_file", required=True, help="Input image file")
-    arg_parser.add_argument("-is", "--image_size", required=True, help="Input image file's size")
+    arg_parser.add_argument(
+        "-im", "--imaging_model", required=True, help="Input BioImage model"
+    )
+    arg_parser.add_argument(
+        "-ii", "--image_file", required=True, help="Input image file"
+    )
+    arg_parser.add_argument(
+        "-is", "--image_size", required=True, help="Input image file's size"
+    )
+    arg_parser.add_argument(
+        "-ia", "--image_axes", required=True, help="Input image file's axes"
+    )
 
     # get argument values
     args = vars(arg_parser.parse_args())
     model_path = args["imaging_model"]
     input_image_path = args["image_file"]
+    input_size = args["image_size"]
 
     # load all embedded images in TIF file
     test_data = imageio.v3.imread(input_image_path, index="...")
+    test_data = test_data.astype(np.float32)
     test_data = np.squeeze(test_data)
-    test_data = test_data.astype(np.float32)
+
+    target_image_dim = input_size.split(",")[::-1]
+    target_image_dim = [int(i) for i in target_image_dim if i != "1"]
+    target_image_dim = tuple(target_image_dim)
 
-    # assess the correct dimensions of TIF input image
-    input_image_shape = args["image_size"]
-    im_test_data, shape_vals = find_dim_order(input_image_shape, test_data)
+    exp_test_data = torch.tensor(test_data)
+    # check if image dimensions are reversed
+    reversed_order = list(reversed(range(exp_test_data.dim())))
+    exp_test_data_T = exp_test_data.permute(*reversed_order)
+    if exp_test_data_T.shape == target_image_dim:
+        exp_test_data = exp_test_data_T
+    if exp_test_data.shape != target_image_dim:
+        for i in range(len(target_image_dim) - exp_test_data.dim()):
+            exp_test_data = exp_test_data.unsqueeze(i)
+        try:
+            exp_test_data = dynamic_resize(exp_test_data, target_image_dim)
+        except Exception as e:
+            raise RuntimeError(f"Error during resizing: {e}") from e
+
+    current_dimension = len(exp_test_data.shape)
+    input_axes = args["image_axes"]
+    target_dimension = len(input_axes)
+    # expand input image based on the number of target dimensions
+    for i in range(target_dimension - current_dimension):
+        exp_test_data = torch.unsqueeze(exp_test_data, i)
 
     # load model
     model = torch.load(model_path)
     model.eval()
 
-    # find the number of dimensions required by the model
-    target_dimension = 0
-    for param in model.named_parameters():
-        target_dimension = len(param[1].shape)
-        break
-    current_dimension = len(list(im_test_data.shape))
-
-    # update the dimensions of input image if the required image by
-    # the model is smaller
-    slices = tuple(slice(0, s_val) for s_val in shape_vals)
-
-    # apply the slices to the reshaped_input
-    im_test_data = im_test_data[slices]
-    exp_test_data = torch.tensor(im_test_data)
-
-    # expand input image's dimensions
-    for i in range(target_dimension - current_dimension):
-        exp_test_data = torch.unsqueeze(exp_test_data, i)
-
     # make prediction
     pred_data = model(exp_test_data)
     pred_data_output = pred_data.detach().numpy()