diff image_registration_affine.py @ 1:fa769715b6b0 draft

"planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tools/image_registration_affine/ commit e82400162e337b36c29d6e79fb2deb9871475397"
author imgteam
date Thu, 20 Jan 2022 00:45:30 +0000
parents e34222a620d4
children 77dc68af2b40
line wrap: on
line diff
--- a/image_registration_affine.py	Wed Dec 30 20:24:35 2020 +0000
+++ b/image_registration_affine.py	Thu Jan 20 00:45:30 2022 +0000
@@ -1,12 +1,21 @@
+"""
+Copyright 2021-2022 Biomedical Computer Vision Group, Heidelberg University.
+
+Distributed under the MIT license.
+See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
+
+"""
+import argparse
+
+import numpy as np
+import pandas as pd
 import skimage.io
-from skimage.transform import ProjectiveTransform
-from skimage.filters import gaussian
 from scipy.ndimage import map_coordinates
 from scipy.optimize import least_squares
-import numpy as np
-import pandas as pd
-import argparse
-
+from scipy.signal import convolve2d
+from skimage.color import rgb2gray, rgba2rgb
+from skimage.filters import gaussian
+from skimage.transform import ProjectiveTransform
 
 
 def _stackcopy(a, b):
@@ -16,7 +25,6 @@
         a[:] = b
 
 
-
 def warp_coords_batch(coord_map, shape, dtype=np.float64, batch_size=1000000):
     rows, cols = shape[0], shape[1]
     coords_shape = [len(shape), rows, cols]
@@ -26,8 +34,8 @@
 
     tf_coords = np.indices((cols, rows), dtype=dtype).reshape(2, -1).T
 
-    for i in range(0, (tf_coords.shape[0]//batch_size+1)):
-        tf_coords[batch_size*i:batch_size*(i+1)] = coord_map(tf_coords[batch_size*i:batch_size*(i+1)])
+    for i in range(0, (tf_coords.shape[0] // batch_size + 1)):
+        tf_coords[batch_size * i:batch_size * (i + 1)] = coord_map(tf_coords[batch_size * i:batch_size * (i + 1)])
     tf_coords = tf_coords.T.reshape((-1, cols, rows)).swapaxes(1, 2)
 
     _stackcopy(coords[1, ...], tf_coords[0, ...])
@@ -38,45 +46,80 @@
     return coords
 
 
-
-def affine_registration(params,moving,fixed):
+def affine_registration(params, moving, fixed, metric='mae'):
     tmat = np.eye(3)
-    tmat[0,:] = params.take([0,1,2])
-    tmat[1,:] = params.take([3,4,5])
-    
+    tmat[0, :] = params.take([0, 1, 2])
+    tmat[1, :] = params.take([3, 4, 5])
+
     trans = ProjectiveTransform(matrix=tmat)
     warped_coords = warp_coords_batch(trans, fixed.shape)
-    t = map_coordinates(moving, warped_coords, mode='reflect')
-    
-    eI = (t - fixed)**2
-    return eI.flatten()
+    t = map_coordinates(moving, warped_coords, mode='nearest')
+    f = fixed
+
+    if metric == 'mse':
+        err = (t - f) ** 2
+
+    elif metric == 'mae':
+        err = (t - f)
+
+    elif metric == 'lcc':
+        sum_filt = np.ones((9, 9))
+        win_size = 81
+
+        t_sum = convolve2d(t, sum_filt, mode='same', boundary='symm')
+        f_sum = convolve2d(f, sum_filt, mode='same', boundary='symm')
+        t2_sum = convolve2d(t * t, sum_filt, mode='same', boundary='symm')
+        f2_sum = convolve2d(f * f, sum_filt, mode='same', boundary='symm')
+        tf_sum = convolve2d(t * f, sum_filt, mode='same', boundary='symm')
+
+        cross = tf_sum - f_sum * t_sum / win_size
+        t_var = t2_sum - t_sum * t_sum / win_size
+        f_var = f2_sum - f_sum * f_sum / win_size
+        cc = cross * cross / (t_var * f_var + 1e-5)
+        err = 1 - cc
+
+    return err.flatten()
 
 
+def read_img_as_gray(fn):
+    im = skimage.io.imread(fn)
+    nDims = len(im.shape)
+    assert nDims in [2, 3], 'this tool does not support multichannel images'
+    if nDims == 3:
+        assert im.shape[-1] in [3, 4], 'this tool does not support multichannel images'
+        if im.shape[-1] == 4:
+            im = rgba2rgb(im)
+        im = rgb2gray(im)
+    im = im.astype(float)
+    im = im / np.max(im)
+    return im
 
-def image_registration(fn_moving, fn_fixed, fn_out, smooth_sigma=1):
-    moving = skimage.io.imread(fn_moving,as_gray=True)
-    fixed = skimage.io.imread(fn_fixed,as_gray=True)
+
+def image_registration(fn_moving, fn_fixed, fn_out, smooth_sigma=3, metric='lcc'):
+    moving = read_img_as_gray(fn_moving)
+    fixed = read_img_as_gray(fn_fixed)
 
     moving = gaussian(moving, sigma=smooth_sigma)
     fixed = gaussian(fixed, sigma=smooth_sigma)
-    
-    x = np.array([1, 0, 0, 0, 1, 0],dtype='float64')
-    result = least_squares(affine_registration, x, args=(moving,fixed))
-        
+
+    x = np.array([1, 0, 0, 0, 1, 0], dtype='float64')
+    result = least_squares(affine_registration, x, args=(moving, fixed, metric))
+
     tmat = np.eye(3)
-    tmat[0,:] = result.x.take([0,1,2])
-    tmat[1,:] = result.x.take([3,4,5])
-    
+    tmat[0, :] = result.x.take([0, 1, 2])
+    tmat[1, :] = result.x.take([3, 4, 5])
+
     pd.DataFrame(tmat).to_csv(fn_out, header=None, index=False, sep="\t")
-    
 
 
 if __name__ == "__main__":
 
     parser = argparse.ArgumentParser(description="Estimate the transformation matrix")
-    parser.add_argument("fn_moving", help="Name of the moving image.png")
-    parser.add_argument("fn_fixed",  help="Name of the fixed image.png")
-    parser.add_argument("fn_tmat",   help="Name of output file to save the transformation matrix")
+    parser.add_argument("fn_moving", help="Path to the moving image")
+    parser.add_argument("fn_fixed", help="Path to the fixed (reference) image")
+    parser.add_argument("fn_tmat", help="Path to the output file for saving the transformation matrix")
+    parser.add_argument("sigma", type=float, help="Sigma of Gaussian filter for smoothing input images")
+    parser.add_argument("metric", help="Image similarity metric")
     args = parser.parse_args()
-    
-    image_registration(args.fn_moving, args.fn_fixed, args.fn_tmat)
+
+    image_registration(args.fn_moving, args.fn_fixed, args.fn_tmat, args.sigma, args.metric)