annotate UNet2DtCycifTRAINCoreograph.py @ 0:99308601eaa6 draft

"planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
author perssond
date Wed, 19 May 2021 21:34:38 +0000
parents
children 57f1260ca94e
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
0
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
1 import numpy as np
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
2 from scipy import misc
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
3 import tensorflow as tf
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
4 import shutil
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
5 import scipy.io as sio
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
6 import os,fnmatch,PIL,glob
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
7
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
8 import sys
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
9 sys.path.insert(0, 'C:\\Users\\Public\\Documents\\ImageScience')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
10 from toolbox.imtools import *
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
11 from toolbox.ftools import *
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
12 from toolbox.PartitionOfImage import PI2D
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
13
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
14
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
15 def concat3(lst):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
16 return tf.concat(lst,3)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
17
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
18 class UNet2D:
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
19 hp = None # hyper-parameters
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
20 nn = None # network
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
21 tfTraining = None # if training or not (to handle batch norm)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
22 tfData = None # data placeholder
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
23 Session = None
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
24 DatasetMean = 0
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
25 DatasetStDev = 0
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
26
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
27 def setupWithHP(hp):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
28 UNet2D.setup(hp['imSize'],
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
29 hp['nChannels'],
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
30 hp['nClasses'],
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
31 hp['nOut0'],
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
32 hp['featMapsFact'],
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
33 hp['downSampFact'],
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
34 hp['ks'],
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
35 hp['nExtraConvs'],
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
36 hp['stdDev0'],
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
37 hp['nLayers'],
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
38 hp['batchSize'])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
39
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
40 def setup(imSize,nChannels,nClasses,nOut0,featMapsFact,downSampFact,kernelSize,nExtraConvs,stdDev0,nDownSampLayers,batchSize):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
41 UNet2D.hp = {'imSize':imSize,
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
42 'nClasses':nClasses,
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
43 'nChannels':nChannels,
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
44 'nExtraConvs':nExtraConvs,
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
45 'nLayers':nDownSampLayers,
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
46 'featMapsFact':featMapsFact,
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
47 'downSampFact':downSampFact,
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
48 'ks':kernelSize,
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
49 'nOut0':nOut0,
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
50 'stdDev0':stdDev0,
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
51 'batchSize':batchSize}
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
52
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
53 nOutX = [UNet2D.hp['nChannels'],UNet2D.hp['nOut0']]
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
54 dsfX = []
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
55 for i in range(UNet2D.hp['nLayers']):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
56 nOutX.append(nOutX[-1]*UNet2D.hp['featMapsFact'])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
57 dsfX.append(UNet2D.hp['downSampFact'])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
58
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
59
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
60 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
61 # downsampling layer
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
62 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
63
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
64 with tf.name_scope('placeholders'):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
65 UNet2D.tfTraining = tf.placeholder(tf.bool, name='training')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
66 UNet2D.tfData = tf.placeholder("float", shape=[None,UNet2D.hp['imSize'],UNet2D.hp['imSize'],UNet2D.hp['nChannels']],name='data')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
67
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
68 def down_samp_layer(data,index):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
69 with tf.name_scope('ld%d' % index):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
70 ldXWeights1 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index], nOutX[index+1]], stddev=stdDev0),name='kernel1')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
71 ldXWeightsExtra = []
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
72 for i in range(nExtraConvs):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
73 ldXWeightsExtra.append(tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index+1], nOutX[index+1]], stddev=stdDev0),name='kernelExtra%d' % i))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
74
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
75 c00 = tf.nn.conv2d(data, ldXWeights1, strides=[1, 1, 1, 1], padding='SAME')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
76 for i in range(nExtraConvs):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
77 c00 = tf.nn.conv2d(tf.nn.relu(c00), ldXWeightsExtra[i], strides=[1, 1, 1, 1], padding='SAME')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
78
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
79 ldXWeightsShortcut = tf.Variable(tf.truncated_normal([1, 1, nOutX[index], nOutX[index+1]], stddev=stdDev0),name='shortcutWeights')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
80 shortcut = tf.nn.conv2d(data, ldXWeightsShortcut, strides=[1, 1, 1, 1], padding='SAME')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
81
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
82 bn = tf.layers.batch_normalization(tf.nn.relu(c00+shortcut), training=UNet2D.tfTraining)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
83
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
84 return tf.nn.max_pool(bn, ksize=[1, dsfX[index], dsfX[index], 1], strides=[1, dsfX[index], dsfX[index], 1], padding='SAME',name='maxpool')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
85
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
86 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
87 # bottom layer
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
88 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
89
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
90 with tf.name_scope('lb'):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
91 lbWeights1 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[UNet2D.hp['nLayers']], nOutX[UNet2D.hp['nLayers']+1]], stddev=stdDev0),name='kernel1')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
92 def lb(hidden):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
93 return tf.nn.relu(tf.nn.conv2d(hidden, lbWeights1, strides=[1, 1, 1, 1], padding='SAME'),name='conv')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
94
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
95 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
96 # downsampling
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
97 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
98
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
99 with tf.name_scope('downsampling'):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
100 dsX = []
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
101 dsX.append(UNet2D.tfData)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
102
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
103 for i in range(UNet2D.hp['nLayers']):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
104 dsX.append(down_samp_layer(dsX[i],i))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
105
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
106 b = lb(dsX[UNet2D.hp['nLayers']])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
107
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
108 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
109 # upsampling layer
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
110 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
111
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
112 def up_samp_layer(data,index):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
113 with tf.name_scope('lu%d' % index):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
114 luXWeights1 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index+1], nOutX[index+2]], stddev=stdDev0),name='kernel1')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
115 luXWeights2 = tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index]+nOutX[index+1], nOutX[index+1]], stddev=stdDev0),name='kernel2')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
116 luXWeightsExtra = []
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
117 for i in range(nExtraConvs):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
118 luXWeightsExtra.append(tf.Variable(tf.truncated_normal([UNet2D.hp['ks'], UNet2D.hp['ks'], nOutX[index+1], nOutX[index+1]], stddev=stdDev0),name='kernel2Extra%d' % i))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
119
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
120 outSize = UNet2D.hp['imSize']
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
121 for i in range(index):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
122 outSize /= dsfX[i]
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
123 outSize = int(outSize)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
124
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
125 outputShape = [UNet2D.hp['batchSize'],outSize,outSize,nOutX[index+1]]
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
126 us = tf.nn.relu(tf.nn.conv2d_transpose(data, luXWeights1, outputShape, strides=[1, dsfX[index], dsfX[index], 1], padding='SAME'),name='conv1')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
127 cc = concat3([dsX[index],us])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
128 cv = tf.nn.relu(tf.nn.conv2d(cc, luXWeights2, strides=[1, 1, 1, 1], padding='SAME'),name='conv2')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
129 for i in range(nExtraConvs):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
130 cv = tf.nn.relu(tf.nn.conv2d(cv, luXWeightsExtra[i], strides=[1, 1, 1, 1], padding='SAME'),name='conv2Extra%d' % i)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
131 return cv
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
132
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
133 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
134 # final (top) layer
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
135 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
136
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
137 with tf.name_scope('lt'):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
138 ltWeights1 = tf.Variable(tf.truncated_normal([1, 1, nOutX[1], nClasses], stddev=stdDev0),name='kernel')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
139 def lt(hidden):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
140 return tf.nn.conv2d(hidden, ltWeights1, strides=[1, 1, 1, 1], padding='SAME',name='conv')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
141
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
142
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
143 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
144 # upsampling
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
145 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
146
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
147 with tf.name_scope('upsampling'):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
148 usX = []
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
149 usX.append(b)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
150
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
151 for i in range(UNet2D.hp['nLayers']):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
152 usX.append(up_samp_layer(usX[i],UNet2D.hp['nLayers']-1-i))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
153
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
154 t = lt(usX[UNet2D.hp['nLayers']])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
155
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
156
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
157 sm = tf.nn.softmax(t,-1)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
158 UNet2D.nn = sm
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
159
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
160
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
161 def train(imPath,logPath,modelPath,pmPath,nTrain,nValid,nTest,restoreVariables,nSteps,gpuIndex,testPMIndex):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
162 os.environ['CUDA_VISIBLE_DEVICES']= '%d' % gpuIndex
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
163
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
164 outLogPath = logPath
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
165 trainWriterPath = pathjoin(logPath,'Train')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
166 validWriterPath = pathjoin(logPath,'Valid')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
167 outModelPath = pathjoin(modelPath,'model.ckpt')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
168 outPMPath = pmPath
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
169
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
170 batchSize = UNet2D.hp['batchSize']
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
171 imSize = UNet2D.hp['imSize']
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
172 nChannels = UNet2D.hp['nChannels']
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
173 nClasses = UNet2D.hp['nClasses']
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
174
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
175 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
176 # data
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
177 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
178
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
179 Train = np.zeros((nTrain,imSize,imSize,nChannels))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
180 Valid = np.zeros((nValid,imSize,imSize,nChannels))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
181 Test = np.zeros((nTest,imSize,imSize,nChannels))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
182 LTrain = np.zeros((nTrain,imSize,imSize,nClasses))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
183 LValid = np.zeros((nValid,imSize,imSize,nClasses))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
184 LTest = np.zeros((nTest,imSize,imSize,nClasses))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
185
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
186 print('loading data, computing mean / st dev')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
187 if not os.path.exists(modelPath):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
188 os.makedirs(modelPath)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
189 if restoreVariables:
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
190 datasetMean = loadData(pathjoin(modelPath,'datasetMean.data'))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
191 datasetStDev = loadData(pathjoin(modelPath,'datasetStDev.data'))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
192 else:
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
193 datasetMean = 0.09
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
194 datasetStDev = 0.09
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
195 #for iSample in range(nTrain+nValid+nTest):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
196 # I = im2double(tifread('%s/I%05d_Img.tif' % (imPath,iSample)))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
197 # datasetMean += np.mean(I)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
198 # datasetStDev += np.std(I)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
199 #datasetMean /= (nTrain+nValid+nTest)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
200 #datasetStDev /= (nTrain+nValid+nTest)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
201 saveData(datasetMean, pathjoin(modelPath,'datasetMean.data'))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
202 saveData(datasetStDev, pathjoin(modelPath,'datasetStDev.data'))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
203
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
204 perm = np.arange(nTrain+nValid+nTest)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
205 np.random.shuffle(perm)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
206
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
207 for iSample in range(0, nTrain):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
208 path = '%s/I%05d_Img.tif' % (imPath,perm[iSample])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
209 im = im2double(tifread(path))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
210 #im = im[0, 0, 0, :, :]
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
211 Train[iSample,:,:,0] = (im-datasetMean)/datasetStDev
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
212 path = '%s/I%05d_Ant.tif' % (imPath,perm[iSample])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
213 im = tifread(path)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
214 for i in range(nClasses):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
215 LTrain[iSample,:,:,i] = (im == i+1)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
216
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
217 for iSample in range(0, nValid):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
218 path = '%s/I%05d_Img.tif' % (imPath,perm[nTrain+iSample])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
219 im = im2double(tifread(path))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
220 #im = im[0, 0, 0, :, :]
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
221 Valid[iSample,:,:,0] = (im-datasetMean)/datasetStDev
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
222 path = '%s/I%05d_Ant.tif' % (imPath,perm[nTrain+iSample])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
223 im = tifread(path)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
224 for i in range(nClasses):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
225 LValid[iSample,:,:,i] = (im == i+1)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
226
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
227 for iSample in range(0, nTest):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
228 path = '%s/I%05d_Img.tif' % (imPath,perm[nTrain+nValid+iSample])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
229 im = im2double(tifread(path))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
230 #im = im[0, 0, 0, :, :]
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
231 Test[iSample,:,:,0] = (im-datasetMean)/datasetStDev
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
232 path = '%s/I%05d_Ant.tif' % (imPath,perm[nTrain+nValid+iSample])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
233 im = tifread(path)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
234 for i in range(nClasses):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
235 LTest[iSample,:,:,i] = (im == i+1)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
236
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
237 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
238 # optimization
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
239 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
240
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
241 tfLabels = tf.placeholder("float", shape=[None,imSize,imSize,nClasses],name='labels')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
242
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
243 globalStep = tf.Variable(0,trainable=False)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
244 learningRate0 = 0.05
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
245 decaySteps = 1000
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
246 decayRate = 0.95
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
247 learningRate = tf.train.exponential_decay(learningRate0,globalStep,decaySteps,decayRate,staircase=True)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
248
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
249 with tf.name_scope('optim'):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
250 loss = tf.reduce_mean(-tf.reduce_sum(tf.multiply(tfLabels,tf.log(UNet2D.nn)),3))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
251 updateOps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
252 # optimizer = tf.train.MomentumOptimizer(1e-3,0.9)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
253 optimizer = tf.train.MomentumOptimizer(learningRate,0.9)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
254 # optimizer = tf.train.GradientDescentOptimizer(learningRate)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
255 with tf.control_dependencies(updateOps):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
256 optOp = optimizer.minimize(loss,global_step=globalStep)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
257
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
258 with tf.name_scope('eval'):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
259 error = []
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
260 for iClass in range(nClasses):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
261 labels0 = tf.reshape(tf.to_int32(tf.slice(tfLabels,[0,0,0,iClass],[-1,-1,-1,1])),[batchSize,imSize,imSize])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
262 predict0 = tf.reshape(tf.to_int32(tf.equal(tf.argmax(UNet2D.nn,3),iClass)),[batchSize,imSize,imSize])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
263 correct = tf.multiply(labels0,predict0)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
264 nCorrect0 = tf.reduce_sum(correct)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
265 nLabels0 = tf.reduce_sum(labels0)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
266 error.append(1-tf.to_float(nCorrect0)/tf.to_float(nLabels0))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
267 errors = tf.tuple(error)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
268
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
269 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
270 # inspection
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
271 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
272
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
273 with tf.name_scope('scalars'):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
274 tf.summary.scalar('avg_cross_entropy', loss)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
275 for iClass in range(nClasses):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
276 tf.summary.scalar('avg_pixel_error_%d' % iClass, error[iClass])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
277 tf.summary.scalar('learning_rate', learningRate)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
278 with tf.name_scope('images'):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
279 #split0 = tf.slice(UNet2D.nn,[0,0,0,0],[-1,-1,-1,1])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
280 split0 = tf.slice(UNet2D.nn,[0,0,0,1],[-1,-1,-1,1])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
281 split1 = tf.slice(tfLabels, [0, 0, 0, 0], [-1, -1, -1, 1])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
282 if nClasses > 2:
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
283 split2 = tf.slice(UNet2D.nn,[0,0,0,2],[-1,-1,-1,1])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
284 tf.summary.image('pm0',split0)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
285 tf.summary.image('pm1',split1)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
286 if nClasses > 2:
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
287 tf.summary.image('pm2',split2)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
288 merged = tf.summary.merge_all()
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
289
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
290
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
291 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
292 # session
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
293 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
294
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
295 saver = tf.train.Saver()
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
296 sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) # config parameter needed to save variables when using GPU
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
297
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
298 if os.path.exists(outLogPath):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
299 shutil.rmtree(outLogPath)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
300 trainWriter = tf.summary.FileWriter(trainWriterPath, sess.graph)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
301 validWriter = tf.summary.FileWriter(validWriterPath, sess.graph)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
302
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
303 if restoreVariables:
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
304 saver.restore(sess, outModelPath)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
305 print("Model restored.")
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
306 else:
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
307 sess.run(tf.global_variables_initializer())
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
308
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
309 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
310 # train
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
311 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
312
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
313 batchData = np.zeros((batchSize,imSize,imSize,nChannels))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
314 batchLabels = np.zeros((batchSize,imSize,imSize,nClasses))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
315 for i in range(nSteps):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
316 # train
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
317
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
318 perm = np.arange(nTrain)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
319 np.random.shuffle(perm)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
320
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
321 for j in range(batchSize):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
322 batchData[j,:,:,:] = Train[perm[j],:,:,:]
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
323 batchLabels[j,:,:,:] = LTrain[perm[j],:,:,:]
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
324
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
325 summary,_ = sess.run([merged,optOp],feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 1})
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
326 trainWriter.add_summary(summary, i)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
327
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
328 # validation
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
329
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
330 perm = np.arange(nValid)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
331 np.random.shuffle(perm)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
332
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
333 for j in range(batchSize):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
334 batchData[j,:,:,:] = Valid[perm[j],:,:,:]
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
335 batchLabels[j,:,:,:] = LValid[perm[j],:,:,:]
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
336
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
337 summary, es = sess.run([merged, errors],feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 0})
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
338 validWriter.add_summary(summary, i)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
339
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
340 e = np.mean(es)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
341 print('step %05d, e: %f' % (i,e))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
342
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
343 if i == 0:
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
344 if restoreVariables:
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
345 lowestError = e
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
346 else:
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
347 lowestError = np.inf
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
348
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
349 if np.mod(i,100) == 0 and e < lowestError:
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
350 lowestError = e
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
351 print("Model saved in file: %s" % saver.save(sess, outModelPath))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
352
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
353
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
354 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
355 # test
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
356 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
357
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
358 if not os.path.exists(outPMPath):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
359 os.makedirs(outPMPath)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
360
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
361 for i in range(nTest):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
362 j = np.mod(i,batchSize)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
363
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
364 batchData[j,:,:,:] = Test[i,:,:,:]
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
365 batchLabels[j,:,:,:] = LTest[i,:,:,:]
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
366
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
367 if j == batchSize-1 or i == nTest-1:
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
368
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
369 output = sess.run(UNet2D.nn,feed_dict={UNet2D.tfData: batchData, tfLabels: batchLabels, UNet2D.tfTraining: 0})
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
370
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
371 for k in range(j+1):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
372 pm = output[k,:,:,testPMIndex]
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
373 gt = batchLabels[k,:,:,testPMIndex]
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
374 im = np.sqrt(normalize(batchData[k,:,:,0]))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
375 imwrite(np.uint8(255*np.concatenate((im,np.concatenate((pm,gt),axis=1)),axis=1)),'%s/I%05d.png' % (outPMPath,i-j+k+1))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
376
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
377
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
378 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
379 # save hyper-parameters, clean-up
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
380 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
381
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
382 saveData(UNet2D.hp,pathjoin(modelPath,'hp.data'))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
383
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
384 trainWriter.close()
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
385 validWriter.close()
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
386 sess.close()
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
387
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
388 def deploy(imPath,nImages,modelPath,pmPath,gpuIndex,pmIndex):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
389 os.environ['CUDA_VISIBLE_DEVICES']= '%d' % gpuIndex
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
390
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
391 variablesPath = pathjoin(modelPath,'model.ckpt')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
392 outPMPath = pmPath
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
393
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
394 hp = loadData(pathjoin(modelPath,'hp.data'))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
395 UNet2D.setupWithHP(hp)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
396
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
397 batchSize = UNet2D.hp['batchSize']
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
398 imSize = UNet2D.hp['imSize']
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
399 nChannels = UNet2D.hp['nChannels']
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
400 nClasses = UNet2D.hp['nClasses']
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
401
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
402 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
403 # data
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
404 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
405
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
406 Data = np.zeros((nImages,imSize,imSize,nChannels))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
407
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
408 datasetMean = loadData(pathjoin(modelPath,'datasetMean.data'))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
409 datasetStDev = loadData(pathjoin(modelPath,'datasetStDev.data'))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
410
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
411 for iSample in range(0, nImages):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
412 path = '%s/I%05d_Img.tif' % (imPath,iSample)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
413 im = im2double(tifread(path))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
414 #im = im[0, 0, 0, :, :]
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
415 Data[iSample,:,:,0] = (im-datasetMean)/datasetStDev
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
416
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
417 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
418 # session
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
419 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
420
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
421 saver = tf.train.Saver()
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
422 sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) # config parameter needed to save variables when using GPU
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
423
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
424 saver.restore(sess, variablesPath)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
425 print("Model restored.")
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
426
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
427 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
428 # deploy
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
429 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
430
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
431 batchData = np.zeros((batchSize,imSize,imSize,nChannels))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
432
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
433 if not os.path.exists(outPMPath):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
434 os.makedirs(outPMPath)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
435
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
436 for i in range(nImages):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
437 print(i,nImages)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
438
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
439 j = np.mod(i,batchSize)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
440
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
441 batchData[j,:,:,:] = Data[i,:,:,:]
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
442
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
443 if j == batchSize-1 or i == nImages-1:
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
444
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
445 output = sess.run(UNet2D.nn,feed_dict={UNet2D.tfData: batchData, UNet2D.tfTraining: 0})
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
446
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
447 for k in range(j+1):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
448 pm = output[k,:,:,pmIndex]
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
449 im = np.sqrt(normalize(batchData[k,:,:,0]))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
450 # imwrite(np.uint8(255*np.concatenate((im,pm),axis=1)),'%s/I%05d.png' % (outPMPath,i-j+k+1))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
451 imwrite(np.uint8(255*im),'%s/I%05d_Im.png' % (outPMPath,i-j+k+1))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
452 imwrite(np.uint8(255*pm),'%s/I%05d_PM.png' % (outPMPath,i-j+k+1))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
453
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
454
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
455 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
456 # clean-up
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
457 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
458
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
459 sess.close()
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
460
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
461 def singleImageInferenceSetup(modelPath,gpuIndex):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
462 os.environ['CUDA_VISIBLE_DEVICES']= '%d' % gpuIndex
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
463
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
464 variablesPath = pathjoin(modelPath,'model.ckpt')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
465
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
466 hp = loadData(pathjoin(modelPath,'hp.data'))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
467 UNet2D.setupWithHP(hp)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
468
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
469 UNet2D.DatasetMean = loadData(pathjoin(modelPath,'datasetMean.data'))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
470 UNet2D.DatasetStDev = loadData(pathjoin(modelPath,'datasetStDev.data'))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
471 print(UNet2D.DatasetMean)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
472 print(UNet2D.DatasetStDev)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
473
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
474 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
475 # session
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
476 # --------------------------------------------------
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
477
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
478 saver = tf.train.Saver()
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
479 UNet2D.Session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) # config parameter needed to save variables when using GPU
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
480
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
481 saver.restore(UNet2D.Session, variablesPath)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
482 print("Model restored.")
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
483
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
484 def singleImageInferenceCleanup():
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
485 UNet2D.Session.close()
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
486
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
487 def singleImageInference(image,mode,pmIndex):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
488 print('Inference...')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
489
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
490 batchSize = UNet2D.hp['batchSize']
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
491 imSize = UNet2D.hp['imSize']
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
492 nChannels = UNet2D.hp['nChannels']
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
493
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
494 PI2D.setup(image,imSize,int(imSize/8),mode)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
495 PI2D.createOutput(nChannels)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
496
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
497 batchData = np.zeros((batchSize,imSize,imSize,nChannels))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
498 for i in range(PI2D.NumPatches):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
499 j = np.mod(i,batchSize)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
500 batchData[j,:,:,0] = (PI2D.getPatch(i)-UNet2D.DatasetMean)/UNet2D.DatasetStDev
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
501 if j == batchSize-1 or i == PI2D.NumPatches-1:
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
502 output = UNet2D.Session.run(UNet2D.nn,feed_dict={UNet2D.tfData: batchData, UNet2D.tfTraining: 0})
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
503 for k in range(j+1):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
504 pm = output[k,:,:,pmIndex]
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
505 PI2D.patchOutput(i-j+k,pm)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
506 # PI2D.patchOutput(i-j+k,normalize(imgradmag(PI2D.getPatch(i-j+k),1)))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
507
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
508 return PI2D.getValidOutput()
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
509
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
510
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
511 if __name__ == '__main__':
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
512 logPath = 'D:\\LSP\\UNet\\Coreograph\\TFLogs'
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
513 modelPath = 'D:\\LSP\\Coreograph\\model-4layersMaskAug20New'
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
514 pmPath = 'D:\\LSP\\UNet\\Coreograph\\TFProbMaps'
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
515
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
516
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
517 # ----- test 1 -----
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
518
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
519 # imPath = 'D:\\LSP\\UNet\\tonsil20x1bin1chan\\tonsilAnnotations'
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
520 imPath = 'Z:/IDAC/Clarence/LSP/CyCIF/TMA/training data custom unaveraged'
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
521 # UNet2D.setup(128,1,2,8,2,2,3,1,0.1,2,8)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
522 # UNet2D.train(imPath,logPath,modelPath,pmPath,500,100,40,True,20000,1,0)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
523 UNet2D.setup(128, 1, 2, 20, 2, 2, 3, 2, 0.03, 4, 32)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
524 UNet2D.train(imPath, logPath, modelPath, pmPath, 2053, 513 , 641, True, 10, 1, 1)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
525 UNet2D.deploy(imPath,100,modelPath,pmPath,1,1)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
526
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
527 # I = im2double(tifread('/home/mc457/files/CellBiology/IDAC/Marcelo/Etc/UNetTestSets/SinemSaka_NucleiSegmentation_SingleImageInferenceTest3.tif'))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
528 # UNet2D.singleImageInferenceSetup(modelPath,0)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
529 # J = UNet2D.singleImageInference(I,'accumulate',0)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
530 # UNet2D.singleImageInferenceCleanup()
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
531 # # imshowlist([I,J])
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
532 # # sys.exit(0)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
533 # # tifwrite(np.uint8(255*I),'/home/mc457/Workspace/I1.tif')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
534 # # tifwrite(np.uint8(255*J),'/home/mc457/Workspace/I2.tif')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
535 # K = np.zeros((2,I.shape[0],I.shape[1]))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
536 # K[0,:,:] = I
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
537 # K[1,:,:] = J
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
538 # tifwrite(np.uint8(255*K),'/home/mc457/Workspace/Sinem_NucSeg.tif')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
539
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
540 # UNet2D.singleImageInferenceSetup(modelPath,0)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
541 # imagePath = 'Y://sorger//data//RareCyte//Connor//Topacio_P2_AF//ashlar//C0078'
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
542 #
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
543 # fileList = glob.glob(imagePath + '//registration//C0078.ome.tif')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
544 # print(fileList)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
545 # for iFile in fileList:
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
546 # fileName = os.path.basename(iFile)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
547 # fileNamePrefix = fileName.split(os.extsep, 1)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
548 # I = im2double(tifffile.imread(iFile, key=0))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
549 # hsize = int((float(I.shape[0])*float(0.75)))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
550 # vsize = int((float(I.shape[1])*float(0.75)))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
551 # I = resize(I,(hsize,vsize))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
552 # J = UNet2D.singleImageInference(I,'accumulate',1)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
553 # K = np.zeros((3,I.shape[0],I.shape[1]))
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
554 # K[2,:,:] = I
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
555 # K[0,:,:] = J
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
556 # J = UNet2D.singleImageInference(I, 'accumulate', 2)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
557 # K[1, :, :] = J
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
558 # outputPath = imagePath + '//prob_maps'
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
559 # if not os.path.exists(outputPath):
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
560 # os.makedirs(outputPath)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
561 # tifwrite(np.uint8(255*K),outputPath + '//' + fileNamePrefix[0] +'_NucSeg.tif')
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
562 # UNet2D.singleImageInferenceCleanup()
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
563
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
564
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
565 # ----- test 2 -----
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
566
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
567 # imPath = '/home/mc457/files/CellBiology/IDAC/Marcelo/Etc/UNetTestSets/ClarenceYapp_NucleiSegmentation'
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
568 # UNet2D.setup(128,1,2,8,2,2,3,1,0.1,3,4)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
569 # UNet2D.train(imPath,logPath,modelPath,pmPath,800,100,100,False,10,1)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
570 # UNet2D.deploy(imPath,100,modelPath,pmPath,1)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
571
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
572
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
573 # ----- test 3 -----
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
574
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
575 # imPath = '/home/mc457/files/CellBiology/IDAC/Marcelo/Etc/UNetTestSets/CarmanLi_CellTypeSegmentation'
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
576 # # UNet2D.setup(256,1,2,8,2,2,3,1,0.1,3,4)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
577 # # UNet2D.train(imPath,logPath,modelPath,pmPath,1400,100,164,False,10000,1)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
578 # UNet2D.deploy(imPath,164,modelPath,pmPath,1)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
579
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
580
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
581 # ----- test 4 -----
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
582
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
583 # imPath = '/home/cicconet/Downloads/TrainSet1'
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
584 # UNet2D.setup(64,1,2,8,2,2,3,1,0.1,3,4)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
585 # UNet2D.train(imPath,logPath,modelPath,pmPath,200,8,8,False,2000,1,0)
99308601eaa6 "planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
perssond
parents:
diff changeset
586 # # UNet2D.deploy(imPath,164,modelPath,pmPath,1)