comparison points2label.py @ 6:22bb32eae6a1 draft default tip

planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/points2labelimage/ commit edac062b00490276ef00d094e0594abdb3a3f23c
author imgteam
date Thu, 06 Nov 2025 09:59:34 +0000
parents 4a49f74a3c14
children
comparison
equal deleted inserted replaced
5:4a49f74a3c14 6:22bb32eae6a1
1 import argparse 1 import argparse
2 import json 2 import json
3 import os
4 import warnings 3 import warnings
5 from typing import ( 4 from typing import (
5 Any,
6 Dict, 6 Dict,
7 List, 7 Optional,
8 Tuple, 8 Tuple,
9 Union,
10 ) 9 )
11 10
12 import giatools.pandas 11 import giatools.pandas
13 import numpy as np 12 import numpy as np
14 import numpy.typing as npt 13 import numpy.typing as npt
15 import pandas as pd 14 import pandas as pd
16 import scipy.ndimage as ndi 15 import scipy.ndimage as ndi
16 import skimage.draw
17 import skimage.io 17 import skimage.io
18 import skimage.segmentation 18 import skimage.segmentation
19 19
20 20
21 def is_rectangular(points: Union[List[Tuple[float, float]], npt.NDArray]) -> bool: 21 def get_list_depth(nested_list: Any) -> int:
22 points = np.asarray(points) 22 if isinstance(nested_list, list):
23 23 if len(nested_list) > 0:
24 # Rectangle must have 5 points, where first and last are identical 24 return 1 + max(map(get_list_depth, nested_list))
25 if len(points) != 5 or not (points[0] == points[-1]).all(): 25 else:
26 return False 26 return 1
27 27 else:
28 # Check that all edges align with the axes 28 return 0
29 edges = points[1:] - points[:-1] 29
30 if any((edge == 0).sum() != 1 for edge in edges): 30
31 return False 31 class AutoLabel:
32 32 """
33 # All checks have passed, the geometry is rectangular 33 Creates a sequence of unique labels (non-negative values).
34 return True 34 """
35 35
36 36 def __init__(self, reserved_labels):
37 def geojson_to_tabular(geojson: Dict): 37 self.reserved_labels = reserved_labels
38 rows = [] 38 self.next_autolabel = 0
39 labels = [] 39
40 def next(self):
41 """
42 Retrieve the next auto-label (post-increment).
43 """
44 # Fast-forward `next_autolabel` to the first free label
45 while self.next_autolabel in self.reserved_labels:
46 self.next_autolabel += 1
47
48 # Return the free label, then advance `next_autolabel`
49 try:
50 return self.next_autolabel
51 finally:
52 self.next_autolabel += 1
53
54
55 def get_feature_label(feature: Dict) -> Optional[int]:
56 """
57 Get the label of a GeoJSON feature, or `None` if there is no proper label.
58 """
59 label = feature.get('properties', {}).get('name', None)
60 if label is None:
61 return None
62
63 # If the `label` is given as a string, try to parse as integer
64 if isinstance(label, str):
65 try:
66 label = int(label)
67 except ValueError:
68 pass
69
70 # Finally, if `label` is an integer, only use it if it is non-negative
71 if isinstance(label, int) and label >= 0:
72 return label
73 else:
74 return None
75
76
77 def rasterize(
78 geojson: Dict,
79 shape: Tuple[int, int],
80 bg_value: int = 0,
81 fg_value: Optional[int] = None,
82 ) -> npt.NDArray:
83 """
84 Rasterize GeoJSON into a pixel image, that is returned as a NumPy array.
85 """
86
87 # Determine which labels are reserved (not used by auto-label)
88 reserved_labels = [bg_value]
89 if fg_value is None:
90 for feature in geojson['features']:
91 label = get_feature_label(feature)
92 if label is not None:
93 reserved_labels.append(label)
94
95 # Convert `reserved_labels` into a `set` for faster look-ups
96 reserved_labels = frozenset(reserved_labels)
97
98 # Define routine to retrieve the next auto-label
99 autolabel = AutoLabel(reserved_labels)
100
101 # Rasterize the image
102 img = np.full(shape, dtype=np.uint16, fill_value=bg_value)
40 for feature in geojson['features']: 103 for feature in geojson['features']:
41 assert feature['geometry']['type'].lower() == 'polygon', ( 104 geom_type = feature['geometry']['type'].lower()
42 f'Unsupported geometry type: "{feature["geometry"]["type"]}"' 105 coords = feature['geometry']['coordinates']
43 ) 106
44 coords = feature['geometry']['coordinates'][0] 107 # Rasterize a `mask` separately for each feature
45 108 if geom_type == 'polygon':
46 # Properties and name (label) are optional 109
47 try: 110 # Normalization: Let there always be a list of polygons
48 label = feature['properties']['name'] 111 if get_list_depth(coords) == 2:
112 coords = [coords]
113
114 # Rasterize each polygon separately, then join via XOR
115 mask = np.zeros(shape, dtype=bool)
116 for polygon_coords in coords:
117 polygon_mask = skimage.draw.polygon2mask(
118 shape,
119 [point[::-1] for point in polygon_coords],
120 )
121 mask = np.logical_xor(mask, polygon_mask)
122
123 elif geom_type == 'point':
124 mask = np.zeros(shape, dtype=bool)
125 mask[coords[1], coords[0]] = True
126 radius = feature.get('properties', {}).get('radius', 0)
127 if radius > 0:
128 mask = (ndi.distance_transform_edt(~mask) <= radius)
129
130 else:
131 raise ValueError(
132 f'Unsupported geometry type: "{feature["geometry"]["type"]}"',
133 )
134
135 # Determine the `label` for the current `mask`
136 if fg_value is None:
137 label = get_feature_label(feature)
138 if label is None:
139 label = autolabel.next()
140 else:
141 label = fg_value
142
143 # Blend the current `mask` with the rasterized image
144 img[mask] = label
145
146 # Return the rasterized image
147 return img
148
149
150 def convert_tabular_to_geojson(
151 tabular_file: str,
152 has_header: bool,
153 ) -> dict:
154 """
155 Read a tabular file and convert it to GeoJSON.
156
157 The GeoJSON data is returned as a dictionary.
158 """
159
160 # Read the tabular file with information from the header
161 if has_header:
162 df = pd.read_csv(tabular_file, delimiter='\t')
163
164 pos_x_column = giatools.pandas.find_column(df, ['pos_x', 'POS_X'])
165 pos_y_column = giatools.pandas.find_column(df, ['pos_y', 'POS_Y'])
166 pos_x_list = df[pos_x_column].round().astype(int)
167 pos_y_list = df[pos_y_column].round().astype(int)
168 assert len(pos_x_list) == len(pos_y_list)
169
170 try:
171 radius_column = giatools.pandas.find_column(df, ['radius', 'RADIUS'])
172 radius_list = df[radius_column]
173 assert len(pos_x_list) == len(radius_list)
49 except KeyError: 174 except KeyError:
50 label = max(labels, default=0) + 1 175 radius_list = [0] * len(pos_x_list)
51 labels.append(label) 176
52 177 try:
53 # Read geometry 178 width_column = giatools.pandas.find_column(df, ['width', 'WIDTH'])
54 xs = [pt[0] for pt in coords] 179 height_column = giatools.pandas.find_column(df, ['height', 'HEIGHT'])
55 ys = [pt[1] for pt in coords] 180 width_list = df[width_column]
56 181 height_list = df[height_column]
57 x = min(xs) 182 assert len(pos_x_list) == len(width_list)
58 y = min(ys) 183 assert len(pos_x_list) == len(height_list)
59 184 except KeyError:
60 width = max(xs) + 1 - x 185 width_list = [0] * len(pos_x_list)
61 height = max(ys) + 1 - y 186 height_list = [0] * len(pos_x_list)
62 187
63 # Validate geometry (must be rectangular) 188 try:
64 assert is_rectangular(list(zip(xs, ys))) 189 label_column = giatools.pandas.find_column(df, ['label', 'LABEL'])
65 190 label_list = df[label_column]
66 # Append the rectangle 191 assert len(pos_x_list) == len(label_list)
67 rows.append({ 192 except KeyError:
68 'pos_x': x,
69 'pos_y': y,
70 'width': width,
71 'height': height,
72 'label': label,
73 })
74 df = pd.DataFrame(rows)
75 point_file = './point_file.tabular'
76 df.to_csv(point_file, sep='\t', index=False)
77 return point_file
78
79
80 def rasterize(point_file, out_file, shape, has_header=False, swap_xy=False, bg_value=0, fg_value=None):
81
82 img = np.full(shape, dtype=np.uint16, fill_value=bg_value)
83 if os.path.exists(point_file) and os.path.getsize(point_file) > 0:
84
85 # Read the tabular file with information from the header
86 if has_header:
87 df = pd.read_csv(point_file, delimiter='\t')
88
89 pos_x_column = giatools.pandas.find_column(df, ['pos_x', 'POS_X'])
90 pos_y_column = giatools.pandas.find_column(df, ['pos_y', 'POS_Y'])
91 pos_x_list = df[pos_x_column].round().astype(int)
92 pos_y_list = df[pos_y_column].round().astype(int)
93 assert len(pos_x_list) == len(pos_y_list)
94
95 try:
96 radius_column = giatools.pandas.find_column(df, ['radius', 'RADIUS'])
97 radius_list = df[radius_column]
98 assert len(pos_x_list) == len(radius_list)
99 except KeyError:
100 radius_list = [0] * len(pos_x_list)
101
102 try:
103 width_column = giatools.pandas.find_column(df, ['width', 'WIDTH'])
104 height_column = giatools.pandas.find_column(df, ['height', 'HEIGHT'])
105 width_list = df[width_column]
106 height_list = df[height_column]
107 assert len(pos_x_list) == len(width_list)
108 assert len(pos_x_list) == len(height_list)
109 except KeyError:
110 width_list = [0] * len(pos_x_list)
111 height_list = [0] * len(pos_x_list)
112
113 try:
114 label_column = giatools.pandas.find_column(df, ['label', 'LABEL'])
115 label_list = df[label_column]
116 assert len(pos_x_list) == len(label_list)
117 except KeyError:
118 label_list = list(range(1, len(pos_x_list) + 1))
119
120 # Read the tabular file without header
121 else:
122 df = pd.read_csv(point_file, header=None, delimiter='\t')
123 pos_x_list = df[0].round().astype(int)
124 pos_y_list = df[1].round().astype(int)
125 assert len(pos_x_list) == len(pos_y_list)
126 radius_list, width_list, height_list = [[0] * len(pos_x_list)] * 3
127 label_list = list(range(1, len(pos_x_list) + 1)) 193 label_list = list(range(1, len(pos_x_list) + 1))
128 194
129 # Optionally swap the coordinates 195 # Read the tabular file without header
130 if swap_xy: 196 else:
131 pos_x_list, pos_y_list = pos_y_list, pos_x_list 197 df = pd.read_csv(tabular_file, header=None, delimiter='\t')
132 198 pos_x_list = df[0].round().astype(int)
133 # Perform the rasterization 199 pos_y_list = df[1].round().astype(int)
134 for y, x, radius, width, height, label in zip( 200 assert len(pos_x_list) == len(pos_y_list)
135 pos_y_list, pos_x_list, radius_list, width_list, height_list, label_list, 201 radius_list, width_list, height_list = [[0] * len(pos_x_list)] * 3
136 ): 202 label_list = list(range(1, len(pos_x_list) + 1))
137 if fg_value is not None: 203
138 label = fg_value 204 # Convert to GeoJSON
139 205 features = []
140 if y < 0 or x < 0 or y >= shape[0] or x >= shape[1]: 206 geojson = {
141 raise IndexError(f'The point x={x}, y={y} exceeds the bounds of the image (width: {shape[1]}, height: {shape[0]})') 207 'type': 'FeatureCollection',
142 208 'features': features,
143 # Rasterize circle and distribute overlapping image area 209 }
144 # Rasterize primitive geometry 210 for y, x, radius, width, height, label in zip(
145 if radius > 0 or (width > 0 and height > 0): 211 pos_y_list, pos_x_list, radius_list, width_list, height_list, label_list,
146 212 ):
147 # Rasterize circle 213 if radius > 0 and width > 0 and height > 0:
148 if radius > 0: 214 raise ValueError('Ambiguous shape type (circle or rectangle)')
149 mask = np.ones(shape, dtype=bool) 215
150 mask[y, x] = False 216 # Create a rectangle
151 mask = (ndi.distance_transform_edt(mask) <= radius) 217 if width > 0 and height > 0:
152 else: 218 geom_type = 'Polygon'
153 mask = np.zeros(shape, dtype=bool) 219 coords = [
154 220 [x, y],
155 # Rasterize rectangle 221 [x + width - 1, y],
156 if width > 0 and height > 0: 222 [x + width - 1, y + height - 1],
157 mask[ 223 [x, y + height - 1],
158 y:min(shape[0], y + width), 224 ]
159 x:min(shape[1], x + height) 225
160 ] = True 226 # Create a point or circle
161 227 else:
162 # Compute the overlap (pretend there is none if the rasterization is binary) 228 geom_type = 'Point'
163 if fg_value is None: 229 coords = [x, y]
164 overlap = np.logical_and(img > 0, mask) 230
165 else: 231 # Create a GeoJSON feature
166 overlap = np.zeros(shape, dtype=bool) 232 feature = {
167 233 'type': 'Feature',
168 # Rasterize the part of the circle which is disjoint from other foreground. 234 'geometry': {
169 # 235 'type': geom_type,
170 # In the current implementation, the result depends on the order of the rasterized circles if somewhere 236 'coordinates': coords,
171 # more than two circles overlap. This is probably negligable for most applications. To achieve results 237 },
172 # that are invariant to the order, first all circles would need to be rasterized independently, and 238 'properties': {
173 # then blended together. This, however, would either strongly increase the memory consumption, or 239 'name': label,
174 # require a more complex implementation which exploits the sparsity of the rasterized masks. 240 },
175 # 241 }
176 disjoint_mask = np.logical_xor(mask, overlap) 242 if radius > 0:
177 if disjoint_mask.any(): 243 feature['properties']['radius'] = radius
178 img[disjoint_mask] = label 244 feature['properties']['subType'] = 'Circle'
179 245 features.append(feature)
180 # Distribute the remaining part of the circle 246
181 if overlap.any(): 247 # Return the GeoJSON object
182 dist = ndi.distance_transform_edt(overlap) 248 return geojson
183 foreground = (img > 0)
184 img[overlap] = 0
185 img = skimage.segmentation.watershed(dist, img, mask=foreground)
186
187 # Rasterize point (there is no overlapping area to be distributed)
188 else:
189 img[y, x] = label
190
191 else:
192 raise Exception('{} is empty or does not exist.'.format(point_file)) # appropriate built-in error?
193
194 with warnings.catch_warnings():
195 warnings.simplefilter("ignore")
196 skimage.io.imsave(out_file, img, plugin='tifffile') # otherwise we get problems with the .dat extension
197 249
198 250
199 if __name__ == '__main__': 251 if __name__ == '__main__':
200 parser = argparse.ArgumentParser() 252 parser = argparse.ArgumentParser()
201 parser.add_argument('in_file', type=argparse.FileType('r'), help='Input point file or GeoJSON file') 253 parser.add_argument('in_ext', type=str, help='Input file format')
202 parser.add_argument('out_file', type=str, help='out file (TIFF)') 254 parser.add_argument('in_file', type=str, help='Input file path (tabular or GeoJSON)')
203 parser.add_argument('shapex', type=int, help='shapex') 255 parser.add_argument('out_file', type=str, help='Output file path (TIFF)')
204 parser.add_argument('shapey', type=int, help='shapey') 256 parser.add_argument('shapex', type=int, help='Output image width')
205 parser.add_argument('--has_header', dest='has_header', default=False, help='set True if point file has header') 257 parser.add_argument('shapey', type=int, help='Output image height')
258 parser.add_argument('--has_header', dest='has_header', default=False, help='Set True if tabular file has a header')
206 parser.add_argument('--swap_xy', dest='swap_xy', default=False, help='Swap X and Y coordinates') 259 parser.add_argument('--swap_xy', dest='swap_xy', default=False, help='Swap X and Y coordinates')
207 parser.add_argument('--binary', dest='binary', default=False, help='Produce binary image') 260 parser.add_argument('--binary', dest='binary', default=False, help='Produce binary image')
208
209 args = parser.parse_args() 261 args = parser.parse_args()
210 262
211 point_file = args.in_file.name 263 # Validate command-line arguments
212 has_header = args.has_header 264 assert args.in_ext in ('tabular', 'geojson'), (
213 265 f'Unexpected input file format: {args.in_ext}'
214 try: 266 )
215 with open(args.in_file.name, 'r') as f: 267
216 content = json.load(f) 268 # Load the GeoJSON data (if the input file is tabular, convert to GeoJSON)
217 if isinstance(content, dict) and content.get('type') == 'FeatureCollection' and isinstance(content.get('features'), list): 269 if args.in_ext == 'tabular':
218 point_file = geojson_to_tabular(content) 270 geojson = convert_tabular_to_geojson(args.in_file, args.has_header)
219 has_header = True # header included in the converted file 271 else:
220 else: 272 with open(args.in_file) as f:
221 raise ValueError('Input is a JSON file but not a valid GeoJSON file') 273 geojson = json.load(f)
222 except json.JSONDecodeError: 274
223 print('Input is not a valid JSON file. Assuming it a tabular file.') 275 # Rasterize the image from GeoJSON
224 276 shape = (args.shapey, args.shapex)
225 rasterize( 277 img = rasterize(
226 point_file, 278 geojson,
227 args.out_file, 279 shape if not args.swap_xy else shape[::-1],
228 (args.shapey, args.shapex),
229 has_header=has_header,
230 swap_xy=args.swap_xy,
231 fg_value=0xffff if args.binary else None, 280 fg_value=0xffff if args.binary else None,
232 ) 281 )
282 if args.swap_xy:
283 img = img.T
284
285 # Write the rasterized image as TIFF
286 with warnings.catch_warnings():
287 warnings.simplefilter('ignore')
288 skimage.io.imsave(args.out_file, img, plugin='tifffile') # otherwise we get problems with the .dat extension