view main.py @ 2:9b8fc55cb9eb draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/bioimaging commit c994ad3cbfbc134d39f485514b0c1a9c3d2d235e
author bgruening
date Tue, 15 Oct 2024 12:57:42 +0000
parents caea9ee1ffac
children
line wrap: on
line source

"""
Predict images using AI models from BioImage.IO
"""

import argparse

import imageio
import numpy as np
import torch


def find_dim_order(user_in_shape, input_image):
    """
    Find the correct order of input image's
    shape. For a few models, the order of input size
    mentioned in the RDF.yaml file is reversed compared
    to the input image's original size. If it is reversed,
    transpose the image to find correct order of image's
    dimensions.
    """
    image_shape = list(input_image.shape)
    # reverse the input shape provided from RDF.yaml file
    correct_order = user_in_shape.split(",")[::-1]
    # remove 1s from the original dimensions
    correct_order = [int(i) for i in correct_order if i != "1"]
    if (correct_order[0] == image_shape[-1]) and (correct_order != image_shape):
        input_image = torch.tensor(input_image.transpose())
    return input_image, correct_order


if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument("-im", "--imaging_model", required=True, help="Input BioImage model")
    arg_parser.add_argument("-ii", "--image_file", required=True, help="Input image file")
    arg_parser.add_argument("-is", "--image_size", required=True, help="Input image file's size")

    # get argument values
    args = vars(arg_parser.parse_args())
    model_path = args["imaging_model"]
    input_image_path = args["image_file"]

    # load all embedded images in TIF file
    test_data = imageio.v3.imread(input_image_path, index="...")
    test_data = np.squeeze(test_data)
    test_data = test_data.astype(np.float32)

    # assess the correct dimensions of TIF input image
    input_image_shape = args["image_size"]
    im_test_data, shape_vals = find_dim_order(input_image_shape, test_data)

    # load model
    model = torch.load(model_path)
    model.eval()

    # find the number of dimensions required by the model
    target_dimension = 0
    for param in model.named_parameters():
        target_dimension = len(param[1].shape)
        break
    current_dimension = len(list(im_test_data.shape))

    # update the dimensions of input image if the required image by
    # the model is smaller
    slices = tuple(slice(0, s_val) for s_val in shape_vals)

    # apply the slices to the reshaped_input
    im_test_data = im_test_data[slices]
    exp_test_data = torch.tensor(im_test_data)

    # expand input image's dimensions
    for i in range(target_dimension - current_dimension):
        exp_test_data = torch.unsqueeze(exp_test_data, i)

    # make prediction
    pred_data = model(exp_test_data)
    pred_data_output = pred_data.detach().numpy()

    # save original image matrix
    np.save("output_predicted_image_matrix.npy", pred_data_output)

    # post process predicted file to correctly save as TIF file
    pred_data = torch.squeeze(pred_data)
    pred_numpy = pred_data.detach().numpy()

    # write predicted TIF image to file
    imageio.v3.imwrite("output_predicted_image.tif", pred_numpy, extension=".tif")