comparison main.py @ 0:caea9ee1ffac draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/bioimaging commit 57f46739f4365f59cd52c515bdd3fae2e01b734e
author bgruening
date Fri, 02 Aug 2024 15:40:35 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:caea9ee1ffac
1 """
2 Predict images using AI models from BioImage.IO
3 """
4
5 import argparse
6
7 import imageio
8 import numpy as np
9 import torch
10
11
12 def find_dim_order(user_in_shape, input_image):
13 """
14 Find the correct order of input image's
15 shape. For a few models, the order of input size
16 mentioned in the RDF.yaml file is reversed compared
17 to the input image's original size. If it is reversed,
18 transpose the image to find correct order of image's
19 dimensions.
20 """
21 image_shape = list(input_image.shape)
22 # reverse the input shape provided from RDF.yaml file
23 correct_order = user_in_shape.split(",")[::-1]
24 # remove 1s from the original dimensions
25 correct_order = [int(i) for i in correct_order if i != "1"]
26 if (correct_order[0] == image_shape[-1]) and (correct_order != image_shape):
27 input_image = torch.tensor(input_image.transpose())
28 return input_image, correct_order
29
30
31 if __name__ == "__main__":
32 arg_parser = argparse.ArgumentParser()
33 arg_parser.add_argument("-im", "--imaging_model", required=True, help="Input BioImage model")
34 arg_parser.add_argument("-ii", "--image_file", required=True, help="Input image file")
35 arg_parser.add_argument("-is", "--image_size", required=True, help="Input image file's size")
36
37 # get argument values
38 args = vars(arg_parser.parse_args())
39 model_path = args["imaging_model"]
40 input_image_path = args["image_file"]
41
42 # load all embedded images in TIF file
43 test_data = imageio.v3.imread(input_image_path, index="...")
44 test_data = np.squeeze(test_data)
45 test_data = test_data.astype(np.float32)
46
47 # assess the correct dimensions of TIF input image
48 input_image_shape = args["image_size"]
49 im_test_data, shape_vals = find_dim_order(input_image_shape, test_data)
50
51 # load model
52 model = torch.load(model_path)
53 model.eval()
54
55 # find the number of dimensions required by the model
56 target_dimension = 0
57 for param in model.named_parameters():
58 target_dimension = len(param[1].shape)
59 break
60 current_dimension = len(list(im_test_data.shape))
61
62 # update the dimensions of input image if the required image by
63 # the model is smaller
64 slices = tuple(slice(0, s_val) for s_val in shape_vals)
65
66 # apply the slices to the reshaped_input
67 im_test_data = im_test_data[slices]
68 exp_test_data = torch.tensor(im_test_data)
69
70 # expand input image's dimensions
71 for i in range(target_dimension - current_dimension):
72 exp_test_data = torch.unsqueeze(exp_test_data, i)
73
74 # make prediction
75 pred_data = model(exp_test_data)
76 pred_data_output = pred_data.detach().numpy()
77
78 # save original image matrix
79 np.save("output_predicted_image_matrix.npy", pred_data_output)
80
81 # post process predicted file to correctly save as TIF file
82 pred_data = torch.squeeze(pred_data)
83 pred_numpy = pred_data.detach().numpy()
84
85 # write predicted TIF image to file
86 imageio.v3.imwrite("output_predicted_image.tif", pred_numpy, extension=".tif")