comparison run-segmetrics.py @ 0:0729657d9e4e draft

planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tools/segmetrics/ commit 3b911df716a7b42115c6cd773f666bc90a2bb10f
author imgteam
date Fri, 07 Oct 2022 22:05:59 +0000
parents
children c90b52773d2e
comparison
equal deleted inserted replaced
-1:000000000000 0:0729657d9e4e
1 """
2 Copyright 2022 Leonid Kostrykin, Biomedical Computer Vision Group, Heidelberg University.
3
4 Distributed under the MIT license.
5 See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
6
7 """
8
9 import argparse
10 import csv
11 import itertools
12 import pathlib
13 import tempfile
14 import zipfile
15
16 import numpy as np
17 import segmetrics as sm
18 import skimage.io
19
20
21 measures = [
22 ('dice', 'Dice', sm.regional.Dice()),
23 ('seg', 'SEG', sm.regional.ISBIScore()),
24 ('jc', 'Jaccard coefficient', sm.regional.JaccardSimilarityIndex()),
25 ('ji', 'Jaccard index', sm.regional.JaccardIndex()),
26 ('ri', 'Rand index', sm.regional.RandIndex()),
27 ('ari', 'Adjusted Rand index', sm.regional.AdjustedRandIndex()),
28 ('hsd_sym', 'HSD (sym)', sm.boundary.Hausdorff('sym')),
29 ('hsd_e2a', 'HSD (e2a)', sm.boundary.Hausdorff('e2a')),
30 ('hsd_a2e', 'HSD (a2e)', sm.boundary.Hausdorff('a2e')),
31 ('nsd', 'NSD', sm.boundary.NSD()),
32 ('o_hsd_sym', 'Ob. HSD (sym)', sm.boundary.ObjectBasedDistance(sm.boundary.Hausdorff('sym'))),
33 ('o_hsd_e2a', 'Ob. HSD (e2a)', sm.boundary.ObjectBasedDistance(sm.boundary.Hausdorff('e2a'))),
34 ('o_hsd_a2e', 'Ob. HSD (a2e)', sm.boundary.ObjectBasedDistance(sm.boundary.Hausdorff('a2e'))),
35 ('o_nsd', 'Ob. NSD', sm.boundary.ObjectBasedDistance(sm.boundary.NSD())),
36 ('fs', 'Split', sm.detection.FalseSplit()),
37 ('fm', 'Merge', sm.detection.FalseMerge()),
38 ('fp', 'Spurious', sm.detection.FalsePositive()),
39 ('fn', 'Missing', sm.detection.FalseNegative()),
40 ]
41
42
43 def process_batch(study, gt_filelist, seg_filelist, namelist, gt_is_unique, seg_is_unique):
44 for gt_filename, seg_filename, name in zip(gt_filelist, seg_filelist, namelist):
45 img_ref = skimage.io.imread(gt_filename)
46 img_seg = skimage.io.imread(seg_filename)
47 study.set_expected(img_ref, unique=gt_is_unique)
48 study.process(img_seg, unique=seg_is_unique, chunk_id=name)
49
50
51 def aggregate(measure, values):
52 fnc = np.sum if measure.ACCUMULATIVE else np.mean
53 return fnc(values)
54
55
56 def is_zip_filepath(filepath):
57 return filepath.lower().endswith('.zip')
58
59
60 def is_image_filepath(filepath):
61 suffixes = ['png', 'tif', 'tiff']
62 return any((filepath.lower().endswith(f'.{suffix}') for suffix in suffixes))
63
64
65 if __name__ == "__main__":
66 parser = argparse.ArgumentParser(description='Image segmentation and object detection performance measures for 2-D image data')
67 parser.add_argument('input_seg', help='Path to the segmented image or image archive (ZIP)')
68 parser.add_argument('input_gt', help='Path to the ground truth image or image archive (ZIP)')
69 parser.add_argument('results', help='Path to the results file (CSV)')
70 parser.add_argument('-unzip', action='store_true')
71 parser.add_argument('-seg_unique', action='store_true')
72 parser.add_argument('-gt_unique', action='store_true')
73 for measure in measures:
74 parser.add_argument(f'-measure-{measure[0]}', action='store_true', help=f'Include {measure[1]}')
75
76 args = parser.parse_args()
77 study = sm.study.Study()
78
79 used_measures = []
80 for measure in measures:
81 if getattr(args, f'measure_{measure[0]}'):
82 used_measures.append(measure)
83 study.add_measure(measure[2], measure[1])
84
85 if args.unzip:
86 zipfile_seg = zipfile.ZipFile(args.input_seg)
87 zipfile_gt = zipfile.ZipFile(args.input_gt)
88 namelist = [filepath for filepath in zipfile_seg.namelist() if is_image_filepath(filepath) and filepath in zipfile_gt.namelist()]
89 print('namelist:', namelist)
90 with tempfile.TemporaryDirectory() as tmpdir:
91 basepath = pathlib.Path(tmpdir)
92 gt_path, seg_path = basepath / 'gt', basepath / 'seg'
93 zipfile_seg.extractall(str(seg_path))
94 zipfile_gt.extractall(str(gt_path))
95 gt_filelist, seg_filelist = list(), list()
96 for filepath in namelist:
97 seg_filelist.append(str(seg_path / filepath))
98 gt_filelist.append(str(gt_path / filepath))
99 process_batch(study, gt_filelist, seg_filelist, namelist, args.gt_unique, args.seg_unique)
100
101 else:
102 namelist = ['']
103 process_batch(study, [args.input_gt], [args.input_seg], namelist, args.gt_unique, args.seg_unique)
104
105 # define header
106 rows = [[''] + [measure[1] for measure in used_measures]]
107
108 # define rows
109 if len(namelist) > 1:
110 for chunk_id in namelist:
111 row = [chunk_id]
112 for measure in used_measures:
113 measure_name = measure[1]
114 measure = study.measures[measure_name]
115 chunks = study.results[measure_name]
116 row += [aggregate(measure, chunks[chunk_id])]
117 rows.append(row)
118
119 # define footer
120 rows.append([''])
121 for measure in used_measures:
122 measure_name = measure[1]
123 measure = study.measures[measure_name]
124 chunks = study.results[measure_name]
125 values = list(itertools.chain(*[chunks[chunk_id] for chunk_id in chunks]))
126 val = aggregate(measure, values)
127 rows[-1].append(val)
128
129 # write results
130 with open(args.results, 'w', newline='') as fout:
131 csv_writer = csv.writer(fout, delimiter=';', quotechar='|', quoting=csv.QUOTE_MINIMAL)
132 for row in rows:
133 csv_writer.writerow(row)