comparison UNetCoreograph.py @ 0:99308601eaa6 draft

"planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
author perssond
date Wed, 19 May 2021 21:34:38 +0000
parents
children 57f1260ca94e
comparison
equal deleted inserted replaced
-1:000000000000 0:99308601eaa6
1 import numpy as np
2 from scipy import misc as sm
3 import shutil
4 import scipy.io as sio
5 import os
6 import skimage.exposure as sk
7 import cv2
8 import argparse
9 import pytiff
10 import tifffile
11 import tensorflow as tf
12 from skimage.morphology import *
13 from skimage.exposure import rescale_intensity
14 from skimage.segmentation import chan_vese, find_boundaries, morphological_chan_vese
15 from skimage.measure import regionprops,label, find_contours
16 from skimage.transform import resize
17 from skimage.filters import gaussian
18 from skimage.feature import peak_local_max,blob_log
19 from skimage.color import label2rgb
20 import skimage.io as skio
21 from skimage import img_as_bool
22 from skimage.draw import circle_perimeter
23 from scipy.ndimage.filters import uniform_filter
24 from scipy.ndimage import gaussian_laplace
25 from os.path import *
26 from os import listdir, makedirs, remove
27
28
29
30 import sys
31 from typing import Any
32
33 #sys.path.insert(0, 'C:\\Users\\Public\\Documents\\ImageScience')
34 from toolbox.imtools import *
35 from toolbox.ftools import *
36 from toolbox.PartitionOfImage import PI2D
37
38
39 def concat3(lst):
40 return tf.concat(lst,3)
41
42 class UNet2D:
43 hp = None # hyper-parameters
44 nn = None # network
45 tfTraining = None # if training or not (to handle batch norm)
46 tfData = None # data placeholder
47 Session = None
48 DatasetMean = 0
49 DatasetStDev = 0
50
51 def setupWithHP(hp):
52 UNet2D.setup(hp['imSize'],
53 hp['nChannels'],
54 hp['nClasses'],
55 hp['nOut0'],
56 hp['featMapsFact'],
57 hp['downSampFact'],
58 hp['ks'],
59 hp['nExtraConvs'],
60 hp['stdDev0'],
61 hp['nLayers'],
62 hp['batchSize'])
63
64 def setup(imSize,nChannels,nClasses,nOut0,featMapsFact,downSampFact,kernelSize,nExtraConvs,stdDev0,nDownSampLayers,batchSize):
65 UNet2D.hp = {'imSize':imSize,
66 'nClasses':nClasses,
67 'nChannels':nChannels,
68 'nExtraConvs':nExtraConvs,
69 'nLayers':nDownSampLayers,
70 'featMapsFact':featMapsFact,
71 'downSampFact':downSampFact,
72 'ks':kernelSize,
73 'nOut0':nOut0,
74 'stdDev0':stdDev0,
75 'batchSize':batchSize}
76
77 nOutX = [UNet2D.hp['nChannels'],UNet2D.hp['nOut0']]
78 dsfX = []
79 for i in range(UNet2D.hp['nLayers']):
80 nOutX.append(nOutX[-1]*UNet2D.hp['featMapsFact'])
81 dsfX.append(UNet2D.hp['downSampFact'])
82
83
84 # --------------------------------------------------
85 # downsampling layer
86 # --------------------------------------------------
87
88 with tf.name_scope('placeholders'):
89 UNet2D.tfTraining = tf.placeholder(tf.bool, name='training')
90 UNet2D.tfData = tf.placeholder("float", shape=[None,UNet2D.hp['imSize'],UNet2D.hp['imSize'],UNet2D.hp['nChannels']],name='data')
91
92 def down_samp_layer(data,index):
93 with tf.name_scope('ld%d' % index):
94 ldXWeights1 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index], nOutX[index+1]], stddev=stdDev0),name='kernel1')
95 ldXWeightsExtra = []
96 for i in range(nExtraConvs):
97 ldXWeightsExtra.append(tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index+1], nOutX[index+1]], stddev=stdDev0),name='kernelExtra%d' % i))
98
99 c00 = tf.nn.conv2d(data, ldXWeights1, strides=[1, 1, 1, 1], padding='SAME')
100 for i in range(nExtraConvs):
101 c00 = tf.nn.conv2d(tf.nn.relu(c00), ldXWeightsExtra[i], strides=[1, 1, 1, 1], padding='SAME')
102
103 ldXWeightsShortcut = tf.Variable(tf.truncated_normal([1, 1, nOutX[index], nOutX[index+1]], stddev=stdDev0),name='shortcutWeights')
104 shortcut = tf.nn.conv2d(data, ldXWeightsShortcut, strides=[1, 1, 1, 1], padding='SAME')
105
106 bn = tf.layers.batch_normalization(tf.nn.relu(c00+shortcut), training=UNet2D.tfTraining)
107
108 return tf.nn.max_pool(bn, ksize=[1, dsfX[index], dsfX[index], 1], strides=[1, dsfX[index], dsfX[index], 1], padding='SAME',name='maxpool')
109
110 # --------------------------------------------------
111 # bottom layer
112 # --------------------------------------------------
113
114 with tf.name_scope('lb'):
115 lbWeights1 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[UNet2D.hp['nLayers']], nOutX[UNet2D.hp['nLayers']+1]], stddev=stdDev0),name='kernel1')
116 def lb(hidden):
117 return tf.nn.relu(tf.nn.conv2d(hidden, lbWeights1, strides=[1, 1, 1, 1], padding='SAME'),name='conv')
118
119 # --------------------------------------------------
120 # downsampling
121 # --------------------------------------------------
122
123 with tf.name_scope('downsampling'):
124 dsX = []
125 dsX.append(UNet2D.tfData)
126
127 for i in range(UNet2D.hp['nLayers']):
128 dsX.append(down_samp_layer(dsX[i],i))
129
130 b = lb(dsX[UNet2D.hp['nLayers']])
131
132 # --------------------------------------------------
133 # upsampling layer
134 # --------------------------------------------------
135
136 def up_samp_layer(data,index):
137 with tf.name_scope('lu%d' % index):
138 luXWeights1 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index+1], nOutX[index+2]], stddev=stdDev0),name='kernel1')
139 luXWeights2 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index]+nOutX[index+1], nOutX[index+1]], stddev=stdDev0),name='kernel2')
140 luXWeightsExtra = []
141 for i in range(nExtraConvs):
142 luXWeightsExtra.append(tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index+1], nOutX[index+1]], stddev=stdDev0),name='kernel2Extra%d' % i))
143
144 outSize = UNet2D.hp['imSize']
145 for i in range(index):
146 outSize /= dsfX[i]
147 outSize = int(outSize)
148
149 outputShape = [UNet2D.hp['batchSize'],outSize,outSize,nOutX[index+1]]
150 us = tf.nn.relu(tf.nn.conv2d_transpose(data, luXWeights1, outputShape, strides=[1, dsfX[index], dsfX[index], 1], padding='SAME'),name='conv1')
151 cc = concat3([dsX[index],us])
152 cv = tf.nn.relu(tf.nn.conv2d(cc, luXWeights2, strides=[1, 1, 1, 1], padding='SAME'),name='conv2')
153 for i in range(nExtraConvs):
154 cv = tf.nn.relu(tf.nn.conv2d(cv, luXWeightsExtra[i], strides=[1, 1, 1, 1], padding='SAME'),name='conv2Extra%d' % i)
155 return cv
156
157 # --------------------------------------------------
158 # final (top) layer
159 # --------------------------------------------------
160
161 with tf.name_scope('lt'):
162 ltWeights1 = tf.Variable(tf.truncated_normal([1, 1, nOutX[1], nClasses], stddev=stdDev0),name='kernel')
163 def lt(hidden):
164 return tf.nn.conv2d(hidden, ltWeights1, strides=[1, 1, 1, 1], padding='SAME',name='conv')
165
166
167 # --------------------------------------------------
168 # upsampling
169 # --------------------------------------------------
170
171 with tf.name_scope('upsampling'):
172 usX = []
173 usX.append(b)
174
175 for i in range(UNet2D.hp['nLayers']):
176 usX.append(up_samp_layer(usX[i],UNet2D.hp['nLayers']-1-i))
177
178 t = lt(usX[UNet2D.hp['nLayers']])
179
180
181 sm = tf.nn.softmax(t,-1)
182 UNet2D.nn = sm
183
184
185 def train(imPath,logPath,modelPath,pmPath,nTrain,nValid,nTest,restoreVariables,nSteps,gpuIndex,testPMIndex):
186 os.environ['CUDA_VISIBLE_DEVICES']= '%d' % gpuIndex
187
188 outLogPath = logPath
189 trainWriterPath = pathjoin(logPath,'Train')
190 validWriterPath = pathjoin(logPath,'Valid')
191 outModelPath = pathjoin(modelPath,'model.ckpt')
192 outPMPath = pmPath
193
194 batchSize = UNet2D.hp['batchSize']
195 imSize = UNet2D.hp['imSize']
196 nChannels = UNet2D.hp['nChannels']
197 nClasses = UNet2D.hp['nClasses']
198
199 # --------------------------------------------------
200 # data
201 # --------------------------------------------------
202
203 Train = np.zeros((nTrain,imSize,imSize,nChannels))
204 Valid = np.zeros((nValid,imSize,imSize,nChannels))
205 Test = np.zeros((nTest,imSize,imSize,nChannels))
206 LTrain = np.zeros((nTrain,imSize,imSize,nClasses))
207 LValid = np.zeros((nValid,imSize,imSize,nClasses))
208 LTest = np.zeros((nTest,imSize,imSize,nClasses))
209
210 print('loading data, computing mean / st dev')
211 if not os.path.exists(modelPath):
212 os.makedirs(modelPath)
213 if restoreVariables:
214 datasetMean = loadData(pathjoin(modelPath,'datasetMean.data'))
215 datasetStDev = loadData(pathjoin(modelPath,'datasetStDev.data'))
216 else:
217 datasetMean = 0
218 datasetStDev = 0
219 for iSample in range(nTrain+nValid+nTest):
220 I = im2double(tifread('%s/I%05d_Img.tif' % (imPath,iSample)))
221 datasetMean += np.mean(I)
222 datasetStDev += np.std(I)
223 datasetMean /= (nTrain+nValid+nTest)
224 datasetStDev /= (nTrain+nValid+nTest)
225 saveData(datasetMean, pathjoin(modelPath,'datasetMean.data'))
226 saveData(datasetStDev, pathjoin(modelPath,'datasetStDev.data'))
227
228 perm = np.arange(nTrain+nValid+nTest)
229 np.random.shuffle(perm)
230
231 for iSample in range(0, nTrain):
232 path = '%s/I%05d_Img.tif' % (imPath,perm[iSample])
233 im = im2double(tifread(path))
234 Train[iSample,:,:,0] = (im-datasetMean)/datasetStDev
235 path = '%s/I%05d_Ant.tif' % (imPath,perm[iSample])
236 im = tifread(path)
237 for i in range(nClasses):
238 LTrain[iSample,:,:,i] = (im == i+1)
239
240 for iSample in range(0, nValid):
241 path = '%s/I%05d_Img.tif' % (imPath,perm[nTrain+iSample])
242 im = im2double(tifread(path))
243 Valid[iSample,:,:,0] = (im-datasetMean)/datasetStDev
244 path = '%s/I%05d_Ant.tif' % (imPath,perm[nTrain+iSample])
245 im = tifread(path)
246 for i in range(nClasses):
247 LValid[iSample,:,:,i] = (im == i+1)
248
249 for iSample in range(0, nTest):
250 path = '%s/I%05d_Img.tif' % (imPath,perm[nTrain+nValid+iSample])
251 im = im2double(tifread(path))
252 Test[iSample,:,:,0] = (im-datasetMean)/datasetStDev
253 path = '%s/I%05d_Ant.tif' % (imPath,perm[nTrain+nValid+iSample])
254 im = tifread(path)
255 for i in range(nClasses):
256 LTest[iSample,:,:,i] = (im == i+1)
257
258 # --------------------------------------------------
259 # optimization
260 # --------------------------------------------------
261
262 tfLabels = tf.placeholder("float", shape=[None,imSize,imSize,nClasses],name='labels')
263
264 globalStep = tf.Variable(0,trainable=False)
265 learningRate0 = 0.01
266 decaySteps = 1000
267 decayRate = 0.95
268 learningRate = tf.train.exponential_decay(learningRate0,globalStep,decaySteps,decayRate,staircase=True)
269
270 with tf.name_scope('optim'):
271 loss = tf.reduce_mean(-tf.reduce_sum(tf.multiply(tfLabels,tf.log(UNet2D.nn)),3))
272 updateOps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
273 # optimizer = tf.train.MomentumOptimizer(1e-3,0.9)
274 optimizer = tf.train.MomentumOptimizer(learningRate,0.9)
275 # optimizer = tf.train.GradientDescentOptimizer(learningRate)
276 with tf.control_dependencies(updateOps):
277 optOp = optimizer.minimize(loss,global_step=globalStep)
278
279 with tf.name_scope('eval'):
280 error = []
281 for iClass in range(nClasses):
282 labels0 = tf.reshape(tf.to_int32(tf.slice(tfLabels,[0,0,0,iClass],[-1,-1,-1,1])),[batchSize,imSize,imSize])
283 predict0 = tf.reshape(tf.to_int32(tf.equal(tf.argmax(UNet2D.nn,3),iClass)),[batchSize,imSize,imSize])
284 correct = tf.multiply(labels0,predict0)
285 nCorrect0 = tf.reduce_sum(correct)
286 nLabels0 = tf.reduce_sum(labels0)
287 error.append(1-tf.to_float(nCorrect0)/tf.to_float(nLabels0))
288 errors = tf.tuple(error)
289
290 # --------------------------------------------------
291 # inspection
292 # --------------------------------------------------
293
294 with tf.name_scope('scalars'):
295 tf.summary.scalar('avg_cross_entropy', loss)
296 for iClass in range(nClasses):
297 tf.summary.scalar('avg_pixel_error_%d' % iClass, error[iClass])
298 tf.summary.scalar('learning_rate', learningRate)
299 with tf.name_scope('images'):
300 split0 = tf.slice(UNet2D.nn,[0,0,0,0],[-1,-1,-1,1])
301 split1 = tf.slice(UNet2D.nn,[0,0,0,1],[-1,-1,-1,1])
302 if nClasses > 2:
303 split2 = tf.slice(UNet2D.nn,[0,0,0,2],[-1,-1,-1,1])
304 tf.summary.image('pm0',split0)
305 tf.summary.image('pm1',split1)
306 if nClasses > 2:
307 tf.summary.image('pm2',split2)
308 merged = tf.summary.merge_all()
309
310
311 # --------------------------------------------------
312 # session
313 # --------------------------------------------------
314
315 saver = tf.train.Saver()
316 sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) # config parameter needed to save variables when using GPU
317
318 if os.path.exists(outLogPath):
319 shutil.rmtree(outLogPath)
320 trainWriter = tf.summary.FileWriter(trainWriterPath, sess.graph)
321 validWriter = tf.summary.FileWriter(validWriterPath, sess.graph)
322
323 if restoreVariables:
324 saver.restore(sess, outModelPath)
325 print("Model restored.")
326 else:
327 sess.run(tf.global_variables_initializer())
328
329 # --------------------------------------------------
330 # train
331 # --------------------------------------------------
332
333 batchData = np.zeros((batchSize,imSize,imSize,nChannels))
334 batchLabels = np.zeros((batchSize,imSize,imSize,nClasses))
335 for i in range(nSteps):
336 # train
337
338 perm = np.arange(nTrain)
339 np.random.shuffle(perm)
340
341 for j in range(batchSize):
342 batchData[j,:,:,:] = Train[perm[j],:,:,:]
343 batchLabels[j,:,:,:] = LTrain[perm[j],:,:,:]
344
345 summary,_ = sess.run([merged,optOp],feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 1})
346 trainWriter.add_summary(summary, i)
347
348 # validation
349
350 perm = np.arange(nValid)
351 np.random.shuffle(perm)
352
353 for j in range(batchSize):
354 batchData[j,:,:,:] = Valid[perm[j],:,:,:]
355 batchLabels[j,:,:,:] = LValid[perm[j],:,:,:]
356
357 summary, es = sess.run([merged, errors],feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 0})
358 validWriter.add_summary(summary, i)
359
360 e = np.mean(es)
361 print('step %05d, e: %f' % (i,e))
362
363 if i == 0:
364 if restoreVariables:
365 lowestError = e
366 else:
367 lowestError = np.inf
368
369 if np.mod(i,100) == 0 and e < lowestError:
370 lowestError = e
371 print("Model saved in file: %s" % saver.save(sess, outModelPath))
372
373
374 # --------------------------------------------------
375 # test
376 # --------------------------------------------------
377
378 if not os.path.exists(outPMPath):
379 os.makedirs(outPMPath)
380
381 for i in range(nTest):
382 j = np.mod(i,batchSize)
383
384 batchData[j,:,:,:] = Test[i,:,:,:]
385 batchLabels[j,:,:,:] = LTest[i,:,:,:]
386
387 if j == batchSize-1 or i == nTest-1:
388
389 output = sess.run(UNet2D.nn,feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 0})
390
391 for k in range(j+1):
392 pm = output[k,:,:,testPMIndex]
393 gt = batchLabels[k,:,:,testPMIndex]
394 im = np.sqrt(normalize(batchData[k,:,:,0]))
395 imwrite(np.uint8(255*np.concatenate((im,np.concatenate((pm,gt),axis=1)),axis=1)),'%s/I%05d.png' % (outPMPath,i-j+k+1))
396
397
398 # --------------------------------------------------
399 # save hyper-parameters, clean-up
400 # --------------------------------------------------
401
402 saveData(UNet2D.hp,pathjoin(modelPath,'hp.data'))
403
404 trainWriter.close()
405 validWriter.close()
406 sess.close()
407
408 def deploy(imPath,nImages,modelPath,pmPath,gpuIndex,pmIndex):
409 os.environ['CUDA_VISIBLE_DEVICES']= '%d' % gpuIndex
410 variablesPath = pathjoin(modelPath,'model.ckpt')
411 outPMPath = pmPath
412
413 hp = loadData(pathjoin(modelPath,'hp.data'))
414 UNet2D.setupWithHP(hp)
415
416 batchSize = UNet2D.hp['batchSize']
417 imSize = UNet2D.hp['imSize']
418 nChannels = UNet2D.hp['nChannels']
419 nClasses = UNet2D.hp['nClasses']
420
421 # --------------------------------------------------
422 # data
423 # --------------------------------------------------
424
425 Data = np.zeros((nImages,imSize,imSize,nChannels))
426
427 datasetMean = loadData(pathjoin(modelPath,'datasetMean.data'))
428 datasetStDev = loadData(pathjoin(modelPath,'datasetStDev.data'))
429
430 for iSample in range(0, nImages):
431 path = '%s/I%05d_Img.tif' % (imPath,iSample)
432 im = im2double(tifread(path))
433 Data[iSample,:,:,0] = (im-datasetMean)/datasetStDev
434
435 # --------------------------------------------------
436 # session
437 # --------------------------------------------------
438
439 saver = tf.train.Saver()
440 sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) # config parameter needed to save variables when using GPU
441
442 saver.restore(sess, variablesPath)
443 print("Model restored.")
444
445 # --------------------------------------------------
446 # deploy
447 # --------------------------------------------------
448
449 batchData = np.zeros((batchSize,imSize,imSize,nChannels))
450
451 if not os.path.exists(outPMPath):
452 os.makedirs(outPMPath)
453
454 for i in range(nImages):
455 print(i,nImages)
456
457 j = np.mod(i,batchSize)
458
459 batchData[j,:,:,:] = Data[i,:,:,:]
460
461 if j == batchSize-1 or i == nImages-1:
462
463 output = sess.run(UNet2D.nn,feed_dict={UNet2D.tfData: batchData, UNet2D.tfTraining: 0})
464
465 for k in range(j+1):
466 pm = output[k,:,:,pmIndex]
467 im = np.sqrt(normalize(batchData[k,:,:,0]))
468 # imwrite(np.uint8(255*np.concatenate((im,pm),axis=1)),'%s/I%05d.png' % (outPMPath,i-j+k+1))
469 imwrite(np.uint8(255*im),'%s/I%05d_Im.png' % (outPMPath,i-j+k+1))
470 imwrite(np.uint8(255*pm),'%s/I%05d_PM.png' % (outPMPath,i-j+k+1))
471
472
473 # --------------------------------------------------
474 # clean-up
475 # --------------------------------------------------
476
477 sess.close()
478
479 def singleImageInferenceSetup(modelPath,gpuIndex):
480 os.environ['CUDA_VISIBLE_DEVICES']= '%d' % gpuIndex
481 variablesPath = pathjoin(modelPath,'model.ckpt')
482 hp = loadData(pathjoin(modelPath,'hp.data'))
483 UNet2D.setupWithHP(hp)
484
485 UNet2D.DatasetMean =loadData(pathjoin(modelPath,'datasetMean.data'))
486 UNet2D.DatasetStDev = loadData(pathjoin(modelPath,'datasetStDev.data'))
487 print(UNet2D.DatasetMean)
488 print(UNet2D.DatasetStDev)
489
490 # --------------------------------------------------
491 # session
492 # --------------------------------------------------
493
494 saver = tf.train.Saver()
495 UNet2D.Session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) # config parameter needed to save variables when using GPU
496 #UNet2D.Session = tf.Session(config=tf.ConfigProto(device_count={'GPU': 0}))
497 saver.restore(UNet2D.Session, variablesPath)
498 print("Model restored.")
499
500 def singleImageInferenceCleanup():
501 UNet2D.Session.close()
502
503 def singleImageInference(image,mode,pmIndex):
504 print('Inference...')
505
506 batchSize = UNet2D.hp['batchSize']
507 imSize = UNet2D.hp['imSize']
508 nChannels = UNet2D.hp['nChannels']
509
510 PI2D.setup(image,imSize,int(imSize/8),mode)
511 PI2D.createOutput(nChannels)
512
513 batchData = np.zeros((batchSize,imSize,imSize,nChannels))
514 for i in range(PI2D.NumPatches):
515 j = np.mod(i,batchSize)
516 batchData[j,:,:,0] = (PI2D.getPatch(i)-UNet2D.DatasetMean)/UNet2D.DatasetStDev
517 if j == batchSize-1 or i == PI2D.NumPatches-1:
518 output = UNet2D.Session.run(UNet2D.nn,feed_dict={UNet2D.tfData: batchData, UNet2D.tfTraining: 0})
519 for k in range(j+1):
520 pm = output[k,:,:,pmIndex]
521 PI2D.patchOutput(i-j+k,pm)
522 # PI2D.patchOutput(i-j+k,normalize(imgradmag(PI2D.getPatch(i-j+k),1)))
523
524 return PI2D.getValidOutput()
525
526
527 def identifyNumChan(path):
528 tiff = tifffile.TiffFile(path)
529 shape = tiff.pages[0].shape
530 numChan=None
531 for i, page in enumerate(tiff.pages):
532 if page.shape != shape:
533 numChan = i
534 return numChan
535 break
536 # else:
537 # raise Exception("Did not find any pyramid subresolutions")
538
539 if not numChan:
540 numChan = len(tiff.pages)
541 return numChan
542
543 def getProbMaps(I,dsFactor,modelPath):
544 hsize = int((float(I.shape[0]) * float(0.5)))
545 vsize = int((float(I.shape[1]) * float(0.5)))
546 imagesub = cv2.resize(I,(vsize,hsize),cv2.INTER_NEAREST)
547
548 UNet2D.singleImageInferenceSetup(modelPath, 1)
549
550 for iSize in range(dsFactor):
551 hsize = int((float(I.shape[0]) * float(0.5)))
552 vsize = int((float(I.shape[1]) * float(0.5)))
553 I = cv2.resize(I,(vsize,hsize),cv2.INTER_NEAREST)
554 I = im2double(I)
555 I = im2double(sk.rescale_intensity(I, in_range=(np.min(I), np.max(I)), out_range=(0, 0.983)))
556 probMaps = UNet2D.singleImageInference(I,'accumulate',1)
557 UNet2D.singleImageInferenceCleanup()
558 return probMaps
559
560 def coreSegmenterOutput(I,probMap,initialmask,preBlur,findCenter):
561 hsize = int((float(I.shape[0]) * float(0.1)))
562 vsize = int((float(I.shape[1]) * float(0.1)))
563 nucGF = cv2.resize(I,(vsize,hsize),cv2.INTER_CUBIC)
564 # Irs = cv2.resize(I,(vsize,hsize),cv2.INTER_CUBIC)
565 # I=I.astype(np.float)
566 # r,c = I.shape
567 # I+=np.random.rand(r,c)*1e-6
568 # c1 = uniform_filter(I, 3, mode='reflect')
569 # c2 = uniform_filter(I*I, 3, mode='reflect')
570 # nucGF = np.sqrt(c2 - c1*c1)*np.sqrt(9./8)
571 # nucGF[np.isnan(nucGF)]=0
572 #active contours
573 hsize = int(float(nucGF.shape[0]))
574 vsize = int(float(nucGF.shape[1]))
575 initialmask = cv2.resize(initialmask,(vsize,hsize),cv2.INTER_NEAREST)
576 initialmask = dilation(initialmask,disk(15)) >0
577
578 # init=np.argwhere(eroded>0)
579 nucGF = gaussian(nucGF,0.7)
580 nucGF=nucGF/np.amax(nucGF)
581
582
583 # initialmask = nucGF>0
584 nuclearMask = morphological_chan_vese(nucGF, 100, init_level_set=initialmask, smoothing=10,lambda1=1.001, lambda2=1)
585
586 # nuclearMask = chan_vese(nucGF, mu=1.5, lambda1=6, lambda2=1, tol=0.0005, max_iter=2000, dt=15, init_level_set=initialmask, extended_output=True)
587 # nuclearMask = nuclearMask[0]
588
589
590 TMAmask = nuclearMask
591 # nMaskDist =distance_transform_edt(nuclearMask)
592 # fgm = peak_local_max(h_maxima(nMaskDist, 2*preBlur),indices =False)
593 # markers= np.logical_or(erosion(1-nuclearMask,disk(3)),fgm)
594 # TMAmask=watershed(-nMaskDist,label(markers),watershed_line=True)
595 # TMAmask = nuclearMask*(TMAmask>0)
596 TMAmask = remove_small_objects(TMAmask>0,round(TMAmask.shape[0])*round(TMAmask.shape[1])*0.005)
597 TMAlabel = label(TMAmask)
598 # find object closest to center
599 if findCenter==True:
600
601 stats= regionprops(TMAlabel)
602 counter=1
603 minDistance =-1
604 index =[]
605 for props in stats:
606 centroid = props.centroid
607 distanceFromCenter = np.sqrt((centroid[0]-nucGF.shape[0]/2)**2+(centroid[1]-nucGF.shape[1]/2)**2)
608 # if distanceFromCenter<0.6/2*np.sqrt(TMAlabel.shape[0]*TMAlabel.shape[1]):
609 if distanceFromCenter<minDistance or minDistance==-1 :
610 minDistance =distanceFromCenter
611 index = counter
612 counter=counter+1
613 # dist = 0.6/2*np.sqrt(TMAlabel.shape[0]*TMAlabel.shape[1])
614 TMAmask = morphology.binary_closing(TMAlabel==index,disk(3))
615
616 return TMAmask
617
618 def overlayOutline(outline,img):
619 img2 = img.copy()
620 stacked_img = np.stack((img2,)*3, axis=-1)
621 stacked_img[outline > 0] = [1, 0, 0]
622 imshowpair(img2,stacked_img)
623
624 def imshowpair(A,B):
625 plt.imshow(A,cmap='Purples')
626 plt.imshow(B,cmap='Greens',alpha=0.5)
627 plt.show()
628
629
630 if __name__ == '__main__':
631 parser=argparse.ArgumentParser()
632 parser.add_argument("--imagePath")
633 parser.add_argument("--outputPath")
634 parser.add_argument("--maskPath")
635 parser.add_argument("--downsampleFactor",type = int, default = 5)
636 parser.add_argument("--channel",type = int, default = 0)
637 parser.add_argument("--buffer",type = float, default = 2)
638 parser.add_argument("--outputChan", type=int, nargs = '+', default=[-1])
639 parser.add_argument("--sensitivity",type = float, default=0.3)
640 parser.add_argument("--useGrid",action='store_true')
641 parser.add_argument("--cluster",action='store_true')
642 args = parser.parse_args()
643
644 outputPath = args.outputPath
645 imagePath = args.imagePath
646 sensitivity = args.sensitivity
647 #scriptPath = os.path.dirname(os.path.realpath(__file__))
648 #modelPath = os.path.join(scriptPath, 'TFModel - 3class 16 kernels 5ks 2 layers')
649 #modelPath = 'D:\\LSP\\Coreograph\\model-4layersMaskAug20'
650 scriptPath = os.path.dirname(os.path.realpath(__file__))
651 modelPath = os.path.join(scriptPath, 'model')
652 # outputPath = 'D:\\LSP\\cycif\\testsets\\exemplar-002\\dearrayPython' ############
653 maskOutputPath = os.path.join(outputPath, 'masks')
654 # imagePath = 'D:\\LSP\\cycif\\testsets\\exemplar-002\\registration\\exemplar-002.ome.tif'###########
655 # imagePath = 'Y:\\sorger\\data\\RareCyte\\Connor\\TMAs\\CAJ_TMA11_13\\original_data\\TMA11\\registration\\TMA11.ome.tif'
656 # imagePath = 'Y:\\sorger\\data\\RareCyte\\Connor\\TMAs\\Z124_TMA20_22\\TMA22\\registration\\TMA22.ome.tif'
657 # classProbsPath = 'D:\\unetcoreograph.tif'
658 # imagePath = 'Y:\\sorger\\data\\RareCyte\\Connor\\Z155_PTCL\\TMA_552\\registration\\TMA_552.ome.tif'
659 # classProbsPath = 'Y:\\sorger\\data\\RareCyte\\Connor\\Z155_PTCL\\TMA_552\\probMapCore\\TMA_552_CorePM_1.tif'
660 # imagePath = 'Y:\\sorger\\data\\RareCyte\\Zoltan\\Z112_TMA17_19\\190403_ashlar\\TMA17_1092.ome.tif'
661 # classProbsPath = 'Z:\\IDAC\\Clarence\\LSP\\CyCIF\\TMA\\probMapCore\\1new_CorePM_1.tif'
662 # imagePath = 'Y:\\sorger\\data\\RareCyte\\ANNIINA\\Julia\\2018\\TMA6\\julia_tma6.ome.tif'
663 # classProbsPath = 'Z:\\IDAC\\Clarence\\LSP\\CyCIF\\TMA\\probMapCore\\3new_CorePM_1.tif'
664
665
666 # if not os.path.exists(outputPath):
667 # os.makedirs(outputPath)
668 # else:
669 # shutil.rmtree(outputPath)
670 if not os.path.exists(maskOutputPath):
671 os.makedirs(maskOutputPath)
672
673
674 channel = args.channel
675 dsFactor = 1/(2**args.downsampleFactor)
676 # I = tifffile.imread(imagePath, key=channel)
677 I = skio.imread(imagePath, img_num=channel)
678
679 imagesub = resize(I,(int((float(I.shape[0]) * dsFactor)),int((float(I.shape[1]) * dsFactor))))
680 numChan = identifyNumChan(imagePath)
681
682 outputChan = args.outputChan
683 if len(outputChan)==1:
684 if outputChan[0]==-1:
685 outputChan = [0, numChan-1]
686 else:
687 outputChan.append(outputChan[0])
688
689 classProbs = getProbMaps(I,args.downsampleFactor,modelPath)
690 # classProbs = tifffile.imread(classProbsPath,key=0)
691 preMask = gaussian(np.uint8(classProbs*255),1)>0.8
692
693 P = regionprops(label(preMask),cache=False)
694 area = [ele.area for ele in P]
695 print(str(len(P)) + ' cores detected!')
696 if len(P) <3:
697 medArea = np.median(area)
698 maxArea = np.percentile(area,99)
699 else:
700 count=0
701 labelpreMask = np.zeros(preMask.shape,dtype=np.uint32)
702 for props in P:
703 count += 1
704 yi = props.coords[:, 0]
705 xi = props.coords[:, 1]
706 labelpreMask[yi, xi] = count
707 P=regionprops(labelpreMask)
708 area = [ele.area for ele in P]
709 medArea = np.median(area)
710 maxArea = np.percentile(area,99)
711 preMask = remove_small_objects(preMask,0.2*medArea)
712 coreRad = round(np.sqrt(medArea/np.pi))
713 estCoreDiam = round(np.sqrt(maxArea/np.pi)*1.2*args.buffer)
714
715 #preprocessing
716 fgFiltered = blob_log(preMask,coreRad*0.6,threshold=sensitivity)
717 Imax = np.zeros(preMask.shape,dtype=np.uint8)
718 for iSpot in range(fgFiltered.shape[0]):
719 yi = np.uint32(round(fgFiltered[iSpot, 0]))
720 xi = np.uint32(round(fgFiltered[iSpot, 1]))
721 Imax[yi, xi] = 1
722 Imax = Imax*preMask
723 Idist = distance_transform_edt(1-Imax)
724 markers = label(Imax)
725 coreLabel = watershed(Idist,markers,watershed_line=True,mask = preMask)
726 P = regionprops(coreLabel)
727 centroids = np.array([ele.centroid for ele in P])/dsFactor
728 numCores = len(centroids)
729 estCoreDiamX = np.ones(numCores)*estCoreDiam/dsFactor
730 estCoreDiamY = np.ones(numCores)*estCoreDiam/dsFactor
731
732 if numCores ==0 & args.cluster:
733 print('No cores detected. Try adjusting the downsample factor')
734 sys.exit(255)
735
736 singleMaskTMA = np.zeros(imagesub.shape)
737 maskTMA = np.zeros(imagesub.shape)
738 bbox = [None] * numCores
739
740
741 x=np.zeros(numCores)
742 xLim=np.zeros(numCores)
743 y=np.zeros(numCores)
744 yLim=np.zeros(numCores)
745
746 # segmenting each core
747 #######################
748 for iCore in range(numCores):
749 x[iCore] = centroids[iCore,1] - estCoreDiamX[iCore]/2
750 xLim[iCore] = x[iCore]+estCoreDiamX[iCore]
751 if xLim[iCore] > I.shape[1]:
752 xLim[iCore] = I.shape[1]
753 if x[iCore]<1:
754 x[iCore]=1
755
756 y[iCore] = centroids[iCore,0] - estCoreDiamY[iCore]/2
757 yLim[iCore] = y[iCore] + estCoreDiamY[iCore]
758 if yLim[iCore] > I.shape[0]:
759 yLim[iCore] = I.shape[0]
760 if y[iCore]<1:
761 y[iCore]=1
762
763 bbox[iCore] = [round(x[iCore]), round(y[iCore]), round(xLim[iCore]), round(yLim[iCore])]
764
765 for iChan in range(outputChan[0],outputChan[1]+1):
766 with pytiff.Tiff(imagePath, "r", encoding='utf-8') as handle:
767 handle.set_page(iChan)
768 coreStack= handle[np.uint32(bbox[iCore][1]):np.uint32(bbox[iCore][3]-1), np.uint32(bbox[iCore][0]):np.uint32(bbox[iCore][2]-1)]
769 skio.imsave(outputPath + os.path.sep + str(iCore+1) + '.tif',coreStack,append=True)
770
771 with pytiff.Tiff(imagePath, "r", encoding='utf-8') as handle:
772 handle.set_page(args.channel)
773 coreSlice= handle[np.uint32(bbox[iCore][1]):np.uint32(bbox[iCore][3]-1), np.uint32(bbox[iCore][0]):np.uint32(bbox[iCore][2]-1)]
774
775 core = (coreLabel ==(iCore+1))
776 initialmask = core[np.uint32(y[iCore]*dsFactor):np.uint32(yLim[iCore]*dsFactor),np.uint32(x[iCore]*dsFactor):np.uint32(xLim[iCore]*dsFactor)]
777 initialmask = resize(initialmask,size(coreSlice),cv2.INTER_NEAREST)
778
779 singleProbMap = classProbs[np.uint32(y[iCore]*dsFactor):np.uint32(yLim[iCore]*dsFactor),np.uint32(x[iCore]*dsFactor):np.uint32(xLim[iCore]*dsFactor)]
780 singleProbMap = resize(np.uint8(255*singleProbMap),size(coreSlice),cv2.INTER_NEAREST)
781 TMAmask = coreSegmenterOutput(coreSlice,singleProbMap,initialmask,coreRad/20,False)
782 if np.sum(TMAmask)==0:
783 TMAmask = np.ones(TMAmask.shape)
784 vsize = int(float(coreSlice.shape[0]))
785 hsize = int(float(coreSlice.shape[1]))
786 masksub = resize(resize(TMAmask,(vsize,hsize),cv2.INTER_NEAREST),(int((float(coreSlice.shape[0])*dsFactor)),int((float(coreSlice.shape[1])*dsFactor))),cv2.INTER_NEAREST)
787 singleMaskTMA[int(y[iCore]*dsFactor):int(y[iCore]*dsFactor)+masksub.shape[0],int(x[iCore]*dsFactor):int(x[iCore]*dsFactor)+masksub.shape[1]]=masksub
788 maskTMA = maskTMA + resize(singleMaskTMA,maskTMA.shape,cv2.INTER_NEAREST)
789 cv2.putText(imagesub, str(iCore+1), (int(P[iCore].centroid[1]),int(P[iCore].centroid[0])), 0, 0.5, (np.amax(imagesub), np.amax(imagesub), np.amax(imagesub)), 1, cv2.LINE_AA)
790
791 skio.imsave(maskOutputPath + os.path.sep + str(iCore+1) + '_mask.tif',np.uint8(TMAmask))
792 print('Segmented core ' + str(iCore+1))
793
794 boundaries = find_boundaries(maskTMA)
795 imagesub = imagesub/np.percentile(imagesub,99.9)
796 imagesub[boundaries==1] = 1
797 skio.imsave(outputPath + os.path.sep + 'TMA_MAP.tif' ,np.uint8(imagesub*255))
798 print('Segmented all cores!')
799
800
801 #restore GPU to 0
802 #image load using tifffile