Mercurial > repos > bgruening > bioimage_inference
comparison main.py @ 0:caea9ee1ffac draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/bioimaging commit 57f46739f4365f59cd52c515bdd3fae2e01b734e
| author | bgruening |
|---|---|
| date | Fri, 02 Aug 2024 15:40:35 +0000 |
| parents | |
| children | 000f171afabb |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:caea9ee1ffac |
|---|---|
| 1 """ | |
| 2 Predict images using AI models from BioImage.IO | |
| 3 """ | |
| 4 | |
| 5 import argparse | |
| 6 | |
| 7 import imageio | |
| 8 import numpy as np | |
| 9 import torch | |
| 10 | |
| 11 | |
| 12 def find_dim_order(user_in_shape, input_image): | |
| 13 """ | |
| 14 Find the correct order of input image's | |
| 15 shape. For a few models, the order of input size | |
| 16 mentioned in the RDF.yaml file is reversed compared | |
| 17 to the input image's original size. If it is reversed, | |
| 18 transpose the image to find correct order of image's | |
| 19 dimensions. | |
| 20 """ | |
| 21 image_shape = list(input_image.shape) | |
| 22 # reverse the input shape provided from RDF.yaml file | |
| 23 correct_order = user_in_shape.split(",")[::-1] | |
| 24 # remove 1s from the original dimensions | |
| 25 correct_order = [int(i) for i in correct_order if i != "1"] | |
| 26 if (correct_order[0] == image_shape[-1]) and (correct_order != image_shape): | |
| 27 input_image = torch.tensor(input_image.transpose()) | |
| 28 return input_image, correct_order | |
| 29 | |
| 30 | |
| 31 if __name__ == "__main__": | |
| 32 arg_parser = argparse.ArgumentParser() | |
| 33 arg_parser.add_argument("-im", "--imaging_model", required=True, help="Input BioImage model") | |
| 34 arg_parser.add_argument("-ii", "--image_file", required=True, help="Input image file") | |
| 35 arg_parser.add_argument("-is", "--image_size", required=True, help="Input image file's size") | |
| 36 | |
| 37 # get argument values | |
| 38 args = vars(arg_parser.parse_args()) | |
| 39 model_path = args["imaging_model"] | |
| 40 input_image_path = args["image_file"] | |
| 41 | |
| 42 # load all embedded images in TIF file | |
| 43 test_data = imageio.v3.imread(input_image_path, index="...") | |
| 44 test_data = np.squeeze(test_data) | |
| 45 test_data = test_data.astype(np.float32) | |
| 46 | |
| 47 # assess the correct dimensions of TIF input image | |
| 48 input_image_shape = args["image_size"] | |
| 49 im_test_data, shape_vals = find_dim_order(input_image_shape, test_data) | |
| 50 | |
| 51 # load model | |
| 52 model = torch.load(model_path) | |
| 53 model.eval() | |
| 54 | |
| 55 # find the number of dimensions required by the model | |
| 56 target_dimension = 0 | |
| 57 for param in model.named_parameters(): | |
| 58 target_dimension = len(param[1].shape) | |
| 59 break | |
| 60 current_dimension = len(list(im_test_data.shape)) | |
| 61 | |
| 62 # update the dimensions of input image if the required image by | |
| 63 # the model is smaller | |
| 64 slices = tuple(slice(0, s_val) for s_val in shape_vals) | |
| 65 | |
| 66 # apply the slices to the reshaped_input | |
| 67 im_test_data = im_test_data[slices] | |
| 68 exp_test_data = torch.tensor(im_test_data) | |
| 69 | |
| 70 # expand input image's dimensions | |
| 71 for i in range(target_dimension - current_dimension): | |
| 72 exp_test_data = torch.unsqueeze(exp_test_data, i) | |
| 73 | |
| 74 # make prediction | |
| 75 pred_data = model(exp_test_data) | |
| 76 pred_data_output = pred_data.detach().numpy() | |
| 77 | |
| 78 # save original image matrix | |
| 79 np.save("output_predicted_image_matrix.npy", pred_data_output) | |
| 80 | |
| 81 # post process predicted file to correctly save as TIF file | |
| 82 pred_data = torch.squeeze(pred_data) | |
| 83 pred_numpy = pred_data.detach().numpy() | |
| 84 | |
| 85 # write predicted TIF image to file | |
| 86 imageio.v3.imwrite("output_predicted_image.tif", pred_numpy, extension=".tif") |
