Mercurial > repos > bgruening > bioimage_inference
comparison main.py @ 0:caea9ee1ffac draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/bioimaging commit 57f46739f4365f59cd52c515bdd3fae2e01b734e
author | bgruening |
---|---|
date | Fri, 02 Aug 2024 15:40:35 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:caea9ee1ffac |
---|---|
1 """ | |
2 Predict images using AI models from BioImage.IO | |
3 """ | |
4 | |
5 import argparse | |
6 | |
7 import imageio | |
8 import numpy as np | |
9 import torch | |
10 | |
11 | |
12 def find_dim_order(user_in_shape, input_image): | |
13 """ | |
14 Find the correct order of input image's | |
15 shape. For a few models, the order of input size | |
16 mentioned in the RDF.yaml file is reversed compared | |
17 to the input image's original size. If it is reversed, | |
18 transpose the image to find correct order of image's | |
19 dimensions. | |
20 """ | |
21 image_shape = list(input_image.shape) | |
22 # reverse the input shape provided from RDF.yaml file | |
23 correct_order = user_in_shape.split(",")[::-1] | |
24 # remove 1s from the original dimensions | |
25 correct_order = [int(i) for i in correct_order if i != "1"] | |
26 if (correct_order[0] == image_shape[-1]) and (correct_order != image_shape): | |
27 input_image = torch.tensor(input_image.transpose()) | |
28 return input_image, correct_order | |
29 | |
30 | |
31 if __name__ == "__main__": | |
32 arg_parser = argparse.ArgumentParser() | |
33 arg_parser.add_argument("-im", "--imaging_model", required=True, help="Input BioImage model") | |
34 arg_parser.add_argument("-ii", "--image_file", required=True, help="Input image file") | |
35 arg_parser.add_argument("-is", "--image_size", required=True, help="Input image file's size") | |
36 | |
37 # get argument values | |
38 args = vars(arg_parser.parse_args()) | |
39 model_path = args["imaging_model"] | |
40 input_image_path = args["image_file"] | |
41 | |
42 # load all embedded images in TIF file | |
43 test_data = imageio.v3.imread(input_image_path, index="...") | |
44 test_data = np.squeeze(test_data) | |
45 test_data = test_data.astype(np.float32) | |
46 | |
47 # assess the correct dimensions of TIF input image | |
48 input_image_shape = args["image_size"] | |
49 im_test_data, shape_vals = find_dim_order(input_image_shape, test_data) | |
50 | |
51 # load model | |
52 model = torch.load(model_path) | |
53 model.eval() | |
54 | |
55 # find the number of dimensions required by the model | |
56 target_dimension = 0 | |
57 for param in model.named_parameters(): | |
58 target_dimension = len(param[1].shape) | |
59 break | |
60 current_dimension = len(list(im_test_data.shape)) | |
61 | |
62 # update the dimensions of input image if the required image by | |
63 # the model is smaller | |
64 slices = tuple(slice(0, s_val) for s_val in shape_vals) | |
65 | |
66 # apply the slices to the reshaped_input | |
67 im_test_data = im_test_data[slices] | |
68 exp_test_data = torch.tensor(im_test_data) | |
69 | |
70 # expand input image's dimensions | |
71 for i in range(target_dimension - current_dimension): | |
72 exp_test_data = torch.unsqueeze(exp_test_data, i) | |
73 | |
74 # make prediction | |
75 pred_data = model(exp_test_data) | |
76 pred_data_output = pred_data.detach().numpy() | |
77 | |
78 # save original image matrix | |
79 np.save("output_predicted_image_matrix.npy", pred_data_output) | |
80 | |
81 # post process predicted file to correctly save as TIF file | |
82 pred_data = torch.squeeze(pred_data) | |
83 pred_numpy = pred_data.detach().numpy() | |
84 | |
85 # write predicted TIF image to file | |
86 imageio.v3.imwrite("output_predicted_image.tif", pred_numpy, extension=".tif") |