diff auto_threshold.py @ 10:2ee04d2ebdcf draft default tip

planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/2d_auto_threshold/ commit 71f7ecabba78de48147d4a5e6ea380b6b70b16e8
author imgteam
date Sat, 03 Jan 2026 14:43:10 +0000
parents 50fa6150e340
children
line wrap: on
line diff
--- a/auto_threshold.py	Sat Jun 07 18:38:31 2025 +0000
+++ b/auto_threshold.py	Sat Jan 03 14:43:10 2026 +0000
@@ -1,45 +1,47 @@
 """
-Copyright 2017-2024 Biomedical Computer Vision Group, Heidelberg University.
+Copyright 2017-2025 Biomedical Computer Vision Group, Heidelberg University.
 
 Distributed under the MIT license.
 See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
 """
 
-import argparse
-
+import giatools
 import numpy as np
 import skimage.filters
 import skimage.util
-from giatools.image import Image
+
+# Fail early if an optional backend is not available
+giatools.require_backend('omezarr')
 
 
 class DefaultThresholdingMethod:
 
-    def __init__(self, thres, accept: list[str] | None = None, **kwargs):
+    def __init__(self, thres, **kwargs):
         self.thres = thres
-        self.accept = accept if accept else []
         self.kwargs = kwargs
 
     def __call__(self, image, *args, offset=0, **kwargs):
-        accepted_kwargs = self.kwargs.copy()
-        for key, val in kwargs.items():
-            if key in self.accept:
-                accepted_kwargs[key] = val
-        thres = self.thres(image, *args, **accepted_kwargs)
+        thres = self.thres(image, *args, **(self.kwargs | kwargs))
         return image > thres + offset
 
+    def __str__(self):
+        return self.thres.__name__
+
 
 class ManualThresholding:
 
-    def __call__(self, image, thres1: float, thres2: float | None, **kwargs):
-        if thres2 is None:
-            return image > thres1
+    def __call__(self, image, threshold1: float, threshold2: float | None, **kwargs):
+        if threshold2 is None:
+            return image > threshold1
         else:
-            thres1, thres2 = sorted((thres1, thres2))
-            return skimage.filters.apply_hysteresis_threshold(image, thres1, thres2)
+            threshold1, threshold2 = sorted((threshold1, threshold2))
+            return skimage.filters.apply_hysteresis_threshold(image, threshold1, threshold2)
+
+    def __str__(self):
+        return 'Manual'
 
 
-th_methods = {
+methods = {
     'manual': ManualThresholding(),
 
     'otsu': DefaultThresholdingMethod(skimage.filters.threshold_otsu),
@@ -47,71 +49,37 @@
     'yen': DefaultThresholdingMethod(skimage.filters.threshold_yen),
     'isodata': DefaultThresholdingMethod(skimage.filters.threshold_isodata),
 
-    'loc_gaussian': DefaultThresholdingMethod(skimage.filters.threshold_local, accept=['block_size'], method='gaussian'),
-    'loc_median': DefaultThresholdingMethod(skimage.filters.threshold_local, accept=['block_size'], method='median'),
-    'loc_mean': DefaultThresholdingMethod(skimage.filters.threshold_local, accept=['block_size'], method='mean'),
+    'loc_gaussian': DefaultThresholdingMethod(skimage.filters.threshold_local, method='gaussian'),
+    'loc_median': DefaultThresholdingMethod(skimage.filters.threshold_local, method='median'),
+    'loc_mean': DefaultThresholdingMethod(skimage.filters.threshold_local, method='mean'),
 }
 
 
-def do_thresholding(
-    input_filepath: str,
-    output_filepath: str,
-    th_method: str,
-    block_size: int,
-    offset: float,
-    threshold1: float,
-    threshold2: float | None,
-    invert_output: bool,
-):
-    assert th_method in th_methods, f'Unknown method "{th_method}"'
+if __name__ == "__main__":
+    tool = giatools.ToolBaseplate()
+    tool.add_input_image('input')
+    tool.add_output_image('output')
+    tool.parse_args()
 
-    # Load image
-    img_in = Image.read(input_filepath)
+    # Retrieve general parameters
+    method = tool.args.params.pop('method')
+    invert = tool.args.params.pop('invert')
 
     # Perform thresholding
-    result = th_methods[th_method](
-        image=img_in.data,
-        block_size=block_size,
-        offset=offset,
-        thres1=threshold1,
-        thres2=threshold2,
-    )
-    if invert_output:
-        result = np.logical_not(result)
-
-    # Convert to canonical representation for binary images
-    result = (result * 255).astype(np.uint8)
-
-    # Write result
-    Image(
-        data=skimage.util.img_as_ubyte(result),
-        axes=img_in.axes,
-    ).normalize_axes_like(
-        img_in.original_axes,
-    ).write(
-        output_filepath,
+    method_impl = methods[method]
+    print(
+        'Thresholding:',
+        str(method_impl),
+        'with',
+        ', '.join(
+            f'{key}={repr(value)}' for key, value in (tool.args.params | dict(invert=invert)).items()
+        ),
     )
-
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description='Automatic image thresholding')
-    parser.add_argument('input', type=str, help='Path to the input image')
-    parser.add_argument('output', type=str, help='Path to the output image (uint8)')
-    parser.add_argument('th_method', choices=th_methods.keys(), help='Thresholding method')
-    parser.add_argument('block_size', type=int, help='Odd size of pixel neighborhood for calculating the threshold')
-    parser.add_argument('offset', type=float, help='Offset of automatically determined threshold value')
-    parser.add_argument('threshold1', type=float, help='Manual threshold value')
-    parser.add_argument('--threshold2', type=float, help='Second manual threshold value (for hysteresis thresholding)')
-    parser.add_argument('--invert_output', default=False, action='store_true', help='Values below/above the threshold are labeled with 0/255 by default, and with 255/0 if this argument is used')
-    args = parser.parse_args()
-
-    do_thresholding(
-        args.input,
-        args.output,
-        args.th_method,
-        args.block_size,
-        args.offset,
-        args.threshold1,
-        args.threshold2,
-        args.invert_output,
-    )
+    for section in tool.run('ZYX', output_dtype_hint='binary'):
+        section_output = method_impl(
+            image=np.asarray(section['input'].data),  # some implementations have issues with Dask arrays
+            **tool.args.params,
+        )
+        if invert:
+            section_output = np.logical_not(section_output)
+        section['output'] = section_output