diff cp_segmentation.py @ 0:1e7334a51725 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/cellpose commit 06dd9637975e3b9d6d27a3d5a773c85e9a52baf2
author bgruening
date Thu, 29 Feb 2024 22:07:26 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/cp_segmentation.py	Thu Feb 29 22:07:26 2024 +0000
@@ -0,0 +1,85 @@
+import argparse
+import json
+import os
+import warnings
+
+import matplotlib.pyplot as plt
+import numpy as np
+import skimage.io
+from cellpose import models, plot, transforms
+
+
+def main(inputs, img_path, img_format, output_dir):
+    """
+    Parameter
+    ---------
+    inputs : str
+        File path to galaxy tool parameter
+    img_path : str
+        File path for the input image
+    img_format : str
+        One of the ['ome.tiff', 'tiff', 'png', 'jpg']
+    output_dir : str
+        Folder to save the outputs.
+    """
+    warnings.simplefilter('ignore')
+
+    with open(inputs, 'r') as param_handler:
+        params = json.load(param_handler)
+
+    gpu = params['use_gpu']
+    model_type = params['model_type']
+    chan = params['chan']
+    chan2 = params['chan2']
+    chan_first = params['chan_first']
+    if chan is None:
+        channels = None
+    else:
+        channels = [int(chan), int(chan2) if chan2 is not None else None]
+
+    options = params['options']
+
+    img = skimage.io.imread(img_path)
+
+    print(f"Image shape: {img.shape}")
+    # transpose to Ly x Lx x nchann and reshape based on channels
+    if img_format.endswith('tiff'):
+        img = np.transpose(img, (1, 2, 0))
+        img = transforms.reshape(img, channels=channels, chan_first=chan_first)
+
+    print(f"Image shape: {img.shape}")
+    model = models.Cellpose(gpu=gpu, model_type=model_type)
+    masks, flows, styles, diams = model.eval(img, channels=channels, **options)
+
+    # save masks to tiff
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        skimage.io.imsave(os.path.join(output_dir, 'cp_masks.tif'),
+                          masks.astype(np.uint16))
+
+    # make segmentation show #
+    if params['show_segmentation']:
+        img = skimage.io.imread(img_path)
+        # uniform image
+        if img_format.endswith('tiff'):
+            img = np.transpose(img, (1, 2, 0))
+            img = transforms.reshape(img, channels=channels, chan_first=chan_first)
+
+        maski = masks
+        flowi = flows[0]
+        fig = plt.figure(figsize=(12, 3))
+        # can save images (set save_dir=None if not)
+        plot.show_segmentation(fig, img, maski, flowi, channels=channels)
+        fig.savefig(os.path.join(output_dir, 'segm_show.png'), dpi=300)
+        plt.close(fig)
+
+
+if __name__ == '__main__':
+    aparser = argparse.ArgumentParser()
+    aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
+    aparser.add_argument("-p", "--img_path", dest="img_path")
+    aparser.add_argument("-f", "--img_format", dest="img_format")
+    aparser.add_argument("-O", "--output_dir", dest="output_dir")
+    args = aparser.parse_args()
+
+    main(args.inputs, args.img_path, args.img_format, args.output_dir)