Mercurial > repos > galaxy-australia > alphafold2
comparison docker/alphafold/alphafold/model/quat_affine.py @ 1:6c92e000d684 draft
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author | galaxy-australia |
---|---|
date | Tue, 01 Mar 2022 02:53:05 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
0:7ae9d78b06f5 | 1:6c92e000d684 |
---|---|
1 # Copyright 2021 DeepMind Technologies Limited | |
2 # | |
3 # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 # you may not use this file except in compliance with the License. | |
5 # You may obtain a copy of the License at | |
6 # | |
7 # http://www.apache.org/licenses/LICENSE-2.0 | |
8 # | |
9 # Unless required by applicable law or agreed to in writing, software | |
10 # distributed under the License is distributed on an "AS IS" BASIS, | |
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 # See the License for the specific language governing permissions and | |
13 # limitations under the License. | |
14 | |
15 """Quaternion geometry modules. | |
16 | |
17 This introduces a representation of coordinate frames that is based around a | |
18 ‘QuatAffine’ object. This object describes an array of coordinate frames. | |
19 It consists of vectors corresponding to the | |
20 origin of the frames as well as orientations which are stored in two | |
21 ways, as unit quaternions as well as a rotation matrices. | |
22 The rotation matrices are derived from the unit quaternions and the two are kept | |
23 in sync. | |
24 For an explanation of the relation between unit quaternions and rotations see | |
25 https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation | |
26 | |
27 This representation is used in the model for the backbone frames. | |
28 | |
29 One important thing to note here, is that while we update both representations | |
30 the jit compiler is going to ensure that only the parts that are | |
31 actually used are executed. | |
32 """ | |
33 | |
34 | |
35 import functools | |
36 from typing import Tuple | |
37 | |
38 import jax | |
39 import jax.numpy as jnp | |
40 import numpy as np | |
41 | |
42 # pylint: disable=bad-whitespace | |
43 QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32) | |
44 | |
45 QUAT_TO_ROT[0, 0] = [[ 1, 0, 0], [ 0, 1, 0], [ 0, 0, 1]] # rr | |
46 QUAT_TO_ROT[1, 1] = [[ 1, 0, 0], [ 0,-1, 0], [ 0, 0,-1]] # ii | |
47 QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [ 0, 1, 0], [ 0, 0,-1]] # jj | |
48 QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [ 0,-1, 0], [ 0, 0, 1]] # kk | |
49 | |
50 QUAT_TO_ROT[1, 2] = [[ 0, 2, 0], [ 2, 0, 0], [ 0, 0, 0]] # ij | |
51 QUAT_TO_ROT[1, 3] = [[ 0, 0, 2], [ 0, 0, 0], [ 2, 0, 0]] # ik | |
52 QUAT_TO_ROT[2, 3] = [[ 0, 0, 0], [ 0, 0, 2], [ 0, 2, 0]] # jk | |
53 | |
54 QUAT_TO_ROT[0, 1] = [[ 0, 0, 0], [ 0, 0,-2], [ 0, 2, 0]] # ir | |
55 QUAT_TO_ROT[0, 2] = [[ 0, 0, 2], [ 0, 0, 0], [-2, 0, 0]] # jr | |
56 QUAT_TO_ROT[0, 3] = [[ 0,-2, 0], [ 2, 0, 0], [ 0, 0, 0]] # kr | |
57 | |
58 QUAT_MULTIPLY = np.zeros((4, 4, 4), dtype=np.float32) | |
59 QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0], | |
60 [ 0,-1, 0, 0], | |
61 [ 0, 0,-1, 0], | |
62 [ 0, 0, 0,-1]] | |
63 | |
64 QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0], | |
65 [ 1, 0, 0, 0], | |
66 [ 0, 0, 0, 1], | |
67 [ 0, 0,-1, 0]] | |
68 | |
69 QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0], | |
70 [ 0, 0, 0,-1], | |
71 [ 1, 0, 0, 0], | |
72 [ 0, 1, 0, 0]] | |
73 | |
74 QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1], | |
75 [ 0, 0, 1, 0], | |
76 [ 0,-1, 0, 0], | |
77 [ 1, 0, 0, 0]] | |
78 | |
79 QUAT_MULTIPLY_BY_VEC = QUAT_MULTIPLY[:, 1:, :] | |
80 # pylint: enable=bad-whitespace | |
81 | |
82 | |
83 def rot_to_quat(rot, unstack_inputs=False): | |
84 """Convert rotation matrix to quaternion. | |
85 | |
86 Note that this function calls self_adjoint_eig which is extremely expensive on | |
87 the GPU. If at all possible, this function should run on the CPU. | |
88 | |
89 Args: | |
90 rot: rotation matrix (see below for format). | |
91 unstack_inputs: If true, rotation matrix should be shape (..., 3, 3) | |
92 otherwise the rotation matrix should be a list of lists of tensors. | |
93 | |
94 Returns: | |
95 Quaternion as (..., 4) tensor. | |
96 """ | |
97 if unstack_inputs: | |
98 rot = [jnp.moveaxis(x, -1, 0) for x in jnp.moveaxis(rot, -2, 0)] | |
99 | |
100 [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot | |
101 | |
102 # pylint: disable=bad-whitespace | |
103 k = [[ xx + yy + zz, zy - yz, xz - zx, yx - xy,], | |
104 [ zy - yz, xx - yy - zz, xy + yx, xz + zx,], | |
105 [ xz - zx, xy + yx, yy - xx - zz, yz + zy,], | |
106 [ yx - xy, xz + zx, yz + zy, zz - xx - yy,]] | |
107 # pylint: enable=bad-whitespace | |
108 | |
109 k = (1./3.) * jnp.stack([jnp.stack(x, axis=-1) for x in k], | |
110 axis=-2) | |
111 | |
112 # Get eigenvalues in non-decreasing order and associated. | |
113 _, qs = jnp.linalg.eigh(k) | |
114 return qs[..., -1] | |
115 | |
116 | |
117 def rot_list_to_tensor(rot_list): | |
118 """Convert list of lists to rotation tensor.""" | |
119 return jnp.stack( | |
120 [jnp.stack(rot_list[0], axis=-1), | |
121 jnp.stack(rot_list[1], axis=-1), | |
122 jnp.stack(rot_list[2], axis=-1)], | |
123 axis=-2) | |
124 | |
125 | |
126 def vec_list_to_tensor(vec_list): | |
127 """Convert list to vector tensor.""" | |
128 return jnp.stack(vec_list, axis=-1) | |
129 | |
130 | |
131 def quat_to_rot(normalized_quat): | |
132 """Convert a normalized quaternion to a rotation matrix.""" | |
133 rot_tensor = jnp.sum( | |
134 np.reshape(QUAT_TO_ROT, (4, 4, 9)) * | |
135 normalized_quat[..., :, None, None] * | |
136 normalized_quat[..., None, :, None], | |
137 axis=(-3, -2)) | |
138 rot = jnp.moveaxis(rot_tensor, -1, 0) # Unstack. | |
139 return [[rot[0], rot[1], rot[2]], | |
140 [rot[3], rot[4], rot[5]], | |
141 [rot[6], rot[7], rot[8]]] | |
142 | |
143 | |
144 def quat_multiply_by_vec(quat, vec): | |
145 """Multiply a quaternion by a pure-vector quaternion.""" | |
146 return jnp.sum( | |
147 QUAT_MULTIPLY_BY_VEC * | |
148 quat[..., :, None, None] * | |
149 vec[..., None, :, None], | |
150 axis=(-3, -2)) | |
151 | |
152 | |
153 def quat_multiply(quat1, quat2): | |
154 """Multiply a quaternion by another quaternion.""" | |
155 return jnp.sum( | |
156 QUAT_MULTIPLY * | |
157 quat1[..., :, None, None] * | |
158 quat2[..., None, :, None], | |
159 axis=(-3, -2)) | |
160 | |
161 | |
162 def apply_rot_to_vec(rot, vec, unstack=False): | |
163 """Multiply rotation matrix by a vector.""" | |
164 if unstack: | |
165 x, y, z = [vec[:, i] for i in range(3)] | |
166 else: | |
167 x, y, z = vec | |
168 return [rot[0][0] * x + rot[0][1] * y + rot[0][2] * z, | |
169 rot[1][0] * x + rot[1][1] * y + rot[1][2] * z, | |
170 rot[2][0] * x + rot[2][1] * y + rot[2][2] * z] | |
171 | |
172 | |
173 def apply_inverse_rot_to_vec(rot, vec): | |
174 """Multiply the inverse of a rotation matrix by a vector.""" | |
175 # Inverse rotation is just transpose | |
176 return [rot[0][0] * vec[0] + rot[1][0] * vec[1] + rot[2][0] * vec[2], | |
177 rot[0][1] * vec[0] + rot[1][1] * vec[1] + rot[2][1] * vec[2], | |
178 rot[0][2] * vec[0] + rot[1][2] * vec[1] + rot[2][2] * vec[2]] | |
179 | |
180 | |
181 class QuatAffine(object): | |
182 """Affine transformation represented by quaternion and vector.""" | |
183 | |
184 def __init__(self, quaternion, translation, rotation=None, normalize=True, | |
185 unstack_inputs=False): | |
186 """Initialize from quaternion and translation. | |
187 | |
188 Args: | |
189 quaternion: Rotation represented by a quaternion, to be applied | |
190 before translation. Must be a unit quaternion unless normalize==True. | |
191 translation: Translation represented as a vector. | |
192 rotation: Same rotation as the quaternion, represented as a (..., 3, 3) | |
193 tensor. If None, rotation will be calculated from the quaternion. | |
194 normalize: If True, l2 normalize the quaternion on input. | |
195 unstack_inputs: If True, translation is a vector with last component 3 | |
196 """ | |
197 | |
198 if quaternion is not None: | |
199 assert quaternion.shape[-1] == 4 | |
200 | |
201 if unstack_inputs: | |
202 if rotation is not None: | |
203 rotation = [jnp.moveaxis(x, -1, 0) # Unstack. | |
204 for x in jnp.moveaxis(rotation, -2, 0)] # Unstack. | |
205 translation = jnp.moveaxis(translation, -1, 0) # Unstack. | |
206 | |
207 if normalize and quaternion is not None: | |
208 quaternion = quaternion / jnp.linalg.norm(quaternion, axis=-1, | |
209 keepdims=True) | |
210 | |
211 if rotation is None: | |
212 rotation = quat_to_rot(quaternion) | |
213 | |
214 self.quaternion = quaternion | |
215 self.rotation = [list(row) for row in rotation] | |
216 self.translation = list(translation) | |
217 | |
218 assert all(len(row) == 3 for row in self.rotation) | |
219 assert len(self.translation) == 3 | |
220 | |
221 def to_tensor(self): | |
222 return jnp.concatenate( | |
223 [self.quaternion] + | |
224 [jnp.expand_dims(x, axis=-1) for x in self.translation], | |
225 axis=-1) | |
226 | |
227 def apply_tensor_fn(self, tensor_fn): | |
228 """Return a new QuatAffine with tensor_fn applied (e.g. stop_gradient).""" | |
229 return QuatAffine( | |
230 tensor_fn(self.quaternion), | |
231 [tensor_fn(x) for x in self.translation], | |
232 rotation=[[tensor_fn(x) for x in row] for row in self.rotation], | |
233 normalize=False) | |
234 | |
235 def apply_rotation_tensor_fn(self, tensor_fn): | |
236 """Return a new QuatAffine with tensor_fn applied to the rotation part.""" | |
237 return QuatAffine( | |
238 tensor_fn(self.quaternion), | |
239 [x for x in self.translation], | |
240 rotation=[[tensor_fn(x) for x in row] for row in self.rotation], | |
241 normalize=False) | |
242 | |
243 def scale_translation(self, position_scale): | |
244 """Return a new quat affine with a different scale for translation.""" | |
245 | |
246 return QuatAffine( | |
247 self.quaternion, | |
248 [x * position_scale for x in self.translation], | |
249 rotation=[[x for x in row] for row in self.rotation], | |
250 normalize=False) | |
251 | |
252 @classmethod | |
253 def from_tensor(cls, tensor, normalize=False): | |
254 quaternion, tx, ty, tz = jnp.split(tensor, [4, 5, 6], axis=-1) | |
255 return cls(quaternion, | |
256 [tx[..., 0], ty[..., 0], tz[..., 0]], | |
257 normalize=normalize) | |
258 | |
259 def pre_compose(self, update): | |
260 """Return a new QuatAffine which applies the transformation update first. | |
261 | |
262 Args: | |
263 update: Length-6 vector. 3-vector of x, y, and z such that the quaternion | |
264 update is (1, x, y, z) and zero for the 3-vector is the identity | |
265 quaternion. 3-vector for translation concatenated. | |
266 | |
267 Returns: | |
268 New QuatAffine object. | |
269 """ | |
270 vector_quaternion_update, x, y, z = jnp.split(update, [3, 4, 5], axis=-1) | |
271 trans_update = [jnp.squeeze(x, axis=-1), | |
272 jnp.squeeze(y, axis=-1), | |
273 jnp.squeeze(z, axis=-1)] | |
274 | |
275 new_quaternion = (self.quaternion + | |
276 quat_multiply_by_vec(self.quaternion, | |
277 vector_quaternion_update)) | |
278 | |
279 trans_update = apply_rot_to_vec(self.rotation, trans_update) | |
280 new_translation = [ | |
281 self.translation[0] + trans_update[0], | |
282 self.translation[1] + trans_update[1], | |
283 self.translation[2] + trans_update[2]] | |
284 | |
285 return QuatAffine(new_quaternion, new_translation) | |
286 | |
287 def apply_to_point(self, point, extra_dims=0): | |
288 """Apply affine to a point. | |
289 | |
290 Args: | |
291 point: List of 3 tensors to apply affine. | |
292 extra_dims: Number of dimensions at the end of the transformed_point | |
293 shape that are not present in the rotation and translation. The most | |
294 common use is rotation N points at once with extra_dims=1 for use in a | |
295 network. | |
296 | |
297 Returns: | |
298 Transformed point after applying affine. | |
299 """ | |
300 rotation = self.rotation | |
301 translation = self.translation | |
302 for _ in range(extra_dims): | |
303 expand_fn = functools.partial(jnp.expand_dims, axis=-1) | |
304 rotation = jax.tree_map(expand_fn, rotation) | |
305 translation = jax.tree_map(expand_fn, translation) | |
306 | |
307 rot_point = apply_rot_to_vec(rotation, point) | |
308 return [ | |
309 rot_point[0] + translation[0], | |
310 rot_point[1] + translation[1], | |
311 rot_point[2] + translation[2]] | |
312 | |
313 def invert_point(self, transformed_point, extra_dims=0): | |
314 """Apply inverse of transformation to a point. | |
315 | |
316 Args: | |
317 transformed_point: List of 3 tensors to apply affine | |
318 extra_dims: Number of dimensions at the end of the transformed_point | |
319 shape that are not present in the rotation and translation. The most | |
320 common use is rotation N points at once with extra_dims=1 for use in a | |
321 network. | |
322 | |
323 Returns: | |
324 Transformed point after applying affine. | |
325 """ | |
326 rotation = self.rotation | |
327 translation = self.translation | |
328 for _ in range(extra_dims): | |
329 expand_fn = functools.partial(jnp.expand_dims, axis=-1) | |
330 rotation = jax.tree_map(expand_fn, rotation) | |
331 translation = jax.tree_map(expand_fn, translation) | |
332 | |
333 rot_point = [ | |
334 transformed_point[0] - translation[0], | |
335 transformed_point[1] - translation[1], | |
336 transformed_point[2] - translation[2]] | |
337 | |
338 return apply_inverse_rot_to_vec(rotation, rot_point) | |
339 | |
340 def __repr__(self): | |
341 return 'QuatAffine(%r, %r)' % (self.quaternion, self.translation) | |
342 | |
343 | |
344 def _multiply(a, b): | |
345 return jnp.stack([ | |
346 jnp.array([a[0][0]*b[0][0] + a[0][1]*b[1][0] + a[0][2]*b[2][0], | |
347 a[0][0]*b[0][1] + a[0][1]*b[1][1] + a[0][2]*b[2][1], | |
348 a[0][0]*b[0][2] + a[0][1]*b[1][2] + a[0][2]*b[2][2]]), | |
349 | |
350 jnp.array([a[1][0]*b[0][0] + a[1][1]*b[1][0] + a[1][2]*b[2][0], | |
351 a[1][0]*b[0][1] + a[1][1]*b[1][1] + a[1][2]*b[2][1], | |
352 a[1][0]*b[0][2] + a[1][1]*b[1][2] + a[1][2]*b[2][2]]), | |
353 | |
354 jnp.array([a[2][0]*b[0][0] + a[2][1]*b[1][0] + a[2][2]*b[2][0], | |
355 a[2][0]*b[0][1] + a[2][1]*b[1][1] + a[2][2]*b[2][1], | |
356 a[2][0]*b[0][2] + a[2][1]*b[1][2] + a[2][2]*b[2][2]])]) | |
357 | |
358 | |
359 def make_canonical_transform( | |
360 n_xyz: jnp.ndarray, | |
361 ca_xyz: jnp.ndarray, | |
362 c_xyz: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: | |
363 """Returns translation and rotation matrices to canonicalize residue atoms. | |
364 | |
365 Note that this method does not take care of symmetries. If you provide the | |
366 atom positions in the non-standard way, the N atom will end up not at | |
367 [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You | |
368 need to take care of such cases in your code. | |
369 | |
370 Args: | |
371 n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates. | |
372 ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates. | |
373 c_xyz: An array of shape [batch, 3] of carbon xyz coordinates. | |
374 | |
375 Returns: | |
376 A tuple (translation, rotation) where: | |
377 translation is an array of shape [batch, 3] defining the translation. | |
378 rotation is an array of shape [batch, 3, 3] defining the rotation. | |
379 After applying the translation and rotation to all atoms in a residue: | |
380 * All atoms will be shifted so that CA is at the origin, | |
381 * All atoms will be rotated so that C is at the x-axis, | |
382 * All atoms will be shifted so that N is in the xy plane. | |
383 """ | |
384 assert len(n_xyz.shape) == 2, n_xyz.shape | |
385 assert n_xyz.shape[-1] == 3, n_xyz.shape | |
386 assert n_xyz.shape == ca_xyz.shape == c_xyz.shape, ( | |
387 n_xyz.shape, ca_xyz.shape, c_xyz.shape) | |
388 | |
389 # Place CA at the origin. | |
390 translation = -ca_xyz | |
391 n_xyz = n_xyz + translation | |
392 c_xyz = c_xyz + translation | |
393 | |
394 # Place C on the x-axis. | |
395 c_x, c_y, c_z = [c_xyz[:, i] for i in range(3)] | |
396 # Rotate by angle c1 in the x-y plane (around the z-axis). | |
397 sin_c1 = -c_y / jnp.sqrt(1e-20 + c_x**2 + c_y**2) | |
398 cos_c1 = c_x / jnp.sqrt(1e-20 + c_x**2 + c_y**2) | |
399 zeros = jnp.zeros_like(sin_c1) | |
400 ones = jnp.ones_like(sin_c1) | |
401 # pylint: disable=bad-whitespace | |
402 c1_rot_matrix = jnp.stack([jnp.array([cos_c1, -sin_c1, zeros]), | |
403 jnp.array([sin_c1, cos_c1, zeros]), | |
404 jnp.array([zeros, zeros, ones])]) | |
405 | |
406 # Rotate by angle c2 in the x-z plane (around the y-axis). | |
407 sin_c2 = c_z / jnp.sqrt(1e-20 + c_x**2 + c_y**2 + c_z**2) | |
408 cos_c2 = jnp.sqrt(c_x**2 + c_y**2) / jnp.sqrt( | |
409 1e-20 + c_x**2 + c_y**2 + c_z**2) | |
410 c2_rot_matrix = jnp.stack([jnp.array([cos_c2, zeros, sin_c2]), | |
411 jnp.array([zeros, ones, zeros]), | |
412 jnp.array([-sin_c2, zeros, cos_c2])]) | |
413 | |
414 c_rot_matrix = _multiply(c2_rot_matrix, c1_rot_matrix) | |
415 n_xyz = jnp.stack(apply_rot_to_vec(c_rot_matrix, n_xyz, unstack=True)).T | |
416 | |
417 # Place N in the x-y plane. | |
418 _, n_y, n_z = [n_xyz[:, i] for i in range(3)] | |
419 # Rotate by angle alpha in the y-z plane (around the x-axis). | |
420 sin_n = -n_z / jnp.sqrt(1e-20 + n_y**2 + n_z**2) | |
421 cos_n = n_y / jnp.sqrt(1e-20 + n_y**2 + n_z**2) | |
422 n_rot_matrix = jnp.stack([jnp.array([ones, zeros, zeros]), | |
423 jnp.array([zeros, cos_n, -sin_n]), | |
424 jnp.array([zeros, sin_n, cos_n])]) | |
425 # pylint: enable=bad-whitespace | |
426 | |
427 return (translation, | |
428 jnp.transpose(_multiply(n_rot_matrix, c_rot_matrix), [2, 0, 1])) | |
429 | |
430 | |
431 def make_transform_from_reference( | |
432 n_xyz: jnp.ndarray, | |
433 ca_xyz: jnp.ndarray, | |
434 c_xyz: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: | |
435 """Returns rotation and translation matrices to convert from reference. | |
436 | |
437 Note that this method does not take care of symmetries. If you provide the | |
438 atom positions in the non-standard way, the N atom will end up not at | |
439 [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You | |
440 need to take care of such cases in your code. | |
441 | |
442 Args: | |
443 n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates. | |
444 ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates. | |
445 c_xyz: An array of shape [batch, 3] of carbon xyz coordinates. | |
446 | |
447 Returns: | |
448 A tuple (rotation, translation) where: | |
449 rotation is an array of shape [batch, 3, 3] defining the rotation. | |
450 translation is an array of shape [batch, 3] defining the translation. | |
451 After applying the translation and rotation to the reference backbone, | |
452 the coordinates will approximately equal to the input coordinates. | |
453 | |
454 The order of translation and rotation differs from make_canonical_transform | |
455 because the rotation from this function should be applied before the | |
456 translation, unlike make_canonical_transform. | |
457 """ | |
458 translation, rotation = make_canonical_transform(n_xyz, ca_xyz, c_xyz) | |
459 return np.transpose(rotation, (0, 2, 1)), -translation |