comparison image_registration_affine.py @ 1:fa769715b6b0 draft

"planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tools/image_registration_affine/ commit e82400162e337b36c29d6e79fb2deb9871475397"
author imgteam
date Thu, 20 Jan 2022 00:45:30 +0000
parents e34222a620d4
children 77dc68af2b40
comparison
equal deleted inserted replaced
0:e34222a620d4 1:fa769715b6b0
1 """
2 Copyright 2021-2022 Biomedical Computer Vision Group, Heidelberg University.
3
4 Distributed under the MIT license.
5 See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
6
7 """
8 import argparse
9
10 import numpy as np
11 import pandas as pd
1 import skimage.io 12 import skimage.io
2 from skimage.transform import ProjectiveTransform
3 from skimage.filters import gaussian
4 from scipy.ndimage import map_coordinates 13 from scipy.ndimage import map_coordinates
5 from scipy.optimize import least_squares 14 from scipy.optimize import least_squares
6 import numpy as np 15 from scipy.signal import convolve2d
7 import pandas as pd 16 from skimage.color import rgb2gray, rgba2rgb
8 import argparse 17 from skimage.filters import gaussian
9 18 from skimage.transform import ProjectiveTransform
10 19
11 20
12 def _stackcopy(a, b): 21 def _stackcopy(a, b):
13 if a.ndim == 3: 22 if a.ndim == 3:
14 a[:] = b[:, :, np.newaxis] 23 a[:] = b[:, :, np.newaxis]
15 else: 24 else:
16 a[:] = b 25 a[:] = b
17
18 26
19 27
20 def warp_coords_batch(coord_map, shape, dtype=np.float64, batch_size=1000000): 28 def warp_coords_batch(coord_map, shape, dtype=np.float64, batch_size=1000000):
21 rows, cols = shape[0], shape[1] 29 rows, cols = shape[0], shape[1]
22 coords_shape = [len(shape), rows, cols] 30 coords_shape = [len(shape), rows, cols]
24 coords_shape.append(shape[2]) 32 coords_shape.append(shape[2])
25 coords = np.empty(coords_shape, dtype=dtype) 33 coords = np.empty(coords_shape, dtype=dtype)
26 34
27 tf_coords = np.indices((cols, rows), dtype=dtype).reshape(2, -1).T 35 tf_coords = np.indices((cols, rows), dtype=dtype).reshape(2, -1).T
28 36
29 for i in range(0, (tf_coords.shape[0]//batch_size+1)): 37 for i in range(0, (tf_coords.shape[0] // batch_size + 1)):
30 tf_coords[batch_size*i:batch_size*(i+1)] = coord_map(tf_coords[batch_size*i:batch_size*(i+1)]) 38 tf_coords[batch_size * i:batch_size * (i + 1)] = coord_map(tf_coords[batch_size * i:batch_size * (i + 1)])
31 tf_coords = tf_coords.T.reshape((-1, cols, rows)).swapaxes(1, 2) 39 tf_coords = tf_coords.T.reshape((-1, cols, rows)).swapaxes(1, 2)
32 40
33 _stackcopy(coords[1, ...], tf_coords[0, ...]) 41 _stackcopy(coords[1, ...], tf_coords[0, ...])
34 _stackcopy(coords[0, ...], tf_coords[1, ...]) 42 _stackcopy(coords[0, ...], tf_coords[1, ...])
35 if len(shape) == 3: 43 if len(shape) == 3:
36 coords[2, ...] = range(shape[2]) 44 coords[2, ...] = range(shape[2])
37 45
38 return coords 46 return coords
39 47
40 48
49 def affine_registration(params, moving, fixed, metric='mae'):
50 tmat = np.eye(3)
51 tmat[0, :] = params.take([0, 1, 2])
52 tmat[1, :] = params.take([3, 4, 5])
41 53
42 def affine_registration(params,moving,fixed):
43 tmat = np.eye(3)
44 tmat[0,:] = params.take([0,1,2])
45 tmat[1,:] = params.take([3,4,5])
46
47 trans = ProjectiveTransform(matrix=tmat) 54 trans = ProjectiveTransform(matrix=tmat)
48 warped_coords = warp_coords_batch(trans, fixed.shape) 55 warped_coords = warp_coords_batch(trans, fixed.shape)
49 t = map_coordinates(moving, warped_coords, mode='reflect') 56 t = map_coordinates(moving, warped_coords, mode='nearest')
50 57 f = fixed
51 eI = (t - fixed)**2 58
52 return eI.flatten() 59 if metric == 'mse':
60 err = (t - f) ** 2
61
62 elif metric == 'mae':
63 err = (t - f)
64
65 elif metric == 'lcc':
66 sum_filt = np.ones((9, 9))
67 win_size = 81
68
69 t_sum = convolve2d(t, sum_filt, mode='same', boundary='symm')
70 f_sum = convolve2d(f, sum_filt, mode='same', boundary='symm')
71 t2_sum = convolve2d(t * t, sum_filt, mode='same', boundary='symm')
72 f2_sum = convolve2d(f * f, sum_filt, mode='same', boundary='symm')
73 tf_sum = convolve2d(t * f, sum_filt, mode='same', boundary='symm')
74
75 cross = tf_sum - f_sum * t_sum / win_size
76 t_var = t2_sum - t_sum * t_sum / win_size
77 f_var = f2_sum - f_sum * f_sum / win_size
78 cc = cross * cross / (t_var * f_var + 1e-5)
79 err = 1 - cc
80
81 return err.flatten()
53 82
54 83
84 def read_img_as_gray(fn):
85 im = skimage.io.imread(fn)
86 nDims = len(im.shape)
87 assert nDims in [2, 3], 'this tool does not support multichannel images'
88 if nDims == 3:
89 assert im.shape[-1] in [3, 4], 'this tool does not support multichannel images'
90 if im.shape[-1] == 4:
91 im = rgba2rgb(im)
92 im = rgb2gray(im)
93 im = im.astype(float)
94 im = im / np.max(im)
95 return im
55 96
56 def image_registration(fn_moving, fn_fixed, fn_out, smooth_sigma=1): 97
57 moving = skimage.io.imread(fn_moving,as_gray=True) 98 def image_registration(fn_moving, fn_fixed, fn_out, smooth_sigma=3, metric='lcc'):
58 fixed = skimage.io.imread(fn_fixed,as_gray=True) 99 moving = read_img_as_gray(fn_moving)
100 fixed = read_img_as_gray(fn_fixed)
59 101
60 moving = gaussian(moving, sigma=smooth_sigma) 102 moving = gaussian(moving, sigma=smooth_sigma)
61 fixed = gaussian(fixed, sigma=smooth_sigma) 103 fixed = gaussian(fixed, sigma=smooth_sigma)
62 104
63 x = np.array([1, 0, 0, 0, 1, 0],dtype='float64') 105 x = np.array([1, 0, 0, 0, 1, 0], dtype='float64')
64 result = least_squares(affine_registration, x, args=(moving,fixed)) 106 result = least_squares(affine_registration, x, args=(moving, fixed, metric))
65 107
66 tmat = np.eye(3) 108 tmat = np.eye(3)
67 tmat[0,:] = result.x.take([0,1,2]) 109 tmat[0, :] = result.x.take([0, 1, 2])
68 tmat[1,:] = result.x.take([3,4,5]) 110 tmat[1, :] = result.x.take([3, 4, 5])
69 111
70 pd.DataFrame(tmat).to_csv(fn_out, header=None, index=False, sep="\t") 112 pd.DataFrame(tmat).to_csv(fn_out, header=None, index=False, sep="\t")
71
72 113
73 114
74 if __name__ == "__main__": 115 if __name__ == "__main__":
75 116
76 parser = argparse.ArgumentParser(description="Estimate the transformation matrix") 117 parser = argparse.ArgumentParser(description="Estimate the transformation matrix")
77 parser.add_argument("fn_moving", help="Name of the moving image.png") 118 parser.add_argument("fn_moving", help="Path to the moving image")
78 parser.add_argument("fn_fixed", help="Name of the fixed image.png") 119 parser.add_argument("fn_fixed", help="Path to the fixed (reference) image")
79 parser.add_argument("fn_tmat", help="Name of output file to save the transformation matrix") 120 parser.add_argument("fn_tmat", help="Path to the output file for saving the transformation matrix")
121 parser.add_argument("sigma", type=float, help="Sigma of Gaussian filter for smoothing input images")
122 parser.add_argument("metric", help="Image similarity metric")
80 args = parser.parse_args() 123 args = parser.parse_args()
81 124
82 image_registration(args.fn_moving, args.fn_fixed, args.fn_tmat) 125 image_registration(args.fn_moving, args.fn_fixed, args.fn_tmat, args.sigma, args.metric)