Mercurial > repos > perssond > coreograph
comparison UNet2DtCycifTRAINCoreograph.py @ 0:99308601eaa6 draft
"planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
| author | perssond |
|---|---|
| date | Wed, 19 May 2021 21:34:38 +0000 |
| parents | |
| children | 57f1260ca94e |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:99308601eaa6 |
|---|---|
| 1 import numpy as np | |
| 2 from scipy import misc | |
| 3 import tensorflow as tf | |
| 4 import shutil | |
| 5 import scipy.io as sio | |
| 6 import os,fnmatch,PIL,glob | |
| 7 | |
| 8 import sys | |
| 9 sys.path.insert(0, 'C:\\Users\\Public\\Documents\\ImageScience') | |
| 10 from toolbox.imtools import * | |
| 11 from toolbox.ftools import * | |
| 12 from toolbox.PartitionOfImage import PI2D | |
| 13 | |
| 14 | |
| 15 def concat3(lst): | |
| 16 return tf.concat(lst,3) | |
| 17 | |
| 18 class UNet2D: | |
| 19 hp = None # hyper-parameters | |
| 20 nn = None # network | |
| 21 tfTraining = None # if training or not (to handle batch norm) | |
| 22 tfData = None # data placeholder | |
| 23 Session = None | |
| 24 DatasetMean = 0 | |
| 25 DatasetStDev = 0 | |
| 26 | |
| 27 def setupWithHP(hp): | |
| 28 UNet2D.setup(hp['imSize'], | |
| 29 hp['nChannels'], | |
| 30 hp['nClasses'], | |
| 31 hp['nOut0'], | |
| 32 hp['featMapsFact'], | |
| 33 hp['downSampFact'], | |
| 34 hp['ks'], | |
| 35 hp['nExtraConvs'], | |
| 36 hp['stdDev0'], | |
| 37 hp['nLayers'], | |
| 38 hp['batchSize']) | |
| 39 | |
| 40 def setup(imSize,nChannels,nClasses,nOut0,featMapsFact,downSampFact,kernelSize,nExtraConvs,stdDev0,nDownSampLayers,batchSize): | |
| 41 UNet2D.hp = {'imSize':imSize, | |
| 42 'nClasses':nClasses, | |
| 43 'nChannels':nChannels, | |
| 44 'nExtraConvs':nExtraConvs, | |
| 45 'nLayers':nDownSampLayers, | |
| 46 'featMapsFact':featMapsFact, | |
| 47 'downSampFact':downSampFact, | |
| 48 'ks':kernelSize, | |
| 49 'nOut0':nOut0, | |
| 50 'stdDev0':stdDev0, | |
| 51 'batchSize':batchSize} | |
| 52 | |
| 53 nOutX = [UNet2D.hp['nChannels'],UNet2D.hp['nOut0']] | |
| 54 dsfX = [] | |
| 55 for i in range(UNet2D.hp['nLayers']): | |
| 56 nOutX.append(nOutX[-1]*UNet2D.hp['featMapsFact']) | |
| 57 dsfX.append(UNet2D.hp['downSampFact']) | |
| 58 | |
| 59 | |
| 60 # -------------------------------------------------- | |
| 61 # downsampling layer | |
| 62 # -------------------------------------------------- | |
| 63 | |
| 64 with tf.name_scope('placeholders'): | |
| 65 UNet2D.tfTraining = tf.placeholder(tf.bool, name='training') | |
| 66 UNet2D.tfData = tf.placeholder("float", shape=[None,UNet2D.hp['imSize'],UNet2D.hp['imSize'],UNet2D.hp['nChannels']],name='data') | |
| 67 | |
| 68 def down_samp_layer(data,index): | |
| 69 with tf.name_scope('ld%d' % index): | |
| 70 ldXWeights1 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index], nOutX[index+1]], stddev=stdDev0),name='kernel1') | |
| 71 ldXWeightsExtra = [] | |
| 72 for i in range(nExtraConvs): | |
| 73 ldXWeightsExtra.append(tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index+1], nOutX[index+1]], stddev=stdDev0),name='kernelExtra%d' % i)) | |
| 74 | |
| 75 c00 = tf.nn.conv2d(data, ldXWeights1, strides=[1, 1, 1, 1], padding='SAME') | |
| 76 for i in range(nExtraConvs): | |
| 77 c00 = tf.nn.conv2d(tf.nn.relu(c00), ldXWeightsExtra[i], strides=[1, 1, 1, 1], padding='SAME') | |
| 78 | |
| 79 ldXWeightsShortcut = tf.Variable(tf.truncated_normal([1, 1, nOutX[index], nOutX[index+1]], stddev=stdDev0),name='shortcutWeights') | |
| 80 shortcut = tf.nn.conv2d(data, ldXWeightsShortcut, strides=[1, 1, 1, 1], padding='SAME') | |
| 81 | |
| 82 bn = tf.layers.batch_normalization(tf.nn.relu(c00+shortcut), training=UNet2D.tfTraining) | |
| 83 | |
| 84 return tf.nn.max_pool(bn, ksize=[1, dsfX[index], dsfX[index], 1], strides=[1, dsfX[index], dsfX[index], 1], padding='SAME',name='maxpool') | |
| 85 | |
| 86 # -------------------------------------------------- | |
| 87 # bottom layer | |
| 88 # -------------------------------------------------- | |
| 89 | |
| 90 with tf.name_scope('lb'): | |
| 91 lbWeights1 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[UNet2D.hp['nLayers']], nOutX[UNet2D.hp['nLayers']+1]], stddev=stdDev0),name='kernel1') | |
| 92 def lb(hidden): | |
| 93 return tf.nn.relu(tf.nn.conv2d(hidden, lbWeights1, strides=[1, 1, 1, 1], padding='SAME'),name='conv') | |
| 94 | |
| 95 # -------------------------------------------------- | |
| 96 # downsampling | |
| 97 # -------------------------------------------------- | |
| 98 | |
| 99 with tf.name_scope('downsampling'): | |
| 100 dsX = [] | |
| 101 dsX.append(UNet2D.tfData) | |
| 102 | |
| 103 for i in range(UNet2D.hp['nLayers']): | |
| 104 dsX.append(down_samp_layer(dsX[i],i)) | |
| 105 | |
| 106 b = lb(dsX[UNet2D.hp['nLayers']]) | |
| 107 | |
| 108 # -------------------------------------------------- | |
| 109 # upsampling layer | |
| 110 # -------------------------------------------------- | |
| 111 | |
| 112 def up_samp_layer(data,index): | |
| 113 with tf.name_scope('lu%d' % index): | |
| 114 luXWeights1 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index+1], nOutX[index+2]], stddev=stdDev0),name='kernel1') | |
| 115 luXWeights2 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index]+nOutX[index+1], nOutX[index+1]], stddev=stdDev0),name='kernel2') | |
| 116 luXWeightsExtra = [] | |
| 117 for i in range(nExtraConvs): | |
| 118 luXWeightsExtra.append(tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index+1], nOutX[index+1]], stddev=stdDev0),name='kernel2Extra%d' % i)) | |
| 119 | |
| 120 outSize = UNet2D.hp['imSize'] | |
| 121 for i in range(index): | |
| 122 outSize /= dsfX[i] | |
| 123 outSize = int(outSize) | |
| 124 | |
| 125 outputShape = [UNet2D.hp['batchSize'],outSize,outSize,nOutX[index+1]] | |
| 126 us = tf.nn.relu(tf.nn.conv2d_transpose(data, luXWeights1, outputShape, strides=[1, dsfX[index], dsfX[index], 1], padding='SAME'),name='conv1') | |
| 127 cc = concat3([dsX[index],us]) | |
| 128 cv = tf.nn.relu(tf.nn.conv2d(cc, luXWeights2, strides=[1, 1, 1, 1], padding='SAME'),name='conv2') | |
| 129 for i in range(nExtraConvs): | |
| 130 cv = tf.nn.relu(tf.nn.conv2d(cv, luXWeightsExtra[i], strides=[1, 1, 1, 1], padding='SAME'),name='conv2Extra%d' % i) | |
| 131 return cv | |
| 132 | |
| 133 # -------------------------------------------------- | |
| 134 # final (top) layer | |
| 135 # -------------------------------------------------- | |
| 136 | |
| 137 with tf.name_scope('lt'): | |
| 138 ltWeights1 = tf.Variable(tf.truncated_normal([1, 1, nOutX[1], nClasses], stddev=stdDev0),name='kernel') | |
| 139 def lt(hidden): | |
| 140 return tf.nn.conv2d(hidden, ltWeights1, strides=[1, 1, 1, 1], padding='SAME',name='conv') | |
| 141 | |
| 142 | |
| 143 # -------------------------------------------------- | |
| 144 # upsampling | |
| 145 # -------------------------------------------------- | |
| 146 | |
| 147 with tf.name_scope('upsampling'): | |
| 148 usX = [] | |
| 149 usX.append(b) | |
| 150 | |
| 151 for i in range(UNet2D.hp['nLayers']): | |
| 152 usX.append(up_samp_layer(usX[i],UNet2D.hp['nLayers']-1-i)) | |
| 153 | |
| 154 t = lt(usX[UNet2D.hp['nLayers']]) | |
| 155 | |
| 156 | |
| 157 sm = tf.nn.softmax(t,-1) | |
| 158 UNet2D.nn = sm | |
| 159 | |
| 160 | |
| 161 def train(imPath,logPath,modelPath,pmPath,nTrain,nValid,nTest,restoreVariables,nSteps,gpuIndex,testPMIndex): | |
| 162 os.environ['CUDA_VISIBLE_DEVICES']= '%d' % gpuIndex | |
| 163 | |
| 164 outLogPath = logPath | |
| 165 trainWriterPath = pathjoin(logPath,'Train') | |
| 166 validWriterPath = pathjoin(logPath,'Valid') | |
| 167 outModelPath = pathjoin(modelPath,'model.ckpt') | |
| 168 outPMPath = pmPath | |
| 169 | |
| 170 batchSize = UNet2D.hp['batchSize'] | |
| 171 imSize = UNet2D.hp['imSize'] | |
| 172 nChannels = UNet2D.hp['nChannels'] | |
| 173 nClasses = UNet2D.hp['nClasses'] | |
| 174 | |
| 175 # -------------------------------------------------- | |
| 176 # data | |
| 177 # -------------------------------------------------- | |
| 178 | |
| 179 Train = np.zeros((nTrain,imSize,imSize,nChannels)) | |
| 180 Valid = np.zeros((nValid,imSize,imSize,nChannels)) | |
| 181 Test = np.zeros((nTest,imSize,imSize,nChannels)) | |
| 182 LTrain = np.zeros((nTrain,imSize,imSize,nClasses)) | |
| 183 LValid = np.zeros((nValid,imSize,imSize,nClasses)) | |
| 184 LTest = np.zeros((nTest,imSize,imSize,nClasses)) | |
| 185 | |
| 186 print('loading data, computing mean / st dev') | |
| 187 if not os.path.exists(modelPath): | |
| 188 os.makedirs(modelPath) | |
| 189 if restoreVariables: | |
| 190 datasetMean = loadData(pathjoin(modelPath,'datasetMean.data')) | |
| 191 datasetStDev = loadData(pathjoin(modelPath,'datasetStDev.data')) | |
| 192 else: | |
| 193 datasetMean = 0.09 | |
| 194 datasetStDev = 0.09 | |
| 195 #for iSample in range(nTrain+nValid+nTest): | |
| 196 # I = im2double(tifread('%s/I%05d_Img.tif' % (imPath,iSample))) | |
| 197 # datasetMean += np.mean(I) | |
| 198 # datasetStDev += np.std(I) | |
| 199 #datasetMean /= (nTrain+nValid+nTest) | |
| 200 #datasetStDev /= (nTrain+nValid+nTest) | |
| 201 saveData(datasetMean, pathjoin(modelPath,'datasetMean.data')) | |
| 202 saveData(datasetStDev, pathjoin(modelPath,'datasetStDev.data')) | |
| 203 | |
| 204 perm = np.arange(nTrain+nValid+nTest) | |
| 205 np.random.shuffle(perm) | |
| 206 | |
| 207 for iSample in range(0, nTrain): | |
| 208 path = '%s/I%05d_Img.tif' % (imPath,perm[iSample]) | |
| 209 im = im2double(tifread(path)) | |
| 210 #im = im[0, 0, 0, :, :] | |
| 211 Train[iSample,:,:,0] = (im-datasetMean)/datasetStDev | |
| 212 path = '%s/I%05d_Ant.tif' % (imPath,perm[iSample]) | |
| 213 im = tifread(path) | |
| 214 for i in range(nClasses): | |
| 215 LTrain[iSample,:,:,i] = (im == i+1) | |
| 216 | |
| 217 for iSample in range(0, nValid): | |
| 218 path = '%s/I%05d_Img.tif' % (imPath,perm[nTrain+iSample]) | |
| 219 im = im2double(tifread(path)) | |
| 220 #im = im[0, 0, 0, :, :] | |
| 221 Valid[iSample,:,:,0] = (im-datasetMean)/datasetStDev | |
| 222 path = '%s/I%05d_Ant.tif' % (imPath,perm[nTrain+iSample]) | |
| 223 im = tifread(path) | |
| 224 for i in range(nClasses): | |
| 225 LValid[iSample,:,:,i] = (im == i+1) | |
| 226 | |
| 227 for iSample in range(0, nTest): | |
| 228 path = '%s/I%05d_Img.tif' % (imPath,perm[nTrain+nValid+iSample]) | |
| 229 im = im2double(tifread(path)) | |
| 230 #im = im[0, 0, 0, :, :] | |
| 231 Test[iSample,:,:,0] = (im-datasetMean)/datasetStDev | |
| 232 path = '%s/I%05d_Ant.tif' % (imPath,perm[nTrain+nValid+iSample]) | |
| 233 im = tifread(path) | |
| 234 for i in range(nClasses): | |
| 235 LTest[iSample,:,:,i] = (im == i+1) | |
| 236 | |
| 237 # -------------------------------------------------- | |
| 238 # optimization | |
| 239 # -------------------------------------------------- | |
| 240 | |
| 241 tfLabels = tf.placeholder("float", shape=[None,imSize,imSize,nClasses],name='labels') | |
| 242 | |
| 243 globalStep = tf.Variable(0,trainable=False) | |
| 244 learningRate0 = 0.05 | |
| 245 decaySteps = 1000 | |
| 246 decayRate = 0.95 | |
| 247 learningRate = tf.train.exponential_decay(learningRate0,globalStep,decaySteps,decayRate,staircase=True) | |
| 248 | |
| 249 with tf.name_scope('optim'): | |
| 250 loss = tf.reduce_mean(-tf.reduce_sum(tf.multiply(tfLabels,tf.log(UNet2D.nn)),3)) | |
| 251 updateOps = tf.get_collection(tf.GraphKeys.UPDATE_OPS) | |
| 252 # optimizer = tf.train.MomentumOptimizer(1e-3,0.9) | |
| 253 optimizer = tf.train.MomentumOptimizer(learningRate,0.9) | |
| 254 # optimizer = tf.train.GradientDescentOptimizer(learningRate) | |
| 255 with tf.control_dependencies(updateOps): | |
| 256 optOp = optimizer.minimize(loss,global_step=globalStep) | |
| 257 | |
| 258 with tf.name_scope('eval'): | |
| 259 error = [] | |
| 260 for iClass in range(nClasses): | |
| 261 labels0 = tf.reshape(tf.to_int32(tf.slice(tfLabels,[0,0,0,iClass],[-1,-1,-1,1])),[batchSize,imSize,imSize]) | |
| 262 predict0 = tf.reshape(tf.to_int32(tf.equal(tf.argmax(UNet2D.nn,3),iClass)),[batchSize,imSize,imSize]) | |
| 263 correct = tf.multiply(labels0,predict0) | |
| 264 nCorrect0 = tf.reduce_sum(correct) | |
| 265 nLabels0 = tf.reduce_sum(labels0) | |
| 266 error.append(1-tf.to_float(nCorrect0)/tf.to_float(nLabels0)) | |
| 267 errors = tf.tuple(error) | |
| 268 | |
| 269 # -------------------------------------------------- | |
| 270 # inspection | |
| 271 # -------------------------------------------------- | |
| 272 | |
| 273 with tf.name_scope('scalars'): | |
| 274 tf.summary.scalar('avg_cross_entropy', loss) | |
| 275 for iClass in range(nClasses): | |
| 276 tf.summary.scalar('avg_pixel_error_%d' % iClass, error[iClass]) | |
| 277 tf.summary.scalar('learning_rate', learningRate) | |
| 278 with tf.name_scope('images'): | |
| 279 #split0 = tf.slice(UNet2D.nn,[0,0,0,0],[-1,-1,-1,1]) | |
| 280 split0 = tf.slice(UNet2D.nn,[0,0,0,1],[-1,-1,-1,1]) | |
| 281 split1 = tf.slice(tfLabels, [0, 0, 0, 0], [-1, -1, -1, 1]) | |
| 282 if nClasses > 2: | |
| 283 split2 = tf.slice(UNet2D.nn,[0,0,0,2],[-1,-1,-1,1]) | |
| 284 tf.summary.image('pm0',split0) | |
| 285 tf.summary.image('pm1',split1) | |
| 286 if nClasses > 2: | |
| 287 tf.summary.image('pm2',split2) | |
| 288 merged = tf.summary.merge_all() | |
| 289 | |
| 290 | |
| 291 # -------------------------------------------------- | |
| 292 # session | |
| 293 # -------------------------------------------------- | |
| 294 | |
| 295 saver = tf.train.Saver() | |
| 296 sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) # config parameter needed to save variables when using GPU | |
| 297 | |
| 298 if os.path.exists(outLogPath): | |
| 299 shutil.rmtree(outLogPath) | |
| 300 trainWriter = tf.summary.FileWriter(trainWriterPath, sess.graph) | |
| 301 validWriter = tf.summary.FileWriter(validWriterPath, sess.graph) | |
| 302 | |
| 303 if restoreVariables: | |
| 304 saver.restore(sess, outModelPath) | |
| 305 print("Model restored.") | |
| 306 else: | |
| 307 sess.run(tf.global_variables_initializer()) | |
| 308 | |
| 309 # -------------------------------------------------- | |
| 310 # train | |
| 311 # -------------------------------------------------- | |
| 312 | |
| 313 batchData = np.zeros((batchSize,imSize,imSize,nChannels)) | |
| 314 batchLabels = np.zeros((batchSize,imSize,imSize,nClasses)) | |
| 315 for i in range(nSteps): | |
| 316 # train | |
| 317 | |
| 318 perm = np.arange(nTrain) | |
| 319 np.random.shuffle(perm) | |
| 320 | |
| 321 for j in range(batchSize): | |
| 322 batchData[j,:,:,:] = Train[perm[j],:,:,:] | |
| 323 batchLabels[j,:,:,:] = LTrain[perm[j],:,:,:] | |
| 324 | |
| 325 summary,_ = sess.run([merged,optOp],feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 1}) | |
| 326 trainWriter.add_summary(summary, i) | |
| 327 | |
| 328 # validation | |
| 329 | |
| 330 perm = np.arange(nValid) | |
| 331 np.random.shuffle(perm) | |
| 332 | |
| 333 for j in range(batchSize): | |
| 334 batchData[j,:,:,:] = Valid[perm[j],:,:,:] | |
| 335 batchLabels[j,:,:,:] = LValid[perm[j],:,:,:] | |
| 336 | |
| 337 summary, es = sess.run([merged, errors],feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 0}) | |
| 338 validWriter.add_summary(summary, i) | |
| 339 | |
| 340 e = np.mean(es) | |
| 341 print('step %05d, e: %f' % (i,e)) | |
| 342 | |
| 343 if i == 0: | |
| 344 if restoreVariables: | |
| 345 lowestError = e | |
| 346 else: | |
| 347 lowestError = np.inf | |
| 348 | |
| 349 if np.mod(i,100) == 0 and e < lowestError: | |
| 350 lowestError = e | |
| 351 print("Model saved in file: %s" % saver.save(sess, outModelPath)) | |
| 352 | |
| 353 | |
| 354 # -------------------------------------------------- | |
| 355 # test | |
| 356 # -------------------------------------------------- | |
| 357 | |
| 358 if not os.path.exists(outPMPath): | |
| 359 os.makedirs(outPMPath) | |
| 360 | |
| 361 for i in range(nTest): | |
| 362 j = np.mod(i,batchSize) | |
| 363 | |
| 364 batchData[j,:,:,:] = Test[i,:,:,:] | |
| 365 batchLabels[j,:,:,:] = LTest[i,:,:,:] | |
| 366 | |
| 367 if j == batchSize-1 or i == nTest-1: | |
| 368 | |
| 369 output = sess.run(UNet2D.nn,feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 0}) | |
| 370 | |
| 371 for k in range(j+1): | |
| 372 pm = output[k,:,:,testPMIndex] | |
| 373 gt = batchLabels[k,:,:,testPMIndex] | |
| 374 im = np.sqrt(normalize(batchData[k,:,:,0])) | |
| 375 imwrite(np.uint8(255*np.concatenate((im,np.concatenate((pm,gt),axis=1)),axis=1)),'%s/I%05d.png' % (outPMPath,i-j+k+1)) | |
| 376 | |
| 377 | |
| 378 # -------------------------------------------------- | |
| 379 # save hyper-parameters, clean-up | |
| 380 # -------------------------------------------------- | |
| 381 | |
| 382 saveData(UNet2D.hp,pathjoin(modelPath,'hp.data')) | |
| 383 | |
| 384 trainWriter.close() | |
| 385 validWriter.close() | |
| 386 sess.close() | |
| 387 | |
| 388 def deploy(imPath,nImages,modelPath,pmPath,gpuIndex,pmIndex): | |
| 389 os.environ['CUDA_VISIBLE_DEVICES']= '%d' % gpuIndex | |
| 390 | |
| 391 variablesPath = pathjoin(modelPath,'model.ckpt') | |
| 392 outPMPath = pmPath | |
| 393 | |
| 394 hp = loadData(pathjoin(modelPath,'hp.data')) | |
| 395 UNet2D.setupWithHP(hp) | |
| 396 | |
| 397 batchSize = UNet2D.hp['batchSize'] | |
| 398 imSize = UNet2D.hp['imSize'] | |
| 399 nChannels = UNet2D.hp['nChannels'] | |
| 400 nClasses = UNet2D.hp['nClasses'] | |
| 401 | |
| 402 # -------------------------------------------------- | |
| 403 # data | |
| 404 # -------------------------------------------------- | |
| 405 | |
| 406 Data = np.zeros((nImages,imSize,imSize,nChannels)) | |
| 407 | |
| 408 datasetMean = loadData(pathjoin(modelPath,'datasetMean.data')) | |
| 409 datasetStDev = loadData(pathjoin(modelPath,'datasetStDev.data')) | |
| 410 | |
| 411 for iSample in range(0, nImages): | |
| 412 path = '%s/I%05d_Img.tif' % (imPath,iSample) | |
| 413 im = im2double(tifread(path)) | |
| 414 #im = im[0, 0, 0, :, :] | |
| 415 Data[iSample,:,:,0] = (im-datasetMean)/datasetStDev | |
| 416 | |
| 417 # -------------------------------------------------- | |
| 418 # session | |
| 419 # -------------------------------------------------- | |
| 420 | |
| 421 saver = tf.train.Saver() | |
| 422 sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) # config parameter needed to save variables when using GPU | |
| 423 | |
| 424 saver.restore(sess, variablesPath) | |
| 425 print("Model restored.") | |
| 426 | |
| 427 # -------------------------------------------------- | |
| 428 # deploy | |
| 429 # -------------------------------------------------- | |
| 430 | |
| 431 batchData = np.zeros((batchSize,imSize,imSize,nChannels)) | |
| 432 | |
| 433 if not os.path.exists(outPMPath): | |
| 434 os.makedirs(outPMPath) | |
| 435 | |
| 436 for i in range(nImages): | |
| 437 print(i,nImages) | |
| 438 | |
| 439 j = np.mod(i,batchSize) | |
| 440 | |
| 441 batchData[j,:,:,:] = Data[i,:,:,:] | |
| 442 | |
| 443 if j == batchSize-1 or i == nImages-1: | |
| 444 | |
| 445 output = sess.run(UNet2D.nn,feed_dict={UNet2D.tfData: batchData, UNet2D.tfTraining: 0}) | |
| 446 | |
| 447 for k in range(j+1): | |
| 448 pm = output[k,:,:,pmIndex] | |
| 449 im = np.sqrt(normalize(batchData[k,:,:,0])) | |
| 450 # imwrite(np.uint8(255*np.concatenate((im,pm),axis=1)),'%s/I%05d.png' % (outPMPath,i-j+k+1)) | |
| 451 imwrite(np.uint8(255*im),'%s/I%05d_Im.png' % (outPMPath,i-j+k+1)) | |
| 452 imwrite(np.uint8(255*pm),'%s/I%05d_PM.png' % (outPMPath,i-j+k+1)) | |
| 453 | |
| 454 | |
| 455 # -------------------------------------------------- | |
| 456 # clean-up | |
| 457 # -------------------------------------------------- | |
| 458 | |
| 459 sess.close() | |
| 460 | |
| 461 def singleImageInferenceSetup(modelPath,gpuIndex): | |
| 462 os.environ['CUDA_VISIBLE_DEVICES']= '%d' % gpuIndex | |
| 463 | |
| 464 variablesPath = pathjoin(modelPath,'model.ckpt') | |
| 465 | |
| 466 hp = loadData(pathjoin(modelPath,'hp.data')) | |
| 467 UNet2D.setupWithHP(hp) | |
| 468 | |
| 469 UNet2D.DatasetMean = loadData(pathjoin(modelPath,'datasetMean.data')) | |
| 470 UNet2D.DatasetStDev = loadData(pathjoin(modelPath,'datasetStDev.data')) | |
| 471 print(UNet2D.DatasetMean) | |
| 472 print(UNet2D.DatasetStDev) | |
| 473 | |
| 474 # -------------------------------------------------- | |
| 475 # session | |
| 476 # -------------------------------------------------- | |
| 477 | |
| 478 saver = tf.train.Saver() | |
| 479 UNet2D.Session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) # config parameter needed to save variables when using GPU | |
| 480 | |
| 481 saver.restore(UNet2D.Session, variablesPath) | |
| 482 print("Model restored.") | |
| 483 | |
| 484 def singleImageInferenceCleanup(): | |
| 485 UNet2D.Session.close() | |
| 486 | |
| 487 def singleImageInference(image,mode,pmIndex): | |
| 488 print('Inference...') | |
| 489 | |
| 490 batchSize = UNet2D.hp['batchSize'] | |
| 491 imSize = UNet2D.hp['imSize'] | |
| 492 nChannels = UNet2D.hp['nChannels'] | |
| 493 | |
| 494 PI2D.setup(image,imSize,int(imSize/8),mode) | |
| 495 PI2D.createOutput(nChannels) | |
| 496 | |
| 497 batchData = np.zeros((batchSize,imSize,imSize,nChannels)) | |
| 498 for i in range(PI2D.NumPatches): | |
| 499 j = np.mod(i,batchSize) | |
| 500 batchData[j,:,:,0] = (PI2D.getPatch(i)-UNet2D.DatasetMean)/UNet2D.DatasetStDev | |
| 501 if j == batchSize-1 or i == PI2D.NumPatches-1: | |
| 502 output = UNet2D.Session.run(UNet2D.nn,feed_dict={UNet2D.tfData: batchData, UNet2D.tfTraining: 0}) | |
| 503 for k in range(j+1): | |
| 504 pm = output[k,:,:,pmIndex] | |
| 505 PI2D.patchOutput(i-j+k,pm) | |
| 506 # PI2D.patchOutput(i-j+k,normalize(imgradmag(PI2D.getPatch(i-j+k),1))) | |
| 507 | |
| 508 return PI2D.getValidOutput() | |
| 509 | |
| 510 | |
| 511 if __name__ == '__main__': | |
| 512 logPath = 'D:\\LSP\\UNet\\Coreograph\\TFLogs' | |
| 513 modelPath = 'D:\\LSP\\Coreograph\\model-4layersMaskAug20New' | |
| 514 pmPath = 'D:\\LSP\\UNet\\Coreograph\\TFProbMaps' | |
| 515 | |
| 516 | |
| 517 # ----- test 1 ----- | |
| 518 | |
| 519 # imPath = 'D:\\LSP\\UNet\\tonsil20x1bin1chan\\tonsilAnnotations' | |
| 520 imPath = 'Z:/IDAC/Clarence/LSP/CyCIF/TMA/training data custom unaveraged' | |
| 521 # UNet2D.setup(128,1,2,8,2,2,3,1,0.1,2,8) | |
| 522 # UNet2D.train(imPath,logPath,modelPath,pmPath,500,100,40,True,20000,1,0) | |
| 523 UNet2D.setup(128, 1, 2, 20, 2, 2, 3, 2, 0.03, 4, 32) | |
| 524 UNet2D.train(imPath, logPath, modelPath, pmPath, 2053, 513 , 641, True, 10, 1, 1) | |
| 525 UNet2D.deploy(imPath,100,modelPath,pmPath,1,1) | |
| 526 | |
| 527 # I = im2double(tifread('/home/mc457/files/CellBiology/IDAC/Marcelo/Etc/UNetTestSets/SinemSaka_NucleiSegmentation_SingleImageInferenceTest3.tif')) | |
| 528 # UNet2D.singleImageInferenceSetup(modelPath,0) | |
| 529 # J = UNet2D.singleImageInference(I,'accumulate',0) | |
| 530 # UNet2D.singleImageInferenceCleanup() | |
| 531 # # imshowlist([I,J]) | |
| 532 # # sys.exit(0) | |
| 533 # # tifwrite(np.uint8(255*I),'/home/mc457/Workspace/I1.tif') | |
| 534 # # tifwrite(np.uint8(255*J),'/home/mc457/Workspace/I2.tif') | |
| 535 # K = np.zeros((2,I.shape[0],I.shape[1])) | |
| 536 # K[0,:,:] = I | |
| 537 # K[1,:,:] = J | |
| 538 # tifwrite(np.uint8(255*K),'/home/mc457/Workspace/Sinem_NucSeg.tif') | |
| 539 | |
| 540 # UNet2D.singleImageInferenceSetup(modelPath,0) | |
| 541 # imagePath = 'Y://sorger//data//RareCyte//Connor//Topacio_P2_AF//ashlar//C0078' | |
| 542 # | |
| 543 # fileList = glob.glob(imagePath + '//registration//C0078.ome.tif') | |
| 544 # print(fileList) | |
| 545 # for iFile in fileList: | |
| 546 # fileName = os.path.basename(iFile) | |
| 547 # fileNamePrefix = fileName.split(os.extsep, 1) | |
| 548 # I = im2double(tifffile.imread(iFile, key=0)) | |
| 549 # hsize = int((float(I.shape[0])*float(0.75))) | |
| 550 # vsize = int((float(I.shape[1])*float(0.75))) | |
| 551 # I = resize(I,(hsize,vsize)) | |
| 552 # J = UNet2D.singleImageInference(I,'accumulate',1) | |
| 553 # K = np.zeros((3,I.shape[0],I.shape[1])) | |
| 554 # K[2,:,:] = I | |
| 555 # K[0,:,:] = J | |
| 556 # J = UNet2D.singleImageInference(I, 'accumulate', 2) | |
| 557 # K[1, :, :] = J | |
| 558 # outputPath = imagePath + '//prob_maps' | |
| 559 # if not os.path.exists(outputPath): | |
| 560 # os.makedirs(outputPath) | |
| 561 # tifwrite(np.uint8(255*K),outputPath + '//' + fileNamePrefix[0] +'_NucSeg.tif') | |
| 562 # UNet2D.singleImageInferenceCleanup() | |
| 563 | |
| 564 | |
| 565 # ----- test 2 ----- | |
| 566 | |
| 567 # imPath = '/home/mc457/files/CellBiology/IDAC/Marcelo/Etc/UNetTestSets/ClarenceYapp_NucleiSegmentation' | |
| 568 # UNet2D.setup(128,1,2,8,2,2,3,1,0.1,3,4) | |
| 569 # UNet2D.train(imPath,logPath,modelPath,pmPath,800,100,100,False,10,1) | |
| 570 # UNet2D.deploy(imPath,100,modelPath,pmPath,1) | |
| 571 | |
| 572 | |
| 573 # ----- test 3 ----- | |
| 574 | |
| 575 # imPath = '/home/mc457/files/CellBiology/IDAC/Marcelo/Etc/UNetTestSets/CarmanLi_CellTypeSegmentation' | |
| 576 # # UNet2D.setup(256,1,2,8,2,2,3,1,0.1,3,4) | |
| 577 # # UNet2D.train(imPath,logPath,modelPath,pmPath,1400,100,164,False,10000,1) | |
| 578 # UNet2D.deploy(imPath,164,modelPath,pmPath,1) | |
| 579 | |
| 580 | |
| 581 # ----- test 4 ----- | |
| 582 | |
| 583 # imPath = '/home/cicconet/Downloads/TrainSet1' | |
| 584 # UNet2D.setup(64,1,2,8,2,2,3,1,0.1,3,4) | |
| 585 # UNet2D.train(imPath,logPath,modelPath,pmPath,200,8,8,False,2000,1,0) | |
| 586 # # UNet2D.deploy(imPath,164,modelPath,pmPath,1) |
