Mercurial > repos > imgteam > concat_channels
comparison concat_channels.py @ 6:999c5941a6f0 draft default tip
planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/concat_channels/ commit a94f04c109c545a9f892a6ce7a5ffef152253201
| author | imgteam |
|---|---|
| date | Fri, 12 Dec 2025 21:15:56 +0000 |
| parents | 8d50a0a9e4af |
| children |
comparison
equal
deleted
inserted
replaced
| 5:8d50a0a9e4af | 6:999c5941a6f0 |
|---|---|
| 1 import argparse | 1 import argparse |
| 2 from typing import Any | |
| 2 | 3 |
| 3 import giatools | 4 import giatools |
| 4 import numpy as np | 5 import numpy as np |
| 5 import skimage.io | 6 import skimage.io |
| 6 import skimage.util | 7 import skimage.util |
| 7 | 8 |
| 8 | 9 |
| 9 normalized_axes = 'QTZYXC' | |
| 10 | |
| 11 | |
| 12 def concat_channels( | 10 def concat_channels( |
| 13 input_image_paths: list[str], | 11 input_image_paths: list[str], |
| 14 output_image_path: str, | 12 output_image_path: str, |
| 15 axis: str, | 13 axis: str, |
| 16 preserve_values: bool, | 14 preserve_values: bool, |
| 15 sort_by: str | None, | |
| 17 ): | 16 ): |
| 18 # Create list of arrays to be concatenated | 17 # Create list of arrays to be concatenated |
| 19 images = [] | 18 images = list() |
| 19 metadata = dict() | |
| 20 for image_path in input_image_paths: | 20 for image_path in input_image_paths: |
| 21 | 21 |
| 22 img = giatools.Image.read(image_path, normalize_axes=normalized_axes) | 22 img = giatools.Image.read(image_path, normalize_axes=giatools.default_normalized_axes) |
| 23 arr = img.data | 23 arr = img.data |
| 24 | 24 |
| 25 # Preserve values: Convert to `float` dtype without changing the values | 25 # Preserve values: Convert to `float` dtype without changing the values |
| 26 if preserve_values: | 26 if preserve_values: |
| 27 arr = arr.astype(float) | 27 arr = arr.astype(float) |
| 28 | 28 |
| 29 # Preserve brightness: Scale values to 0..1 | 29 # Preserve brightness: Scale values to 0..1 |
| 30 else: | 30 else: |
| 31 arr = skimage.util.img_as_float(arr) | 31 arr = skimage.util.img_as_float(arr) |
| 32 | 32 |
| 33 # Record the metadata | |
| 34 for metadata_key, metadata_value in img.metadata.items(): | |
| 35 metadata.setdefault(metadata_key, list()) | |
| 36 metadata[metadata_key].append(metadata_value) | |
| 37 | |
| 38 # Record the image data | |
| 33 images.append(arr) | 39 images.append(arr) |
| 34 | 40 |
| 41 # Perform sorting, if requested | |
| 42 if sort_by is not None: | |
| 43 | |
| 44 # Validate that `sort_by` is available as metadata for all images | |
| 45 sort_keys = list( | |
| 46 filter( | |
| 47 lambda value: value is not None, | |
| 48 metadata.get(sort_by, list()), | |
| 49 ), | |
| 50 ) | |
| 51 if len(sort_keys) != len(images): | |
| 52 raise ValueError( | |
| 53 f'Requested to sort by "{sort_by}", ' | |
| 54 f'but this is not available for all {len(images)} images' | |
| 55 f' (available for only {len(sort_keys)} images)' | |
| 56 ) | |
| 57 | |
| 58 # Sort images by the corresponding `sort_key` metadata value | |
| 59 sorted_indices = sorted(range(len(images)), key=lambda i: sort_keys[i]) | |
| 60 images = [images[i] for i in sorted_indices] | |
| 61 | |
| 62 # Determine consensual metadata | |
| 63 # TODO: Convert metadata of images with different units of measurement into a common unit | |
| 64 final_metadata = dict() | |
| 65 for metadata_key, metadata_values in metadata.items(): | |
| 66 if (metadata_value := reduce_metadata(metadata_values)) is not None: | |
| 67 final_metadata[metadata_key] = metadata_value | |
| 68 | |
| 69 # Update the `z_spacing` metadata, if concatenating along the Z-axis and `z_position` is available for all images | |
| 70 if axis == 'Z' and len(images) >= 2 and len(z_positions := metadata.get('z_position', list())) == len(images): | |
| 71 z_positions = sorted(z_positions) # don't mutate the `metadata` dictionary for easier future code maintenance | |
| 72 final_metadata['z_spacing'] = abs(np.subtract(z_positions[1:], z_positions[:-1]).mean()) | |
| 73 | |
| 35 # Do the concatenation | 74 # Do the concatenation |
| 36 axis_pos = normalized_axes.index(axis) | 75 axis_pos = giatools.default_normalized_axes.index(axis) |
| 37 arr = np.concatenate(images, axis_pos) | 76 arr = np.concatenate(images, axis_pos) |
| 38 res = giatools.Image(arr, normalized_axes) | 77 res = giatools.Image( |
| 78 data=arr, | |
| 79 axes=giatools.default_normalized_axes, | |
| 80 metadata=final_metadata, | |
| 81 ) | |
| 39 | 82 |
| 40 # Squeeze singleton axes and save | 83 # Squeeze singleton axes and save |
| 41 squeezed_axes = ''.join(np.array(list(res.axes))[np.array(arr.shape) > 1]) | 84 res = res.squeeze() |
| 42 res = res.squeeze_like(squeezed_axes) | 85 print('Output TIFF shape:', res.data.shape) |
| 86 print('Output TIFF axes:', res.axes) | |
| 87 print('Output TIFF', metadata_to_str(final_metadata)) | |
| 43 res.write(output_image_path, backend='tifffile') | 88 res.write(output_image_path, backend='tifffile') |
| 89 | |
| 90 | |
| 91 def reduce_metadata(values: list[Any]) -> Any | None: | |
| 92 non_none_values = list(filter(lambda value: value is not None, values)) | |
| 93 | |
| 94 # Reduction is not possible if more than one type is involved (or none) | |
| 95 value_types = [type(value) for value in non_none_values] | |
| 96 if len(frozenset(value_types)) != 1: | |
| 97 return None | |
| 98 else: | |
| 99 value_type = value_types[0] | |
| 100 | |
| 101 # For floating point types, reduce via arithmetic average | |
| 102 if np.issubdtype(value_type, np.floating): | |
| 103 return np.mean(non_none_values) | |
| 104 | |
| 105 # For integer types, reduce via the median | |
| 106 if np.issubdtype(value_type, np.integer): | |
| 107 return int(np.median(non_none_values)) | |
| 108 | |
| 109 # For all other types, reduction is only possible if the values are identical | |
| 110 if len(frozenset(non_none_values)) == 1: | |
| 111 return non_none_values[0] | |
| 112 else: | |
| 113 return None | |
| 114 | |
| 115 | |
| 116 def metadata_to_str(metadata: dict) -> str: | |
| 117 tokens = list() | |
| 118 for key in sorted(metadata.keys()): | |
| 119 value = metadata[key] | |
| 120 if isinstance(value, tuple): | |
| 121 value = '(' + ', '.join([f'{val}' for val in value]) + ')' | |
| 122 tokens.append(f'{key}: {value}') | |
| 123 return ', '.join(tokens) | |
| 44 | 124 |
| 45 | 125 |
| 46 if __name__ == "__main__": | 126 if __name__ == "__main__": |
| 47 parser = argparse.ArgumentParser() | 127 parser = argparse.ArgumentParser() |
| 48 parser.add_argument('input_files', type=str, nargs='+') | 128 parser.add_argument('input_files', type=str, nargs='+') |
| 49 parser.add_argument('out_file', type=str) | 129 parser.add_argument('out_file', type=str) |
| 50 parser.add_argument('axis', type=str) | 130 parser.add_argument('axis', type=str) |
| 51 parser.add_argument('--preserve_values', default=False, action='store_true') | 131 parser.add_argument('--preserve_values', default=False, action='store_true') |
| 132 parser.add_argument('--sort_by', type=str, default=None) | |
| 52 args = parser.parse_args() | 133 args = parser.parse_args() |
| 53 | 134 |
| 54 concat_channels( | 135 concat_channels( |
| 55 args.input_files, | 136 args.input_files, |
| 56 args.out_file, | 137 args.out_file, |
| 57 args.axis, | 138 args.axis, |
| 58 args.preserve_values, | 139 args.preserve_values, |
| 140 args.sort_by, | |
| 59 ) | 141 ) |
