Mercurial > repos > imgteam > image_registration_affine
comparison image_registration_affine.py @ 1:fa769715b6b0 draft
"planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tools/image_registration_affine/ commit e82400162e337b36c29d6e79fb2deb9871475397"
author | imgteam |
---|---|
date | Thu, 20 Jan 2022 00:45:30 +0000 |
parents | e34222a620d4 |
children | 77dc68af2b40 |
comparison
equal
deleted
inserted
replaced
0:e34222a620d4 | 1:fa769715b6b0 |
---|---|
1 """ | |
2 Copyright 2021-2022 Biomedical Computer Vision Group, Heidelberg University. | |
3 | |
4 Distributed under the MIT license. | |
5 See file LICENSE for detail or copy at https://opensource.org/licenses/MIT | |
6 | |
7 """ | |
8 import argparse | |
9 | |
10 import numpy as np | |
11 import pandas as pd | |
1 import skimage.io | 12 import skimage.io |
2 from skimage.transform import ProjectiveTransform | |
3 from skimage.filters import gaussian | |
4 from scipy.ndimage import map_coordinates | 13 from scipy.ndimage import map_coordinates |
5 from scipy.optimize import least_squares | 14 from scipy.optimize import least_squares |
6 import numpy as np | 15 from scipy.signal import convolve2d |
7 import pandas as pd | 16 from skimage.color import rgb2gray, rgba2rgb |
8 import argparse | 17 from skimage.filters import gaussian |
9 | 18 from skimage.transform import ProjectiveTransform |
10 | 19 |
11 | 20 |
12 def _stackcopy(a, b): | 21 def _stackcopy(a, b): |
13 if a.ndim == 3: | 22 if a.ndim == 3: |
14 a[:] = b[:, :, np.newaxis] | 23 a[:] = b[:, :, np.newaxis] |
15 else: | 24 else: |
16 a[:] = b | 25 a[:] = b |
17 | |
18 | 26 |
19 | 27 |
20 def warp_coords_batch(coord_map, shape, dtype=np.float64, batch_size=1000000): | 28 def warp_coords_batch(coord_map, shape, dtype=np.float64, batch_size=1000000): |
21 rows, cols = shape[0], shape[1] | 29 rows, cols = shape[0], shape[1] |
22 coords_shape = [len(shape), rows, cols] | 30 coords_shape = [len(shape), rows, cols] |
24 coords_shape.append(shape[2]) | 32 coords_shape.append(shape[2]) |
25 coords = np.empty(coords_shape, dtype=dtype) | 33 coords = np.empty(coords_shape, dtype=dtype) |
26 | 34 |
27 tf_coords = np.indices((cols, rows), dtype=dtype).reshape(2, -1).T | 35 tf_coords = np.indices((cols, rows), dtype=dtype).reshape(2, -1).T |
28 | 36 |
29 for i in range(0, (tf_coords.shape[0]//batch_size+1)): | 37 for i in range(0, (tf_coords.shape[0] // batch_size + 1)): |
30 tf_coords[batch_size*i:batch_size*(i+1)] = coord_map(tf_coords[batch_size*i:batch_size*(i+1)]) | 38 tf_coords[batch_size * i:batch_size * (i + 1)] = coord_map(tf_coords[batch_size * i:batch_size * (i + 1)]) |
31 tf_coords = tf_coords.T.reshape((-1, cols, rows)).swapaxes(1, 2) | 39 tf_coords = tf_coords.T.reshape((-1, cols, rows)).swapaxes(1, 2) |
32 | 40 |
33 _stackcopy(coords[1, ...], tf_coords[0, ...]) | 41 _stackcopy(coords[1, ...], tf_coords[0, ...]) |
34 _stackcopy(coords[0, ...], tf_coords[1, ...]) | 42 _stackcopy(coords[0, ...], tf_coords[1, ...]) |
35 if len(shape) == 3: | 43 if len(shape) == 3: |
36 coords[2, ...] = range(shape[2]) | 44 coords[2, ...] = range(shape[2]) |
37 | 45 |
38 return coords | 46 return coords |
39 | 47 |
40 | 48 |
49 def affine_registration(params, moving, fixed, metric='mae'): | |
50 tmat = np.eye(3) | |
51 tmat[0, :] = params.take([0, 1, 2]) | |
52 tmat[1, :] = params.take([3, 4, 5]) | |
41 | 53 |
42 def affine_registration(params,moving,fixed): | |
43 tmat = np.eye(3) | |
44 tmat[0,:] = params.take([0,1,2]) | |
45 tmat[1,:] = params.take([3,4,5]) | |
46 | |
47 trans = ProjectiveTransform(matrix=tmat) | 54 trans = ProjectiveTransform(matrix=tmat) |
48 warped_coords = warp_coords_batch(trans, fixed.shape) | 55 warped_coords = warp_coords_batch(trans, fixed.shape) |
49 t = map_coordinates(moving, warped_coords, mode='reflect') | 56 t = map_coordinates(moving, warped_coords, mode='nearest') |
50 | 57 f = fixed |
51 eI = (t - fixed)**2 | 58 |
52 return eI.flatten() | 59 if metric == 'mse': |
60 err = (t - f) ** 2 | |
61 | |
62 elif metric == 'mae': | |
63 err = (t - f) | |
64 | |
65 elif metric == 'lcc': | |
66 sum_filt = np.ones((9, 9)) | |
67 win_size = 81 | |
68 | |
69 t_sum = convolve2d(t, sum_filt, mode='same', boundary='symm') | |
70 f_sum = convolve2d(f, sum_filt, mode='same', boundary='symm') | |
71 t2_sum = convolve2d(t * t, sum_filt, mode='same', boundary='symm') | |
72 f2_sum = convolve2d(f * f, sum_filt, mode='same', boundary='symm') | |
73 tf_sum = convolve2d(t * f, sum_filt, mode='same', boundary='symm') | |
74 | |
75 cross = tf_sum - f_sum * t_sum / win_size | |
76 t_var = t2_sum - t_sum * t_sum / win_size | |
77 f_var = f2_sum - f_sum * f_sum / win_size | |
78 cc = cross * cross / (t_var * f_var + 1e-5) | |
79 err = 1 - cc | |
80 | |
81 return err.flatten() | |
53 | 82 |
54 | 83 |
84 def read_img_as_gray(fn): | |
85 im = skimage.io.imread(fn) | |
86 nDims = len(im.shape) | |
87 assert nDims in [2, 3], 'this tool does not support multichannel images' | |
88 if nDims == 3: | |
89 assert im.shape[-1] in [3, 4], 'this tool does not support multichannel images' | |
90 if im.shape[-1] == 4: | |
91 im = rgba2rgb(im) | |
92 im = rgb2gray(im) | |
93 im = im.astype(float) | |
94 im = im / np.max(im) | |
95 return im | |
55 | 96 |
56 def image_registration(fn_moving, fn_fixed, fn_out, smooth_sigma=1): | 97 |
57 moving = skimage.io.imread(fn_moving,as_gray=True) | 98 def image_registration(fn_moving, fn_fixed, fn_out, smooth_sigma=3, metric='lcc'): |
58 fixed = skimage.io.imread(fn_fixed,as_gray=True) | 99 moving = read_img_as_gray(fn_moving) |
100 fixed = read_img_as_gray(fn_fixed) | |
59 | 101 |
60 moving = gaussian(moving, sigma=smooth_sigma) | 102 moving = gaussian(moving, sigma=smooth_sigma) |
61 fixed = gaussian(fixed, sigma=smooth_sigma) | 103 fixed = gaussian(fixed, sigma=smooth_sigma) |
62 | 104 |
63 x = np.array([1, 0, 0, 0, 1, 0],dtype='float64') | 105 x = np.array([1, 0, 0, 0, 1, 0], dtype='float64') |
64 result = least_squares(affine_registration, x, args=(moving,fixed)) | 106 result = least_squares(affine_registration, x, args=(moving, fixed, metric)) |
65 | 107 |
66 tmat = np.eye(3) | 108 tmat = np.eye(3) |
67 tmat[0,:] = result.x.take([0,1,2]) | 109 tmat[0, :] = result.x.take([0, 1, 2]) |
68 tmat[1,:] = result.x.take([3,4,5]) | 110 tmat[1, :] = result.x.take([3, 4, 5]) |
69 | 111 |
70 pd.DataFrame(tmat).to_csv(fn_out, header=None, index=False, sep="\t") | 112 pd.DataFrame(tmat).to_csv(fn_out, header=None, index=False, sep="\t") |
71 | |
72 | 113 |
73 | 114 |
74 if __name__ == "__main__": | 115 if __name__ == "__main__": |
75 | 116 |
76 parser = argparse.ArgumentParser(description="Estimate the transformation matrix") | 117 parser = argparse.ArgumentParser(description="Estimate the transformation matrix") |
77 parser.add_argument("fn_moving", help="Name of the moving image.png") | 118 parser.add_argument("fn_moving", help="Path to the moving image") |
78 parser.add_argument("fn_fixed", help="Name of the fixed image.png") | 119 parser.add_argument("fn_fixed", help="Path to the fixed (reference) image") |
79 parser.add_argument("fn_tmat", help="Name of output file to save the transformation matrix") | 120 parser.add_argument("fn_tmat", help="Path to the output file for saving the transformation matrix") |
121 parser.add_argument("sigma", type=float, help="Sigma of Gaussian filter for smoothing input images") | |
122 parser.add_argument("metric", help="Image similarity metric") | |
80 args = parser.parse_args() | 123 args = parser.parse_args() |
81 | 124 |
82 image_registration(args.fn_moving, args.fn_fixed, args.fn_tmat) | 125 image_registration(args.fn_moving, args.fn_fixed, args.fn_tmat, args.sigma, args.metric) |