comparison scale_image.py @ 7:e85846f4a05f draft default tip

planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/scale_image/ commit cd908933bd7bd8756c213af57ea6343a90effc12
author imgteam
date Sat, 13 Dec 2025 22:11:29 +0000
parents 85666e555698
children
comparison
equal deleted inserted replaced
6:72b8a6b7661b 7:e85846f4a05f
1 import argparse 1 import argparse
2 import json
2 import sys 3 import sys
4 from typing import (
5 Any,
6 Literal,
7 )
3 8
4 import giatools.io 9 import giatools.io
5 import numpy as np 10 import numpy as np
6 import skimage.io 11 import skimage.io
7 import skimage.transform 12 import skimage.transform
8 import skimage.util 13 import skimage.util
9 from PIL import Image 14
10 15
11 16 def get_uniform_scale(
12 def scale_image(input_file, output_file, scale, order, antialias): 17 img: giatools.Image,
13 Image.MAX_IMAGE_PIXELS = 50000 * 50000 18 axes: Literal['all', 'spatial'],
14 im = giatools.io.imread(input_file) 19 factor: float,
15 20 ) -> tuple[float, ...]:
16 # Parse `--scale` argument 21 """
17 if ',' in scale: 22 Determine a tuple of `scale` factors for uniform or spatially uniform scaling.
18 scale = [float(s.strip()) for s in scale.split(',')] 23
19 assert len(scale) <= im.ndim, f'Image has {im.ndim} axes, but scale factors were given for {len(scale)} axes.' 24 Axes, that are not present in the original image data, are ignored.
20 scale = scale + [1] * (im.ndim - len(scale)) 25 """
21 26 ignored_axes = [
22 else: 27 axis for axis_idx, axis in enumerate(img.axes)
23 scale = float(scale) 28 if axis not in img.original_axes or (
24 29 factor < 1 and img.data.shape[axis_idx] == 1
25 # For images with 3 or more axes, the last axis is assumed to correspond to channels 30 )
26 if im.ndim >= 3: 31 ]
27 scale = [scale] * (im.ndim - 1) + [1] 32 match axes:
28 33
29 # Do the scaling 34 case 'all':
30 res = skimage.transform.rescale(im, scale, order, anti_aliasing=antialias, preserve_range=True) 35 return tuple(
36 [
37 (factor if axis not in ignored_axes else 1)
38 for axis in img.axes if axis != 'C'
39 ]
40 )
41
42 case 'spatial':
43 return tuple(
44 [
45 (factor if axis in 'YXZ' and axis not in ignored_axes else 1)
46 for axis in img.axes if axis != 'C'
47 ]
48 )
49
50 case _:
51 raise ValueError(f'Unknown axes for uniform scaling: "{axes}"')
52
53
54 def get_scale_for_isotropy(
55 img: giatools.Image,
56 sample: Literal['up', 'down'],
57 ) -> tuple[float, ...]:
58 """
59 Determine a tuple of `scale` factors to establish spatial isotropy.
60
61 The `sample` parameter governs whether to up-sample or down-sample the image data.
62 """
63 scale = [1] * (len(img.axes) - 1) # omit the channel axis
64 z_axis, y_axis, x_axis = [
65 img.axes.index(axis) for axis in 'ZYX'
66 ]
67
68 # Determine the pixel size of the image
69 if 'resolution' in img.metadata:
70 pixel_size = np.divide(1, img.metadata['resolution'])
71 else:
72 sys.exit('Resolution information missing in image metadata')
73
74 # Define unified transformation of pixel/voxel sizes to scale factors
75 def voxel_size_to_scale(voxel_size: np.ndarray) -> list:
76 match sample:
77 case 'up':
78 return (voxel_size / voxel_size.min()).tolist()
79 case 'down':
80 return (voxel_size / voxel_size.max()).tolist()
81 case _:
82 raise ValueError(f'Unknown value for sample: "{sample}"')
83
84 # Handle the 3-D case
85 if img.data.shape[z_axis] > 1:
86
87 # Determine the voxel depth of the image
88 if (voxel_depth := img.metadata.get('z_spacing', None)) is None:
89 sys.exit('Voxel depth information missing in image metadata')
90
91 # Determine the XYZ scale factors
92 scale[x_axis], scale[y_axis], scale[z_axis] = (
93 voxel_size_to_scale(
94 np.array([*pixel_size, voxel_depth]),
95 )
96 )
97
98 # Handle the 2-D case
99 else:
100
101 # Determine the XY scale factors
102 scale[x_axis], scale[y_axis] = (
103 voxel_size_to_scale(
104 np.array(pixel_size),
105 )
106 )
107
108 return tuple(scale)
109
110
111 def get_aa_sigma_by_scale(scale: float) -> float:
112 """
113 Determine the optimal size of the Gaussian filter for anti-aliasing.
114
115 See for details: https://scikit-image.org/docs/0.25.x/api/skimage.transform.html#skimage.transform.rescale
116 """
117 return (1 / scale - 1) / 2 if scale < 1 else 0
118
119
120 def get_new_metadata(
121 old: giatools.Image,
122 scale: float | tuple[float, ...],
123 arr: np.ndarray,
124 ) -> dict[str, Any]:
125 """
126 Determine the result metadata (copy and adapt).
127 """
128 metadata = dict(old.metadata)
129 scales = (
130 [scale] * (len(old.axes) - 1) # omit the channel axis
131 if isinstance(scale, float) else scale
132 )
133
134 # Determine the original pixel size
135 old_pixel_size = (
136 np.divide(1, old.metadata['resolution'])
137 if 'resolution' in old.metadata else (1, 1)
138 )
139
140 # Determine the new pixel size and update metadata
141 new_pixel_size = np.divide(
142 old_pixel_size,
143 (
144 scales[old.axes.index('X')],
145 scales[old.axes.index('Y')],
146 ),
147 )
148 metadata['resolution'] = tuple(1 / new_pixel_size)
149
150 # Update the metadata for the new voxel depth
151 old_voxel_depth = old.metadata.get('z_spacing', 1)
152 metadata['z_spacing'] = old_voxel_depth / scales[old.axes.index('Z')]
153
154 return metadata
155
156
157 def metadata_to_str(metadata: dict) -> str:
158 tokens = list()
159 for key in sorted(metadata.keys()):
160 value = metadata[key]
161 if isinstance(value, tuple):
162 value = '(' + ', '.join([f'{val}' for val in value]) + ')'
163 tokens.append(f'{key}: {value}')
164 if len(metadata_str := ', '.join(tokens)) > 0:
165 return metadata_str
166 else:
167 return 'has no metadata'
168
169
170 def write_output(filepath: str, img: giatools.Image):
171 """
172 Validate that the output file format is suitable for the image data, then write it.
173 """
174 print('Output shape:', img.data.shape)
175 print('Output axes:', img.axes)
176 print('Output', metadata_to_str(img.metadata))
177
178 # Validate that the output file format is suitable for the image data
179 if filepath.lower().endswith('.png'):
180 if not frozenset(img.axes) <= frozenset('YXC'):
181 sys.exit(f'Cannot write PNG file with axes "{img.axes}"')
182
183 # Write image data to the output file
184 img.write(filepath)
185
186
187 def scale_image(
188 input_filepath: str,
189 output_filepath: str,
190 mode: Literal['uniform', 'explicit', 'isotropy'],
191 order: int,
192 anti_alias: bool,
193 **cfg,
194 ):
195 img = giatools.Image.read(input_filepath)
196 print('Input axes:', img.original_axes)
197 print('Input', metadata_to_str(img.metadata))
198
199 # Determine `scale` for scaling
200 match mode:
201
202 case 'uniform':
203 scale = get_uniform_scale(img, cfg['axes'], cfg['factor'])
204
205 case 'explicit':
206 scale = tuple(
207 [cfg.get(f'factor_{axis.lower()}', 1) for axis in img.axes if axis != 'C']
208 )
209
210 case 'isotropy':
211 scale = get_scale_for_isotropy(img, cfg['sample'])
212
213 case _:
214 raise ValueError(f'Unknown mode: "{mode}"')
215
216 # Assemble remaining `rescale` parameters
217 rescale_kwargs = dict(
218 scale=scale,
219 order=order,
220 preserve_range=True,
221 channel_axis=img.axes.index('C'),
222 )
223 if (anti_alias := anti_alias and (np.array(scale) < 1).any()):
224 rescale_kwargs['anti_aliasing'] = anti_alias
225 rescale_kwargs['anti_aliasing_sigma'] = tuple(
226 [
227 get_aa_sigma_by_scale(s) for s in scale
228 ] + [0] # `skimage.transform.rescale` also expects a value for the channel axis
229 )
230 else:
231 rescale_kwargs['anti_aliasing'] = False
232
233 # Re-sample the image data to perform the scaling
234 for key, value in rescale_kwargs.items():
235 print(f'{key}: {value}')
236 arr = skimage.transform.rescale(img.data, **rescale_kwargs)
31 237
32 # Preserve the `dtype` so that both brightness and range of values is preserved 238 # Preserve the `dtype` so that both brightness and range of values is preserved
33 if res.dtype != im.dtype: 239 if arr.dtype != img.data.dtype:
34 if np.issubdtype(im.dtype, np.integer): 240 if np.issubdtype(img.data.dtype, np.integer):
35 res = res.round() 241 arr = arr.round()
36 res = res.astype(im.dtype) 242 arr = arr.astype(img.data.dtype)
37 243
38 # Save result 244 # Determine the result metadata and save result
39 skimage.io.imsave(output_file, res) 245 metadata = get_new_metadata(img, scale, arr)
246 write_output(
247 output_filepath,
248 giatools.Image(
249 data=arr,
250 axes=img.axes,
251 metadata=metadata,
252 ).squeeze()
253 )
40 254
41 255
42 if __name__ == "__main__": 256 if __name__ == "__main__":
43 parser = argparse.ArgumentParser() 257 parser = argparse.ArgumentParser()
44 parser.add_argument('input_file', type=argparse.FileType('r'), default=sys.stdin) 258 parser.add_argument('input', type=str)
45 parser.add_argument('out_file', type=argparse.FileType('w'), default=sys.stdin) 259 parser.add_argument('output', type=str)
46 parser.add_argument('--scale', type=str, required=True) 260 parser.add_argument('params', type=str)
47 parser.add_argument('--order', type=int, required=True)
48 parser.add_argument('--antialias', default=False, action='store_true')
49 args = parser.parse_args() 261 args = parser.parse_args()
50 262
51 scale_image(args.input_file.name, args.out_file.name, args.scale, args.order, args.antialias) 263 # Read the config file
264 with open(args.params) as cfgf:
265 cfg = json.load(cfgf)
266
267 # Perform scaling
268 scale_image(
269 args.input,
270 args.output,
271 **cfg,
272 )