comparison concat_channels.py @ 6:999c5941a6f0 draft default tip

planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/concat_channels/ commit a94f04c109c545a9f892a6ce7a5ffef152253201
author imgteam
date Fri, 12 Dec 2025 21:15:56 +0000
parents 8d50a0a9e4af
children
comparison
equal deleted inserted replaced
5:8d50a0a9e4af 6:999c5941a6f0
1 import argparse 1 import argparse
2 from typing import Any
2 3
3 import giatools 4 import giatools
4 import numpy as np 5 import numpy as np
5 import skimage.io 6 import skimage.io
6 import skimage.util 7 import skimage.util
7 8
8 9
9 normalized_axes = 'QTZYXC'
10
11
12 def concat_channels( 10 def concat_channels(
13 input_image_paths: list[str], 11 input_image_paths: list[str],
14 output_image_path: str, 12 output_image_path: str,
15 axis: str, 13 axis: str,
16 preserve_values: bool, 14 preserve_values: bool,
15 sort_by: str | None,
17 ): 16 ):
18 # Create list of arrays to be concatenated 17 # Create list of arrays to be concatenated
19 images = [] 18 images = list()
19 metadata = dict()
20 for image_path in input_image_paths: 20 for image_path in input_image_paths:
21 21
22 img = giatools.Image.read(image_path, normalize_axes=normalized_axes) 22 img = giatools.Image.read(image_path, normalize_axes=giatools.default_normalized_axes)
23 arr = img.data 23 arr = img.data
24 24
25 # Preserve values: Convert to `float` dtype without changing the values 25 # Preserve values: Convert to `float` dtype without changing the values
26 if preserve_values: 26 if preserve_values:
27 arr = arr.astype(float) 27 arr = arr.astype(float)
28 28
29 # Preserve brightness: Scale values to 0..1 29 # Preserve brightness: Scale values to 0..1
30 else: 30 else:
31 arr = skimage.util.img_as_float(arr) 31 arr = skimage.util.img_as_float(arr)
32 32
33 # Record the metadata
34 for metadata_key, metadata_value in img.metadata.items():
35 metadata.setdefault(metadata_key, list())
36 metadata[metadata_key].append(metadata_value)
37
38 # Record the image data
33 images.append(arr) 39 images.append(arr)
34 40
41 # Perform sorting, if requested
42 if sort_by is not None:
43
44 # Validate that `sort_by` is available as metadata for all images
45 sort_keys = list(
46 filter(
47 lambda value: value is not None,
48 metadata.get(sort_by, list()),
49 ),
50 )
51 if len(sort_keys) != len(images):
52 raise ValueError(
53 f'Requested to sort by "{sort_by}", '
54 f'but this is not available for all {len(images)} images'
55 f' (available for only {len(sort_keys)} images)'
56 )
57
58 # Sort images by the corresponding `sort_key` metadata value
59 sorted_indices = sorted(range(len(images)), key=lambda i: sort_keys[i])
60 images = [images[i] for i in sorted_indices]
61
62 # Determine consensual metadata
63 # TODO: Convert metadata of images with different units of measurement into a common unit
64 final_metadata = dict()
65 for metadata_key, metadata_values in metadata.items():
66 if (metadata_value := reduce_metadata(metadata_values)) is not None:
67 final_metadata[metadata_key] = metadata_value
68
69 # Update the `z_spacing` metadata, if concatenating along the Z-axis and `z_position` is available for all images
70 if axis == 'Z' and len(images) >= 2 and len(z_positions := metadata.get('z_position', list())) == len(images):
71 z_positions = sorted(z_positions) # don't mutate the `metadata` dictionary for easier future code maintenance
72 final_metadata['z_spacing'] = abs(np.subtract(z_positions[1:], z_positions[:-1]).mean())
73
35 # Do the concatenation 74 # Do the concatenation
36 axis_pos = normalized_axes.index(axis) 75 axis_pos = giatools.default_normalized_axes.index(axis)
37 arr = np.concatenate(images, axis_pos) 76 arr = np.concatenate(images, axis_pos)
38 res = giatools.Image(arr, normalized_axes) 77 res = giatools.Image(
78 data=arr,
79 axes=giatools.default_normalized_axes,
80 metadata=final_metadata,
81 )
39 82
40 # Squeeze singleton axes and save 83 # Squeeze singleton axes and save
41 squeezed_axes = ''.join(np.array(list(res.axes))[np.array(arr.shape) > 1]) 84 res = res.squeeze()
42 res = res.squeeze_like(squeezed_axes) 85 print('Output TIFF shape:', res.data.shape)
86 print('Output TIFF axes:', res.axes)
87 print('Output TIFF', metadata_to_str(final_metadata))
43 res.write(output_image_path, backend='tifffile') 88 res.write(output_image_path, backend='tifffile')
89
90
91 def reduce_metadata(values: list[Any]) -> Any | None:
92 non_none_values = list(filter(lambda value: value is not None, values))
93
94 # Reduction is not possible if more than one type is involved (or none)
95 value_types = [type(value) for value in non_none_values]
96 if len(frozenset(value_types)) != 1:
97 return None
98 else:
99 value_type = value_types[0]
100
101 # For floating point types, reduce via arithmetic average
102 if np.issubdtype(value_type, np.floating):
103 return np.mean(non_none_values)
104
105 # For integer types, reduce via the median
106 if np.issubdtype(value_type, np.integer):
107 return int(np.median(non_none_values))
108
109 # For all other types, reduction is only possible if the values are identical
110 if len(frozenset(non_none_values)) == 1:
111 return non_none_values[0]
112 else:
113 return None
114
115
116 def metadata_to_str(metadata: dict) -> str:
117 tokens = list()
118 for key in sorted(metadata.keys()):
119 value = metadata[key]
120 if isinstance(value, tuple):
121 value = '(' + ', '.join([f'{val}' for val in value]) + ')'
122 tokens.append(f'{key}: {value}')
123 return ', '.join(tokens)
44 124
45 125
46 if __name__ == "__main__": 126 if __name__ == "__main__":
47 parser = argparse.ArgumentParser() 127 parser = argparse.ArgumentParser()
48 parser.add_argument('input_files', type=str, nargs='+') 128 parser.add_argument('input_files', type=str, nargs='+')
49 parser.add_argument('out_file', type=str) 129 parser.add_argument('out_file', type=str)
50 parser.add_argument('axis', type=str) 130 parser.add_argument('axis', type=str)
51 parser.add_argument('--preserve_values', default=False, action='store_true') 131 parser.add_argument('--preserve_values', default=False, action='store_true')
132 parser.add_argument('--sort_by', type=str, default=None)
52 args = parser.parse_args() 133 args = parser.parse_args()
53 134
54 concat_channels( 135 concat_channels(
55 args.input_files, 136 args.input_files,
56 args.out_file, 137 args.out_file,
57 args.axis, 138 args.axis,
58 args.preserve_values, 139 args.preserve_values,
140 args.sort_by,
59 ) 141 )