comparison toolbox/PartitionOfImage.py @ 0:99308601eaa6 draft

"planemo upload for repository https://github.com/ohsu-comp-bio/UNetCoreograph commit fb90660a1805b3f68fcff80d525b5459c3f7dfd6-dirty"
author perssond
date Wed, 19 May 2021 21:34:38 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:99308601eaa6
1 import numpy as np
2 from toolbox.imtools import *
3 # from toolbox.ftools import *
4 # import sys
5
6 class PI2D:
7 Image = None
8 PaddedImage = None
9 PatchSize = 128
10 Margin = 14
11 SubPatchSize = 100
12 PC = None # patch coordinates
13 NumPatches = 0
14 Output = None
15 Count = None
16 NR = None
17 NC = None
18 NRPI = None
19 NCPI = None
20 Mode = None
21 W = None
22
23 def setup(image,patchSize,margin,mode):
24 PI2D.Image = image
25 PI2D.PatchSize = patchSize
26 PI2D.Margin = margin
27 subPatchSize = patchSize-2*margin
28 PI2D.SubPatchSize = subPatchSize
29
30 W = np.ones((patchSize,patchSize))
31 W[[0,-1],:] = 0
32 W[:,[0,-1]] = 0
33 for i in range(1,2*margin):
34 v = i/(2*margin)
35 W[i,i:-i] = v
36 W[-i-1,i:-i] = v
37 W[i:-i,i] = v
38 W[i:-i,-i-1] = v
39 PI2D.W = W
40
41 if len(image.shape) == 2:
42 nr,nc = image.shape
43 elif len(image.shape) == 3: # multi-channel image
44 nz,nr,nc = image.shape
45
46 PI2D.NR = nr
47 PI2D.NC = nc
48
49 npr = int(np.ceil(nr/subPatchSize)) # number of patch rows
50 npc = int(np.ceil(nc/subPatchSize)) # number of patch cols
51
52 nrpi = npr*subPatchSize+2*margin # number of rows in padded image
53 ncpi = npc*subPatchSize+2*margin # number of cols in padded image
54
55 PI2D.NRPI = nrpi
56 PI2D.NCPI = ncpi
57
58 if len(image.shape) == 2:
59 PI2D.PaddedImage = np.zeros((nrpi,ncpi))
60 PI2D.PaddedImage[margin:margin+nr,margin:margin+nc] = image
61 elif len(image.shape) == 3:
62 PI2D.PaddedImage = np.zeros((nz,nrpi,ncpi))
63 PI2D.PaddedImage[:,margin:margin+nr,margin:margin+nc] = image
64
65 PI2D.PC = [] # patch coordinates [r0,r1,c0,c1]
66 for i in range(npr):
67 r0 = i*subPatchSize
68 r1 = r0+patchSize
69 for j in range(npc):
70 c0 = j*subPatchSize
71 c1 = c0+patchSize
72 PI2D.PC.append([r0,r1,c0,c1])
73
74 PI2D.NumPatches = len(PI2D.PC)
75 PI2D.Mode = mode # 'replace' or 'accumulate'
76
77 def getPatch(i):
78 r0,r1,c0,c1 = PI2D.PC[i]
79 if len(PI2D.PaddedImage.shape) == 2:
80 return PI2D.PaddedImage[r0:r1,c0:c1]
81 if len(PI2D.PaddedImage.shape) == 3:
82 return PI2D.PaddedImage[:,r0:r1,c0:c1]
83
84 def createOutput(nChannels):
85 if nChannels == 1:
86 PI2D.Output = np.zeros((PI2D.NRPI,PI2D.NCPI),np.float16)
87 else:
88 PI2D.Output = np.zeros((nChannels,PI2D.NRPI,PI2D.NCPI),np.float16)
89 if PI2D.Mode == 'accumulate':
90 PI2D.Count = np.zeros((PI2D.NRPI,PI2D.NCPI),np.float16)
91
92 def patchOutput(i,P):
93 r0,r1,c0,c1 = PI2D.PC[i]
94 if PI2D.Mode == 'accumulate':
95 PI2D.Count[r0:r1,c0:c1] += PI2D.W
96 if len(P.shape) == 2:
97 if PI2D.Mode == 'accumulate':
98 PI2D.Output[r0:r1,c0:c1] += np.multiply(P,PI2D.W)
99 elif PI2D.Mode == 'replace':
100 PI2D.Output[r0:r1,c0:c1] = P
101 elif len(P.shape) == 3:
102 if PI2D.Mode == 'accumulate':
103 for i in range(P.shape[0]):
104 PI2D.Output[i,r0:r1,c0:c1] += np.multiply(P[i,:,:],PI2D.W)
105 elif PI2D.Mode == 'replace':
106 PI2D.Output[:,r0:r1,c0:c1] = P
107
108 def getValidOutput():
109 margin = PI2D.Margin
110 nr, nc = PI2D.NR, PI2D.NC
111 if PI2D.Mode == 'accumulate':
112 C = PI2D.Count[margin:margin+nr,margin:margin+nc]
113 if len(PI2D.Output.shape) == 2:
114 if PI2D.Mode == 'accumulate':
115 return np.divide(PI2D.Output[margin:margin+nr,margin:margin+nc],C)
116 if PI2D.Mode == 'replace':
117 return PI2D.Output[margin:margin+nr,margin:margin+nc]
118 if len(PI2D.Output.shape) == 3:
119 if PI2D.Mode == 'accumulate':
120 for i in range(PI2D.Output.shape[0]):
121 PI2D.Output[i,margin:margin+nr,margin:margin+nc] = np.divide(PI2D.Output[i,margin:margin+nr,margin:margin+nc],C)
122 return PI2D.Output[:,margin:margin+nr,margin:margin+nc]
123
124
125 def demo():
126 I = np.random.rand(128,128)
127 # PI2D.setup(I,128,14)
128 PI2D.setup(I,64,4,'replace')
129
130 nChannels = 2
131 PI2D.createOutput(nChannels)
132
133 for i in range(PI2D.NumPatches):
134 P = PI2D.getPatch(i)
135 Q = np.zeros((nChannels,P.shape[0],P.shape[1]))
136 for j in range(nChannels):
137 Q[j,:,:] = P
138 PI2D.patchOutput(i,Q)
139
140 J = PI2D.getValidOutput()
141 J = J[0,:,:]
142
143 D = np.abs(I-J)
144 print(np.max(D))
145
146 K = cat(1,cat(1,I,J),D)
147 imshow(K)
148
149
150 class PI3D:
151 Image = None
152 PaddedImage = None
153 PatchSize = 128
154 Margin = 14
155 SubPatchSize = 100
156 PC = None # patch coordinates
157 NumPatches = 0
158 Output = None
159 Count = None
160 NR = None # rows
161 NC = None # cols
162 NZ = None # planes
163 NRPI = None
164 NCPI = None
165 NZPI = None
166 Mode = None
167 W = None
168
169 def setup(image,patchSize,margin,mode):
170 PI3D.Image = image
171 PI3D.PatchSize = patchSize
172 PI3D.Margin = margin
173 subPatchSize = patchSize-2*margin
174 PI3D.SubPatchSize = subPatchSize
175
176 W = np.ones((patchSize,patchSize,patchSize))
177 W[[0,-1],:,:] = 0
178 W[:,[0,-1],:] = 0
179 W[:,:,[0,-1]] = 0
180 for i in range(1,2*margin):
181 v = i/(2*margin)
182 W[[i,-i-1],i:-i,i:-i] = v
183 W[i:-i,[i,-i-1],i:-i] = v
184 W[i:-i,i:-i,[i,-i-1]] = v
185
186 PI3D.W = W
187
188 if len(image.shape) == 3:
189 nz,nr,nc = image.shape
190 elif len(image.shape) == 4: # multi-channel image
191 nz,nw,nr,nc = image.shape
192
193 PI3D.NR = nr
194 PI3D.NC = nc
195 PI3D.NZ = nz
196
197 npr = int(np.ceil(nr/subPatchSize)) # number of patch rows
198 npc = int(np.ceil(nc/subPatchSize)) # number of patch cols
199 npz = int(np.ceil(nz/subPatchSize)) # number of patch planes
200
201 nrpi = npr*subPatchSize+2*margin # number of rows in padded image
202 ncpi = npc*subPatchSize+2*margin # number of cols in padded image
203 nzpi = npz*subPatchSize+2*margin # number of plns in padded image
204
205 PI3D.NRPI = nrpi
206 PI3D.NCPI = ncpi
207 PI3D.NZPI = nzpi
208
209 if len(image.shape) == 3:
210 PI3D.PaddedImage = np.zeros((nzpi,nrpi,ncpi))
211 PI3D.PaddedImage[margin:margin+nz,margin:margin+nr,margin:margin+nc] = image
212 elif len(image.shape) == 4:
213 PI3D.PaddedImage = np.zeros((nzpi,nw,nrpi,ncpi))
214 PI3D.PaddedImage[margin:margin+nz,:,margin:margin+nr,margin:margin+nc] = image
215
216 PI3D.PC = [] # patch coordinates [z0,z1,r0,r1,c0,c1]
217 for iZ in range(npz):
218 z0 = iZ*subPatchSize
219 z1 = z0+patchSize
220 for i in range(npr):
221 r0 = i*subPatchSize
222 r1 = r0+patchSize
223 for j in range(npc):
224 c0 = j*subPatchSize
225 c1 = c0+patchSize
226 PI3D.PC.append([z0,z1,r0,r1,c0,c1])
227
228 PI3D.NumPatches = len(PI3D.PC)
229 PI3D.Mode = mode # 'replace' or 'accumulate'
230
231 def getPatch(i):
232 z0,z1,r0,r1,c0,c1 = PI3D.PC[i]
233 if len(PI3D.PaddedImage.shape) == 3:
234 return PI3D.PaddedImage[z0:z1,r0:r1,c0:c1]
235 if len(PI3D.PaddedImage.shape) == 4:
236 return PI3D.PaddedImage[z0:z1,:,r0:r1,c0:c1]
237
238 def createOutput(nChannels):
239 if nChannels == 1:
240 PI3D.Output = np.zeros((PI3D.NZPI,PI3D.NRPI,PI3D.NCPI))
241 else:
242 PI3D.Output = np.zeros((PI3D.NZPI,nChannels,PI3D.NRPI,PI3D.NCPI))
243 if PI3D.Mode == 'accumulate':
244 PI3D.Count = np.zeros((PI3D.NZPI,PI3D.NRPI,PI3D.NCPI))
245
246 def patchOutput(i,P):
247 z0,z1,r0,r1,c0,c1 = PI3D.PC[i]
248 if PI3D.Mode == 'accumulate':
249 PI3D.Count[z0:z1,r0:r1,c0:c1] += PI3D.W
250 if len(P.shape) == 3:
251 if PI3D.Mode == 'accumulate':
252 PI3D.Output[z0:z1,r0:r1,c0:c1] += np.multiply(P,PI3D.W)
253 elif PI3D.Mode == 'replace':
254 PI3D.Output[z0:z1,r0:r1,c0:c1] = P
255 elif len(P.shape) == 4:
256 if PI3D.Mode == 'accumulate':
257 for i in range(P.shape[1]):
258 PI3D.Output[z0:z1,i,r0:r1,c0:c1] += np.multiply(P[:,i,:,:],PI3D.W)
259 elif PI3D.Mode == 'replace':
260 PI3D.Output[z0:z1,:,r0:r1,c0:c1] = P
261
262 def getValidOutput():
263 margin = PI3D.Margin
264 nz, nr, nc = PI3D.NZ, PI3D.NR, PI3D.NC
265 if PI3D.Mode == 'accumulate':
266 C = PI3D.Count[margin:margin+nz,margin:margin+nr,margin:margin+nc]
267 if len(PI3D.Output.shape) == 3:
268 if PI3D.Mode == 'accumulate':
269 return np.divide(PI3D.Output[margin:margin+nz,margin:margin+nr,margin:margin+nc],C)
270 if PI3D.Mode == 'replace':
271 return PI3D.Output[margin:margin+nz,margin:margin+nr,margin:margin+nc]
272 if len(PI3D.Output.shape) == 4:
273 if PI3D.Mode == 'accumulate':
274 for i in range(PI3D.Output.shape[1]):
275 PI3D.Output[margin:margin+nz,i,margin:margin+nr,margin:margin+nc] = np.divide(PI3D.Output[margin:margin+nz,i,margin:margin+nr,margin:margin+nc],C)
276 return PI3D.Output[margin:margin+nz,:,margin:margin+nr,margin:margin+nc]
277
278
279 def demo():
280 I = np.random.rand(128,128,128)
281 PI3D.setup(I,64,4,'accumulate')
282
283 nChannels = 2
284 PI3D.createOutput(nChannels)
285
286 for i in range(PI3D.NumPatches):
287 P = PI3D.getPatch(i)
288 Q = np.zeros((P.shape[0],nChannels,P.shape[1],P.shape[2]))
289 for j in range(nChannels):
290 Q[:,j,:,:] = P
291 PI3D.patchOutput(i,Q)
292
293 J = PI3D.getValidOutput()
294 J = J[:,0,:,:]
295
296 D = np.abs(I-J)
297 print(np.max(D))
298
299 pI = I[64,:,:]
300 pJ = J[64,:,:]
301 pD = D[64,:,:]
302
303 K = cat(1,cat(1,pI,pJ),pD)
304 imshow(K)
305