comparison docker/alphafold/alphafold/model/quat_affine.py @ 1:6c92e000d684 draft

"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author galaxy-australia
date Tue, 01 Mar 2022 02:53:05 +0000
parents
children
comparison
equal deleted inserted replaced
0:7ae9d78b06f5 1:6c92e000d684
1 # Copyright 2021 DeepMind Technologies Limited
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 """Quaternion geometry modules.
16
17 This introduces a representation of coordinate frames that is based around a
18 ‘QuatAffine’ object. This object describes an array of coordinate frames.
19 It consists of vectors corresponding to the
20 origin of the frames as well as orientations which are stored in two
21 ways, as unit quaternions as well as a rotation matrices.
22 The rotation matrices are derived from the unit quaternions and the two are kept
23 in sync.
24 For an explanation of the relation between unit quaternions and rotations see
25 https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation
26
27 This representation is used in the model for the backbone frames.
28
29 One important thing to note here, is that while we update both representations
30 the jit compiler is going to ensure that only the parts that are
31 actually used are executed.
32 """
33
34
35 import functools
36 from typing import Tuple
37
38 import jax
39 import jax.numpy as jnp
40 import numpy as np
41
42 # pylint: disable=bad-whitespace
43 QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32)
44
45 QUAT_TO_ROT[0, 0] = [[ 1, 0, 0], [ 0, 1, 0], [ 0, 0, 1]] # rr
46 QUAT_TO_ROT[1, 1] = [[ 1, 0, 0], [ 0,-1, 0], [ 0, 0,-1]] # ii
47 QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [ 0, 1, 0], [ 0, 0,-1]] # jj
48 QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [ 0,-1, 0], [ 0, 0, 1]] # kk
49
50 QUAT_TO_ROT[1, 2] = [[ 0, 2, 0], [ 2, 0, 0], [ 0, 0, 0]] # ij
51 QUAT_TO_ROT[1, 3] = [[ 0, 0, 2], [ 0, 0, 0], [ 2, 0, 0]] # ik
52 QUAT_TO_ROT[2, 3] = [[ 0, 0, 0], [ 0, 0, 2], [ 0, 2, 0]] # jk
53
54 QUAT_TO_ROT[0, 1] = [[ 0, 0, 0], [ 0, 0,-2], [ 0, 2, 0]] # ir
55 QUAT_TO_ROT[0, 2] = [[ 0, 0, 2], [ 0, 0, 0], [-2, 0, 0]] # jr
56 QUAT_TO_ROT[0, 3] = [[ 0,-2, 0], [ 2, 0, 0], [ 0, 0, 0]] # kr
57
58 QUAT_MULTIPLY = np.zeros((4, 4, 4), dtype=np.float32)
59 QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0],
60 [ 0,-1, 0, 0],
61 [ 0, 0,-1, 0],
62 [ 0, 0, 0,-1]]
63
64 QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0],
65 [ 1, 0, 0, 0],
66 [ 0, 0, 0, 1],
67 [ 0, 0,-1, 0]]
68
69 QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0],
70 [ 0, 0, 0,-1],
71 [ 1, 0, 0, 0],
72 [ 0, 1, 0, 0]]
73
74 QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
75 [ 0, 0, 1, 0],
76 [ 0,-1, 0, 0],
77 [ 1, 0, 0, 0]]
78
79 QUAT_MULTIPLY_BY_VEC = QUAT_MULTIPLY[:, 1:, :]
80 # pylint: enable=bad-whitespace
81
82
83 def rot_to_quat(rot, unstack_inputs=False):
84 """Convert rotation matrix to quaternion.
85
86 Note that this function calls self_adjoint_eig which is extremely expensive on
87 the GPU. If at all possible, this function should run on the CPU.
88
89 Args:
90 rot: rotation matrix (see below for format).
91 unstack_inputs: If true, rotation matrix should be shape (..., 3, 3)
92 otherwise the rotation matrix should be a list of lists of tensors.
93
94 Returns:
95 Quaternion as (..., 4) tensor.
96 """
97 if unstack_inputs:
98 rot = [jnp.moveaxis(x, -1, 0) for x in jnp.moveaxis(rot, -2, 0)]
99
100 [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
101
102 # pylint: disable=bad-whitespace
103 k = [[ xx + yy + zz, zy - yz, xz - zx, yx - xy,],
104 [ zy - yz, xx - yy - zz, xy + yx, xz + zx,],
105 [ xz - zx, xy + yx, yy - xx - zz, yz + zy,],
106 [ yx - xy, xz + zx, yz + zy, zz - xx - yy,]]
107 # pylint: enable=bad-whitespace
108
109 k = (1./3.) * jnp.stack([jnp.stack(x, axis=-1) for x in k],
110 axis=-2)
111
112 # Get eigenvalues in non-decreasing order and associated.
113 _, qs = jnp.linalg.eigh(k)
114 return qs[..., -1]
115
116
117 def rot_list_to_tensor(rot_list):
118 """Convert list of lists to rotation tensor."""
119 return jnp.stack(
120 [jnp.stack(rot_list[0], axis=-1),
121 jnp.stack(rot_list[1], axis=-1),
122 jnp.stack(rot_list[2], axis=-1)],
123 axis=-2)
124
125
126 def vec_list_to_tensor(vec_list):
127 """Convert list to vector tensor."""
128 return jnp.stack(vec_list, axis=-1)
129
130
131 def quat_to_rot(normalized_quat):
132 """Convert a normalized quaternion to a rotation matrix."""
133 rot_tensor = jnp.sum(
134 np.reshape(QUAT_TO_ROT, (4, 4, 9)) *
135 normalized_quat[..., :, None, None] *
136 normalized_quat[..., None, :, None],
137 axis=(-3, -2))
138 rot = jnp.moveaxis(rot_tensor, -1, 0) # Unstack.
139 return [[rot[0], rot[1], rot[2]],
140 [rot[3], rot[4], rot[5]],
141 [rot[6], rot[7], rot[8]]]
142
143
144 def quat_multiply_by_vec(quat, vec):
145 """Multiply a quaternion by a pure-vector quaternion."""
146 return jnp.sum(
147 QUAT_MULTIPLY_BY_VEC *
148 quat[..., :, None, None] *
149 vec[..., None, :, None],
150 axis=(-3, -2))
151
152
153 def quat_multiply(quat1, quat2):
154 """Multiply a quaternion by another quaternion."""
155 return jnp.sum(
156 QUAT_MULTIPLY *
157 quat1[..., :, None, None] *
158 quat2[..., None, :, None],
159 axis=(-3, -2))
160
161
162 def apply_rot_to_vec(rot, vec, unstack=False):
163 """Multiply rotation matrix by a vector."""
164 if unstack:
165 x, y, z = [vec[:, i] for i in range(3)]
166 else:
167 x, y, z = vec
168 return [rot[0][0] * x + rot[0][1] * y + rot[0][2] * z,
169 rot[1][0] * x + rot[1][1] * y + rot[1][2] * z,
170 rot[2][0] * x + rot[2][1] * y + rot[2][2] * z]
171
172
173 def apply_inverse_rot_to_vec(rot, vec):
174 """Multiply the inverse of a rotation matrix by a vector."""
175 # Inverse rotation is just transpose
176 return [rot[0][0] * vec[0] + rot[1][0] * vec[1] + rot[2][0] * vec[2],
177 rot[0][1] * vec[0] + rot[1][1] * vec[1] + rot[2][1] * vec[2],
178 rot[0][2] * vec[0] + rot[1][2] * vec[1] + rot[2][2] * vec[2]]
179
180
181 class QuatAffine(object):
182 """Affine transformation represented by quaternion and vector."""
183
184 def __init__(self, quaternion, translation, rotation=None, normalize=True,
185 unstack_inputs=False):
186 """Initialize from quaternion and translation.
187
188 Args:
189 quaternion: Rotation represented by a quaternion, to be applied
190 before translation. Must be a unit quaternion unless normalize==True.
191 translation: Translation represented as a vector.
192 rotation: Same rotation as the quaternion, represented as a (..., 3, 3)
193 tensor. If None, rotation will be calculated from the quaternion.
194 normalize: If True, l2 normalize the quaternion on input.
195 unstack_inputs: If True, translation is a vector with last component 3
196 """
197
198 if quaternion is not None:
199 assert quaternion.shape[-1] == 4
200
201 if unstack_inputs:
202 if rotation is not None:
203 rotation = [jnp.moveaxis(x, -1, 0) # Unstack.
204 for x in jnp.moveaxis(rotation, -2, 0)] # Unstack.
205 translation = jnp.moveaxis(translation, -1, 0) # Unstack.
206
207 if normalize and quaternion is not None:
208 quaternion = quaternion / jnp.linalg.norm(quaternion, axis=-1,
209 keepdims=True)
210
211 if rotation is None:
212 rotation = quat_to_rot(quaternion)
213
214 self.quaternion = quaternion
215 self.rotation = [list(row) for row in rotation]
216 self.translation = list(translation)
217
218 assert all(len(row) == 3 for row in self.rotation)
219 assert len(self.translation) == 3
220
221 def to_tensor(self):
222 return jnp.concatenate(
223 [self.quaternion] +
224 [jnp.expand_dims(x, axis=-1) for x in self.translation],
225 axis=-1)
226
227 def apply_tensor_fn(self, tensor_fn):
228 """Return a new QuatAffine with tensor_fn applied (e.g. stop_gradient)."""
229 return QuatAffine(
230 tensor_fn(self.quaternion),
231 [tensor_fn(x) for x in self.translation],
232 rotation=[[tensor_fn(x) for x in row] for row in self.rotation],
233 normalize=False)
234
235 def apply_rotation_tensor_fn(self, tensor_fn):
236 """Return a new QuatAffine with tensor_fn applied to the rotation part."""
237 return QuatAffine(
238 tensor_fn(self.quaternion),
239 [x for x in self.translation],
240 rotation=[[tensor_fn(x) for x in row] for row in self.rotation],
241 normalize=False)
242
243 def scale_translation(self, position_scale):
244 """Return a new quat affine with a different scale for translation."""
245
246 return QuatAffine(
247 self.quaternion,
248 [x * position_scale for x in self.translation],
249 rotation=[[x for x in row] for row in self.rotation],
250 normalize=False)
251
252 @classmethod
253 def from_tensor(cls, tensor, normalize=False):
254 quaternion, tx, ty, tz = jnp.split(tensor, [4, 5, 6], axis=-1)
255 return cls(quaternion,
256 [tx[..., 0], ty[..., 0], tz[..., 0]],
257 normalize=normalize)
258
259 def pre_compose(self, update):
260 """Return a new QuatAffine which applies the transformation update first.
261
262 Args:
263 update: Length-6 vector. 3-vector of x, y, and z such that the quaternion
264 update is (1, x, y, z) and zero for the 3-vector is the identity
265 quaternion. 3-vector for translation concatenated.
266
267 Returns:
268 New QuatAffine object.
269 """
270 vector_quaternion_update, x, y, z = jnp.split(update, [3, 4, 5], axis=-1)
271 trans_update = [jnp.squeeze(x, axis=-1),
272 jnp.squeeze(y, axis=-1),
273 jnp.squeeze(z, axis=-1)]
274
275 new_quaternion = (self.quaternion +
276 quat_multiply_by_vec(self.quaternion,
277 vector_quaternion_update))
278
279 trans_update = apply_rot_to_vec(self.rotation, trans_update)
280 new_translation = [
281 self.translation[0] + trans_update[0],
282 self.translation[1] + trans_update[1],
283 self.translation[2] + trans_update[2]]
284
285 return QuatAffine(new_quaternion, new_translation)
286
287 def apply_to_point(self, point, extra_dims=0):
288 """Apply affine to a point.
289
290 Args:
291 point: List of 3 tensors to apply affine.
292 extra_dims: Number of dimensions at the end of the transformed_point
293 shape that are not present in the rotation and translation. The most
294 common use is rotation N points at once with extra_dims=1 for use in a
295 network.
296
297 Returns:
298 Transformed point after applying affine.
299 """
300 rotation = self.rotation
301 translation = self.translation
302 for _ in range(extra_dims):
303 expand_fn = functools.partial(jnp.expand_dims, axis=-1)
304 rotation = jax.tree_map(expand_fn, rotation)
305 translation = jax.tree_map(expand_fn, translation)
306
307 rot_point = apply_rot_to_vec(rotation, point)
308 return [
309 rot_point[0] + translation[0],
310 rot_point[1] + translation[1],
311 rot_point[2] + translation[2]]
312
313 def invert_point(self, transformed_point, extra_dims=0):
314 """Apply inverse of transformation to a point.
315
316 Args:
317 transformed_point: List of 3 tensors to apply affine
318 extra_dims: Number of dimensions at the end of the transformed_point
319 shape that are not present in the rotation and translation. The most
320 common use is rotation N points at once with extra_dims=1 for use in a
321 network.
322
323 Returns:
324 Transformed point after applying affine.
325 """
326 rotation = self.rotation
327 translation = self.translation
328 for _ in range(extra_dims):
329 expand_fn = functools.partial(jnp.expand_dims, axis=-1)
330 rotation = jax.tree_map(expand_fn, rotation)
331 translation = jax.tree_map(expand_fn, translation)
332
333 rot_point = [
334 transformed_point[0] - translation[0],
335 transformed_point[1] - translation[1],
336 transformed_point[2] - translation[2]]
337
338 return apply_inverse_rot_to_vec(rotation, rot_point)
339
340 def __repr__(self):
341 return 'QuatAffine(%r, %r)' % (self.quaternion, self.translation)
342
343
344 def _multiply(a, b):
345 return jnp.stack([
346 jnp.array([a[0][0]*b[0][0] + a[0][1]*b[1][0] + a[0][2]*b[2][0],
347 a[0][0]*b[0][1] + a[0][1]*b[1][1] + a[0][2]*b[2][1],
348 a[0][0]*b[0][2] + a[0][1]*b[1][2] + a[0][2]*b[2][2]]),
349
350 jnp.array([a[1][0]*b[0][0] + a[1][1]*b[1][0] + a[1][2]*b[2][0],
351 a[1][0]*b[0][1] + a[1][1]*b[1][1] + a[1][2]*b[2][1],
352 a[1][0]*b[0][2] + a[1][1]*b[1][2] + a[1][2]*b[2][2]]),
353
354 jnp.array([a[2][0]*b[0][0] + a[2][1]*b[1][0] + a[2][2]*b[2][0],
355 a[2][0]*b[0][1] + a[2][1]*b[1][1] + a[2][2]*b[2][1],
356 a[2][0]*b[0][2] + a[2][1]*b[1][2] + a[2][2]*b[2][2]])])
357
358
359 def make_canonical_transform(
360 n_xyz: jnp.ndarray,
361 ca_xyz: jnp.ndarray,
362 c_xyz: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
363 """Returns translation and rotation matrices to canonicalize residue atoms.
364
365 Note that this method does not take care of symmetries. If you provide the
366 atom positions in the non-standard way, the N atom will end up not at
367 [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
368 need to take care of such cases in your code.
369
370 Args:
371 n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.
372 ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.
373 c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.
374
375 Returns:
376 A tuple (translation, rotation) where:
377 translation is an array of shape [batch, 3] defining the translation.
378 rotation is an array of shape [batch, 3, 3] defining the rotation.
379 After applying the translation and rotation to all atoms in a residue:
380 * All atoms will be shifted so that CA is at the origin,
381 * All atoms will be rotated so that C is at the x-axis,
382 * All atoms will be shifted so that N is in the xy plane.
383 """
384 assert len(n_xyz.shape) == 2, n_xyz.shape
385 assert n_xyz.shape[-1] == 3, n_xyz.shape
386 assert n_xyz.shape == ca_xyz.shape == c_xyz.shape, (
387 n_xyz.shape, ca_xyz.shape, c_xyz.shape)
388
389 # Place CA at the origin.
390 translation = -ca_xyz
391 n_xyz = n_xyz + translation
392 c_xyz = c_xyz + translation
393
394 # Place C on the x-axis.
395 c_x, c_y, c_z = [c_xyz[:, i] for i in range(3)]
396 # Rotate by angle c1 in the x-y plane (around the z-axis).
397 sin_c1 = -c_y / jnp.sqrt(1e-20 + c_x**2 + c_y**2)
398 cos_c1 = c_x / jnp.sqrt(1e-20 + c_x**2 + c_y**2)
399 zeros = jnp.zeros_like(sin_c1)
400 ones = jnp.ones_like(sin_c1)
401 # pylint: disable=bad-whitespace
402 c1_rot_matrix = jnp.stack([jnp.array([cos_c1, -sin_c1, zeros]),
403 jnp.array([sin_c1, cos_c1, zeros]),
404 jnp.array([zeros, zeros, ones])])
405
406 # Rotate by angle c2 in the x-z plane (around the y-axis).
407 sin_c2 = c_z / jnp.sqrt(1e-20 + c_x**2 + c_y**2 + c_z**2)
408 cos_c2 = jnp.sqrt(c_x**2 + c_y**2) / jnp.sqrt(
409 1e-20 + c_x**2 + c_y**2 + c_z**2)
410 c2_rot_matrix = jnp.stack([jnp.array([cos_c2, zeros, sin_c2]),
411 jnp.array([zeros, ones, zeros]),
412 jnp.array([-sin_c2, zeros, cos_c2])])
413
414 c_rot_matrix = _multiply(c2_rot_matrix, c1_rot_matrix)
415 n_xyz = jnp.stack(apply_rot_to_vec(c_rot_matrix, n_xyz, unstack=True)).T
416
417 # Place N in the x-y plane.
418 _, n_y, n_z = [n_xyz[:, i] for i in range(3)]
419 # Rotate by angle alpha in the y-z plane (around the x-axis).
420 sin_n = -n_z / jnp.sqrt(1e-20 + n_y**2 + n_z**2)
421 cos_n = n_y / jnp.sqrt(1e-20 + n_y**2 + n_z**2)
422 n_rot_matrix = jnp.stack([jnp.array([ones, zeros, zeros]),
423 jnp.array([zeros, cos_n, -sin_n]),
424 jnp.array([zeros, sin_n, cos_n])])
425 # pylint: enable=bad-whitespace
426
427 return (translation,
428 jnp.transpose(_multiply(n_rot_matrix, c_rot_matrix), [2, 0, 1]))
429
430
431 def make_transform_from_reference(
432 n_xyz: jnp.ndarray,
433 ca_xyz: jnp.ndarray,
434 c_xyz: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
435 """Returns rotation and translation matrices to convert from reference.
436
437 Note that this method does not take care of symmetries. If you provide the
438 atom positions in the non-standard way, the N atom will end up not at
439 [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
440 need to take care of such cases in your code.
441
442 Args:
443 n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.
444 ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.
445 c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.
446
447 Returns:
448 A tuple (rotation, translation) where:
449 rotation is an array of shape [batch, 3, 3] defining the rotation.
450 translation is an array of shape [batch, 3] defining the translation.
451 After applying the translation and rotation to the reference backbone,
452 the coordinates will approximately equal to the input coordinates.
453
454 The order of translation and rotation differs from make_canonical_transform
455 because the rotation from this function should be applied before the
456 translation, unlike make_canonical_transform.
457 """
458 translation, rotation = make_canonical_transform(n_xyz, ca_xyz, c_xyz)
459 return np.transpose(rotation, (0, 2, 1)), -translation