Mercurial > repos > imgteam > scale_image
comparison scale_image.py @ 7:e85846f4a05f draft default tip
planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/scale_image/ commit cd908933bd7bd8756c213af57ea6343a90effc12
| author | imgteam |
|---|---|
| date | Sat, 13 Dec 2025 22:11:29 +0000 |
| parents | 85666e555698 |
| children |
comparison
equal
deleted
inserted
replaced
| 6:72b8a6b7661b | 7:e85846f4a05f |
|---|---|
| 1 import argparse | 1 import argparse |
| 2 import json | |
| 2 import sys | 3 import sys |
| 4 from typing import ( | |
| 5 Any, | |
| 6 Literal, | |
| 7 ) | |
| 3 | 8 |
| 4 import giatools.io | 9 import giatools.io |
| 5 import numpy as np | 10 import numpy as np |
| 6 import skimage.io | 11 import skimage.io |
| 7 import skimage.transform | 12 import skimage.transform |
| 8 import skimage.util | 13 import skimage.util |
| 9 from PIL import Image | 14 |
| 10 | 15 |
| 11 | 16 def get_uniform_scale( |
| 12 def scale_image(input_file, output_file, scale, order, antialias): | 17 img: giatools.Image, |
| 13 Image.MAX_IMAGE_PIXELS = 50000 * 50000 | 18 axes: Literal['all', 'spatial'], |
| 14 im = giatools.io.imread(input_file) | 19 factor: float, |
| 15 | 20 ) -> tuple[float, ...]: |
| 16 # Parse `--scale` argument | 21 """ |
| 17 if ',' in scale: | 22 Determine a tuple of `scale` factors for uniform or spatially uniform scaling. |
| 18 scale = [float(s.strip()) for s in scale.split(',')] | 23 |
| 19 assert len(scale) <= im.ndim, f'Image has {im.ndim} axes, but scale factors were given for {len(scale)} axes.' | 24 Axes, that are not present in the original image data, are ignored. |
| 20 scale = scale + [1] * (im.ndim - len(scale)) | 25 """ |
| 21 | 26 ignored_axes = [ |
| 22 else: | 27 axis for axis_idx, axis in enumerate(img.axes) |
| 23 scale = float(scale) | 28 if axis not in img.original_axes or ( |
| 24 | 29 factor < 1 and img.data.shape[axis_idx] == 1 |
| 25 # For images with 3 or more axes, the last axis is assumed to correspond to channels | 30 ) |
| 26 if im.ndim >= 3: | 31 ] |
| 27 scale = [scale] * (im.ndim - 1) + [1] | 32 match axes: |
| 28 | 33 |
| 29 # Do the scaling | 34 case 'all': |
| 30 res = skimage.transform.rescale(im, scale, order, anti_aliasing=antialias, preserve_range=True) | 35 return tuple( |
| 36 [ | |
| 37 (factor if axis not in ignored_axes else 1) | |
| 38 for axis in img.axes if axis != 'C' | |
| 39 ] | |
| 40 ) | |
| 41 | |
| 42 case 'spatial': | |
| 43 return tuple( | |
| 44 [ | |
| 45 (factor if axis in 'YXZ' and axis not in ignored_axes else 1) | |
| 46 for axis in img.axes if axis != 'C' | |
| 47 ] | |
| 48 ) | |
| 49 | |
| 50 case _: | |
| 51 raise ValueError(f'Unknown axes for uniform scaling: "{axes}"') | |
| 52 | |
| 53 | |
| 54 def get_scale_for_isotropy( | |
| 55 img: giatools.Image, | |
| 56 sample: Literal['up', 'down'], | |
| 57 ) -> tuple[float, ...]: | |
| 58 """ | |
| 59 Determine a tuple of `scale` factors to establish spatial isotropy. | |
| 60 | |
| 61 The `sample` parameter governs whether to up-sample or down-sample the image data. | |
| 62 """ | |
| 63 scale = [1] * (len(img.axes) - 1) # omit the channel axis | |
| 64 z_axis, y_axis, x_axis = [ | |
| 65 img.axes.index(axis) for axis in 'ZYX' | |
| 66 ] | |
| 67 | |
| 68 # Determine the pixel size of the image | |
| 69 if 'resolution' in img.metadata: | |
| 70 pixel_size = np.divide(1, img.metadata['resolution']) | |
| 71 else: | |
| 72 sys.exit('Resolution information missing in image metadata') | |
| 73 | |
| 74 # Define unified transformation of pixel/voxel sizes to scale factors | |
| 75 def voxel_size_to_scale(voxel_size: np.ndarray) -> list: | |
| 76 match sample: | |
| 77 case 'up': | |
| 78 return (voxel_size / voxel_size.min()).tolist() | |
| 79 case 'down': | |
| 80 return (voxel_size / voxel_size.max()).tolist() | |
| 81 case _: | |
| 82 raise ValueError(f'Unknown value for sample: "{sample}"') | |
| 83 | |
| 84 # Handle the 3-D case | |
| 85 if img.data.shape[z_axis] > 1: | |
| 86 | |
| 87 # Determine the voxel depth of the image | |
| 88 if (voxel_depth := img.metadata.get('z_spacing', None)) is None: | |
| 89 sys.exit('Voxel depth information missing in image metadata') | |
| 90 | |
| 91 # Determine the XYZ scale factors | |
| 92 scale[x_axis], scale[y_axis], scale[z_axis] = ( | |
| 93 voxel_size_to_scale( | |
| 94 np.array([*pixel_size, voxel_depth]), | |
| 95 ) | |
| 96 ) | |
| 97 | |
| 98 # Handle the 2-D case | |
| 99 else: | |
| 100 | |
| 101 # Determine the XY scale factors | |
| 102 scale[x_axis], scale[y_axis] = ( | |
| 103 voxel_size_to_scale( | |
| 104 np.array(pixel_size), | |
| 105 ) | |
| 106 ) | |
| 107 | |
| 108 return tuple(scale) | |
| 109 | |
| 110 | |
| 111 def get_aa_sigma_by_scale(scale: float) -> float: | |
| 112 """ | |
| 113 Determine the optimal size of the Gaussian filter for anti-aliasing. | |
| 114 | |
| 115 See for details: https://scikit-image.org/docs/0.25.x/api/skimage.transform.html#skimage.transform.rescale | |
| 116 """ | |
| 117 return (1 / scale - 1) / 2 if scale < 1 else 0 | |
| 118 | |
| 119 | |
| 120 def get_new_metadata( | |
| 121 old: giatools.Image, | |
| 122 scale: float | tuple[float, ...], | |
| 123 arr: np.ndarray, | |
| 124 ) -> dict[str, Any]: | |
| 125 """ | |
| 126 Determine the result metadata (copy and adapt). | |
| 127 """ | |
| 128 metadata = dict(old.metadata) | |
| 129 scales = ( | |
| 130 [scale] * (len(old.axes) - 1) # omit the channel axis | |
| 131 if isinstance(scale, float) else scale | |
| 132 ) | |
| 133 | |
| 134 # Determine the original pixel size | |
| 135 old_pixel_size = ( | |
| 136 np.divide(1, old.metadata['resolution']) | |
| 137 if 'resolution' in old.metadata else (1, 1) | |
| 138 ) | |
| 139 | |
| 140 # Determine the new pixel size and update metadata | |
| 141 new_pixel_size = np.divide( | |
| 142 old_pixel_size, | |
| 143 ( | |
| 144 scales[old.axes.index('X')], | |
| 145 scales[old.axes.index('Y')], | |
| 146 ), | |
| 147 ) | |
| 148 metadata['resolution'] = tuple(1 / new_pixel_size) | |
| 149 | |
| 150 # Update the metadata for the new voxel depth | |
| 151 old_voxel_depth = old.metadata.get('z_spacing', 1) | |
| 152 metadata['z_spacing'] = old_voxel_depth / scales[old.axes.index('Z')] | |
| 153 | |
| 154 return metadata | |
| 155 | |
| 156 | |
| 157 def metadata_to_str(metadata: dict) -> str: | |
| 158 tokens = list() | |
| 159 for key in sorted(metadata.keys()): | |
| 160 value = metadata[key] | |
| 161 if isinstance(value, tuple): | |
| 162 value = '(' + ', '.join([f'{val}' for val in value]) + ')' | |
| 163 tokens.append(f'{key}: {value}') | |
| 164 if len(metadata_str := ', '.join(tokens)) > 0: | |
| 165 return metadata_str | |
| 166 else: | |
| 167 return 'has no metadata' | |
| 168 | |
| 169 | |
| 170 def write_output(filepath: str, img: giatools.Image): | |
| 171 """ | |
| 172 Validate that the output file format is suitable for the image data, then write it. | |
| 173 """ | |
| 174 print('Output shape:', img.data.shape) | |
| 175 print('Output axes:', img.axes) | |
| 176 print('Output', metadata_to_str(img.metadata)) | |
| 177 | |
| 178 # Validate that the output file format is suitable for the image data | |
| 179 if filepath.lower().endswith('.png'): | |
| 180 if not frozenset(img.axes) <= frozenset('YXC'): | |
| 181 sys.exit(f'Cannot write PNG file with axes "{img.axes}"') | |
| 182 | |
| 183 # Write image data to the output file | |
| 184 img.write(filepath) | |
| 185 | |
| 186 | |
| 187 def scale_image( | |
| 188 input_filepath: str, | |
| 189 output_filepath: str, | |
| 190 mode: Literal['uniform', 'explicit', 'isotropy'], | |
| 191 order: int, | |
| 192 anti_alias: bool, | |
| 193 **cfg, | |
| 194 ): | |
| 195 img = giatools.Image.read(input_filepath) | |
| 196 print('Input axes:', img.original_axes) | |
| 197 print('Input', metadata_to_str(img.metadata)) | |
| 198 | |
| 199 # Determine `scale` for scaling | |
| 200 match mode: | |
| 201 | |
| 202 case 'uniform': | |
| 203 scale = get_uniform_scale(img, cfg['axes'], cfg['factor']) | |
| 204 | |
| 205 case 'explicit': | |
| 206 scale = tuple( | |
| 207 [cfg.get(f'factor_{axis.lower()}', 1) for axis in img.axes if axis != 'C'] | |
| 208 ) | |
| 209 | |
| 210 case 'isotropy': | |
| 211 scale = get_scale_for_isotropy(img, cfg['sample']) | |
| 212 | |
| 213 case _: | |
| 214 raise ValueError(f'Unknown mode: "{mode}"') | |
| 215 | |
| 216 # Assemble remaining `rescale` parameters | |
| 217 rescale_kwargs = dict( | |
| 218 scale=scale, | |
| 219 order=order, | |
| 220 preserve_range=True, | |
| 221 channel_axis=img.axes.index('C'), | |
| 222 ) | |
| 223 if (anti_alias := anti_alias and (np.array(scale) < 1).any()): | |
| 224 rescale_kwargs['anti_aliasing'] = anti_alias | |
| 225 rescale_kwargs['anti_aliasing_sigma'] = tuple( | |
| 226 [ | |
| 227 get_aa_sigma_by_scale(s) for s in scale | |
| 228 ] + [0] # `skimage.transform.rescale` also expects a value for the channel axis | |
| 229 ) | |
| 230 else: | |
| 231 rescale_kwargs['anti_aliasing'] = False | |
| 232 | |
| 233 # Re-sample the image data to perform the scaling | |
| 234 for key, value in rescale_kwargs.items(): | |
| 235 print(f'{key}: {value}') | |
| 236 arr = skimage.transform.rescale(img.data, **rescale_kwargs) | |
| 31 | 237 |
| 32 # Preserve the `dtype` so that both brightness and range of values is preserved | 238 # Preserve the `dtype` so that both brightness and range of values is preserved |
| 33 if res.dtype != im.dtype: | 239 if arr.dtype != img.data.dtype: |
| 34 if np.issubdtype(im.dtype, np.integer): | 240 if np.issubdtype(img.data.dtype, np.integer): |
| 35 res = res.round() | 241 arr = arr.round() |
| 36 res = res.astype(im.dtype) | 242 arr = arr.astype(img.data.dtype) |
| 37 | 243 |
| 38 # Save result | 244 # Determine the result metadata and save result |
| 39 skimage.io.imsave(output_file, res) | 245 metadata = get_new_metadata(img, scale, arr) |
| 246 write_output( | |
| 247 output_filepath, | |
| 248 giatools.Image( | |
| 249 data=arr, | |
| 250 axes=img.axes, | |
| 251 metadata=metadata, | |
| 252 ).squeeze() | |
| 253 ) | |
| 40 | 254 |
| 41 | 255 |
| 42 if __name__ == "__main__": | 256 if __name__ == "__main__": |
| 43 parser = argparse.ArgumentParser() | 257 parser = argparse.ArgumentParser() |
| 44 parser.add_argument('input_file', type=argparse.FileType('r'), default=sys.stdin) | 258 parser.add_argument('input', type=str) |
| 45 parser.add_argument('out_file', type=argparse.FileType('w'), default=sys.stdin) | 259 parser.add_argument('output', type=str) |
| 46 parser.add_argument('--scale', type=str, required=True) | 260 parser.add_argument('params', type=str) |
| 47 parser.add_argument('--order', type=int, required=True) | |
| 48 parser.add_argument('--antialias', default=False, action='store_true') | |
| 49 args = parser.parse_args() | 261 args = parser.parse_args() |
| 50 | 262 |
| 51 scale_image(args.input_file.name, args.out_file.name, args.scale, args.order, args.antialias) | 263 # Read the config file |
| 264 with open(args.params) as cfgf: | |
| 265 cfg = json.load(cfgf) | |
| 266 | |
| 267 # Perform scaling | |
| 268 scale_image( | |
| 269 args.input, | |
| 270 args.output, | |
| 271 **cfg, | |
| 272 ) |
