diff docker/alphafold/alphafold/model/quat_affine.py @ 1:6c92e000d684 draft

"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author galaxy-australia
date Tue, 01 Mar 2022 02:53:05 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/docker/alphafold/alphafold/model/quat_affine.py	Tue Mar 01 02:53:05 2022 +0000
@@ -0,0 +1,459 @@
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Quaternion geometry modules.
+
+This introduces a representation of coordinate frames that is based around a
+‘QuatAffine’ object. This object describes an array of coordinate frames.
+It consists of vectors corresponding to the
+origin of the frames as well as orientations which are stored in two
+ways, as unit quaternions as well as a rotation matrices.
+The rotation matrices are derived from the unit quaternions and the two are kept
+in sync.
+For an explanation of the relation between unit quaternions and rotations see
+https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation
+
+This representation is used in the model for the backbone frames.
+
+One important thing to note here, is that while we update both representations
+the jit compiler is going to ensure that only the parts that are
+actually used are executed.
+"""
+
+
+import functools
+from typing import Tuple
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+
+# pylint: disable=bad-whitespace
+QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32)
+
+QUAT_TO_ROT[0, 0] = [[ 1, 0, 0], [ 0, 1, 0], [ 0, 0, 1]]  # rr
+QUAT_TO_ROT[1, 1] = [[ 1, 0, 0], [ 0,-1, 0], [ 0, 0,-1]]  # ii
+QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [ 0, 1, 0], [ 0, 0,-1]]  # jj
+QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [ 0,-1, 0], [ 0, 0, 1]]  # kk
+
+QUAT_TO_ROT[1, 2] = [[ 0, 2, 0], [ 2, 0, 0], [ 0, 0, 0]]  # ij
+QUAT_TO_ROT[1, 3] = [[ 0, 0, 2], [ 0, 0, 0], [ 2, 0, 0]]  # ik
+QUAT_TO_ROT[2, 3] = [[ 0, 0, 0], [ 0, 0, 2], [ 0, 2, 0]]  # jk
+
+QUAT_TO_ROT[0, 1] = [[ 0, 0, 0], [ 0, 0,-2], [ 0, 2, 0]]  # ir
+QUAT_TO_ROT[0, 2] = [[ 0, 0, 2], [ 0, 0, 0], [-2, 0, 0]]  # jr
+QUAT_TO_ROT[0, 3] = [[ 0,-2, 0], [ 2, 0, 0], [ 0, 0, 0]]  # kr
+
+QUAT_MULTIPLY = np.zeros((4, 4, 4), dtype=np.float32)
+QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0],
+                          [ 0,-1, 0, 0],
+                          [ 0, 0,-1, 0],
+                          [ 0, 0, 0,-1]]
+
+QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0],
+                          [ 1, 0, 0, 0],
+                          [ 0, 0, 0, 1],
+                          [ 0, 0,-1, 0]]
+
+QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0],
+                          [ 0, 0, 0,-1],
+                          [ 1, 0, 0, 0],
+                          [ 0, 1, 0, 0]]
+
+QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
+                          [ 0, 0, 1, 0],
+                          [ 0,-1, 0, 0],
+                          [ 1, 0, 0, 0]]
+
+QUAT_MULTIPLY_BY_VEC = QUAT_MULTIPLY[:, 1:, :]
+# pylint: enable=bad-whitespace
+
+
+def rot_to_quat(rot, unstack_inputs=False):
+  """Convert rotation matrix to quaternion.
+
+  Note that this function calls self_adjoint_eig which is extremely expensive on
+  the GPU. If at all possible, this function should run on the CPU.
+
+  Args:
+     rot: rotation matrix (see below for format).
+     unstack_inputs:  If true, rotation matrix should be shape (..., 3, 3)
+       otherwise the rotation matrix should be a list of lists of tensors.
+
+  Returns:
+    Quaternion as (..., 4) tensor.
+  """
+  if unstack_inputs:
+    rot = [jnp.moveaxis(x, -1, 0) for x in jnp.moveaxis(rot, -2, 0)]
+
+  [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
+
+  # pylint: disable=bad-whitespace
+  k = [[ xx + yy + zz,      zy - yz,      xz - zx,      yx - xy,],
+       [      zy - yz, xx - yy - zz,      xy + yx,      xz + zx,],
+       [      xz - zx,      xy + yx, yy - xx - zz,      yz + zy,],
+       [      yx - xy,      xz + zx,      yz + zy, zz - xx - yy,]]
+  # pylint: enable=bad-whitespace
+
+  k = (1./3.) * jnp.stack([jnp.stack(x, axis=-1) for x in k],
+                          axis=-2)
+
+  # Get eigenvalues in non-decreasing order and associated.
+  _, qs = jnp.linalg.eigh(k)
+  return qs[..., -1]
+
+
+def rot_list_to_tensor(rot_list):
+  """Convert list of lists to rotation tensor."""
+  return jnp.stack(
+      [jnp.stack(rot_list[0], axis=-1),
+       jnp.stack(rot_list[1], axis=-1),
+       jnp.stack(rot_list[2], axis=-1)],
+      axis=-2)
+
+
+def vec_list_to_tensor(vec_list):
+  """Convert list to vector tensor."""
+  return jnp.stack(vec_list, axis=-1)
+
+
+def quat_to_rot(normalized_quat):
+  """Convert a normalized quaternion to a rotation matrix."""
+  rot_tensor = jnp.sum(
+      np.reshape(QUAT_TO_ROT, (4, 4, 9)) *
+      normalized_quat[..., :, None, None] *
+      normalized_quat[..., None, :, None],
+      axis=(-3, -2))
+  rot = jnp.moveaxis(rot_tensor, -1, 0)  # Unstack.
+  return [[rot[0], rot[1], rot[2]],
+          [rot[3], rot[4], rot[5]],
+          [rot[6], rot[7], rot[8]]]
+
+
+def quat_multiply_by_vec(quat, vec):
+  """Multiply a quaternion by a pure-vector quaternion."""
+  return jnp.sum(
+      QUAT_MULTIPLY_BY_VEC *
+      quat[..., :, None, None] *
+      vec[..., None, :, None],
+      axis=(-3, -2))
+
+
+def quat_multiply(quat1, quat2):
+  """Multiply a quaternion by another quaternion."""
+  return jnp.sum(
+      QUAT_MULTIPLY *
+      quat1[..., :, None, None] *
+      quat2[..., None, :, None],
+      axis=(-3, -2))
+
+
+def apply_rot_to_vec(rot, vec, unstack=False):
+  """Multiply rotation matrix by a vector."""
+  if unstack:
+    x, y, z = [vec[:, i] for i in range(3)]
+  else:
+    x, y, z = vec
+  return [rot[0][0] * x + rot[0][1] * y + rot[0][2] * z,
+          rot[1][0] * x + rot[1][1] * y + rot[1][2] * z,
+          rot[2][0] * x + rot[2][1] * y + rot[2][2] * z]
+
+
+def apply_inverse_rot_to_vec(rot, vec):
+  """Multiply the inverse of a rotation matrix by a vector."""
+  # Inverse rotation is just transpose
+  return [rot[0][0] * vec[0] + rot[1][0] * vec[1] + rot[2][0] * vec[2],
+          rot[0][1] * vec[0] + rot[1][1] * vec[1] + rot[2][1] * vec[2],
+          rot[0][2] * vec[0] + rot[1][2] * vec[1] + rot[2][2] * vec[2]]
+
+
+class QuatAffine(object):
+  """Affine transformation represented by quaternion and vector."""
+
+  def __init__(self, quaternion, translation, rotation=None, normalize=True,
+               unstack_inputs=False):
+    """Initialize from quaternion and translation.
+
+    Args:
+      quaternion: Rotation represented by a quaternion, to be applied
+        before translation.  Must be a unit quaternion unless normalize==True.
+      translation: Translation represented as a vector.
+      rotation: Same rotation as the quaternion, represented as a (..., 3, 3)
+        tensor.  If None, rotation will be calculated from the quaternion.
+      normalize: If True, l2 normalize the quaternion on input.
+      unstack_inputs: If True, translation is a vector with last component 3
+    """
+
+    if quaternion is not None:
+      assert quaternion.shape[-1] == 4
+
+    if unstack_inputs:
+      if rotation is not None:
+        rotation = [jnp.moveaxis(x, -1, 0)   # Unstack.
+                    for x in jnp.moveaxis(rotation, -2, 0)]  # Unstack.
+      translation = jnp.moveaxis(translation, -1, 0)  # Unstack.
+
+    if normalize and quaternion is not None:
+      quaternion = quaternion / jnp.linalg.norm(quaternion, axis=-1,
+                                                keepdims=True)
+
+    if rotation is None:
+      rotation = quat_to_rot(quaternion)
+
+    self.quaternion = quaternion
+    self.rotation = [list(row) for row in rotation]
+    self.translation = list(translation)
+
+    assert all(len(row) == 3 for row in self.rotation)
+    assert len(self.translation) == 3
+
+  def to_tensor(self):
+    return jnp.concatenate(
+        [self.quaternion] +
+        [jnp.expand_dims(x, axis=-1) for x in self.translation],
+        axis=-1)
+
+  def apply_tensor_fn(self, tensor_fn):
+    """Return a new QuatAffine with tensor_fn applied (e.g. stop_gradient)."""
+    return QuatAffine(
+        tensor_fn(self.quaternion),
+        [tensor_fn(x) for x in self.translation],
+        rotation=[[tensor_fn(x) for x in row] for row in self.rotation],
+        normalize=False)
+
+  def apply_rotation_tensor_fn(self, tensor_fn):
+    """Return a new QuatAffine with tensor_fn applied to the rotation part."""
+    return QuatAffine(
+        tensor_fn(self.quaternion),
+        [x for x in self.translation],
+        rotation=[[tensor_fn(x) for x in row] for row in self.rotation],
+        normalize=False)
+
+  def scale_translation(self, position_scale):
+    """Return a new quat affine with a different scale for translation."""
+
+    return QuatAffine(
+        self.quaternion,
+        [x * position_scale for x in self.translation],
+        rotation=[[x for x in row] for row in self.rotation],
+        normalize=False)
+
+  @classmethod
+  def from_tensor(cls, tensor, normalize=False):
+    quaternion, tx, ty, tz = jnp.split(tensor, [4, 5, 6], axis=-1)
+    return cls(quaternion,
+               [tx[..., 0], ty[..., 0], tz[..., 0]],
+               normalize=normalize)
+
+  def pre_compose(self, update):
+    """Return a new QuatAffine which applies the transformation update first.
+
+    Args:
+      update: Length-6 vector. 3-vector of x, y, and z such that the quaternion
+        update is (1, x, y, z) and zero for the 3-vector is the identity
+        quaternion. 3-vector for translation concatenated.
+
+    Returns:
+      New QuatAffine object.
+    """
+    vector_quaternion_update, x, y, z = jnp.split(update, [3, 4, 5], axis=-1)
+    trans_update = [jnp.squeeze(x, axis=-1),
+                    jnp.squeeze(y, axis=-1),
+                    jnp.squeeze(z, axis=-1)]
+
+    new_quaternion = (self.quaternion +
+                      quat_multiply_by_vec(self.quaternion,
+                                           vector_quaternion_update))
+
+    trans_update = apply_rot_to_vec(self.rotation, trans_update)
+    new_translation = [
+        self.translation[0] + trans_update[0],
+        self.translation[1] + trans_update[1],
+        self.translation[2] + trans_update[2]]
+
+    return QuatAffine(new_quaternion, new_translation)
+
+  def apply_to_point(self, point, extra_dims=0):
+    """Apply affine to a point.
+
+    Args:
+      point: List of 3 tensors to apply affine.
+      extra_dims:  Number of dimensions at the end of the transformed_point
+        shape that are not present in the rotation and translation.  The most
+        common use is rotation N points at once with extra_dims=1 for use in a
+        network.
+
+    Returns:
+      Transformed point after applying affine.
+    """
+    rotation = self.rotation
+    translation = self.translation
+    for _ in range(extra_dims):
+      expand_fn = functools.partial(jnp.expand_dims, axis=-1)
+      rotation = jax.tree_map(expand_fn, rotation)
+      translation = jax.tree_map(expand_fn, translation)
+
+    rot_point = apply_rot_to_vec(rotation, point)
+    return [
+        rot_point[0] + translation[0],
+        rot_point[1] + translation[1],
+        rot_point[2] + translation[2]]
+
+  def invert_point(self, transformed_point, extra_dims=0):
+    """Apply inverse of transformation to a point.
+
+    Args:
+      transformed_point: List of 3 tensors to apply affine
+      extra_dims:  Number of dimensions at the end of the transformed_point
+        shape that are not present in the rotation and translation.  The most
+        common use is rotation N points at once with extra_dims=1 for use in a
+        network.
+
+    Returns:
+      Transformed point after applying affine.
+    """
+    rotation = self.rotation
+    translation = self.translation
+    for _ in range(extra_dims):
+      expand_fn = functools.partial(jnp.expand_dims, axis=-1)
+      rotation = jax.tree_map(expand_fn, rotation)
+      translation = jax.tree_map(expand_fn, translation)
+
+    rot_point = [
+        transformed_point[0] - translation[0],
+        transformed_point[1] - translation[1],
+        transformed_point[2] - translation[2]]
+
+    return apply_inverse_rot_to_vec(rotation, rot_point)
+
+  def __repr__(self):
+    return 'QuatAffine(%r, %r)' % (self.quaternion, self.translation)
+
+
+def _multiply(a, b):
+  return jnp.stack([
+      jnp.array([a[0][0]*b[0][0] + a[0][1]*b[1][0] + a[0][2]*b[2][0],
+                 a[0][0]*b[0][1] + a[0][1]*b[1][1] + a[0][2]*b[2][1],
+                 a[0][0]*b[0][2] + a[0][1]*b[1][2] + a[0][2]*b[2][2]]),
+
+      jnp.array([a[1][0]*b[0][0] + a[1][1]*b[1][0] + a[1][2]*b[2][0],
+                 a[1][0]*b[0][1] + a[1][1]*b[1][1] + a[1][2]*b[2][1],
+                 a[1][0]*b[0][2] + a[1][1]*b[1][2] + a[1][2]*b[2][2]]),
+
+      jnp.array([a[2][0]*b[0][0] + a[2][1]*b[1][0] + a[2][2]*b[2][0],
+                 a[2][0]*b[0][1] + a[2][1]*b[1][1] + a[2][2]*b[2][1],
+                 a[2][0]*b[0][2] + a[2][1]*b[1][2] + a[2][2]*b[2][2]])])
+
+
+def make_canonical_transform(
+    n_xyz: jnp.ndarray,
+    ca_xyz: jnp.ndarray,
+    c_xyz: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
+  """Returns translation and rotation matrices to canonicalize residue atoms.
+
+  Note that this method does not take care of symmetries. If you provide the
+  atom positions in the non-standard way, the N atom will end up not at
+  [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
+  need to take care of such cases in your code.
+
+  Args:
+    n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.
+    ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.
+    c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.
+
+  Returns:
+    A tuple (translation, rotation) where:
+      translation is an array of shape [batch, 3] defining the translation.
+      rotation is an array of shape [batch, 3, 3] defining the rotation.
+    After applying the translation and rotation to all atoms in a residue:
+      * All atoms will be shifted so that CA is at the origin,
+      * All atoms will be rotated so that C is at the x-axis,
+      * All atoms will be shifted so that N is in the xy plane.
+  """
+  assert len(n_xyz.shape) == 2, n_xyz.shape
+  assert n_xyz.shape[-1] == 3, n_xyz.shape
+  assert n_xyz.shape == ca_xyz.shape == c_xyz.shape, (
+      n_xyz.shape, ca_xyz.shape, c_xyz.shape)
+
+  # Place CA at the origin.
+  translation = -ca_xyz
+  n_xyz = n_xyz + translation
+  c_xyz = c_xyz + translation
+
+  # Place C on the x-axis.
+  c_x, c_y, c_z = [c_xyz[:, i] for i in range(3)]
+  # Rotate by angle c1 in the x-y plane (around the z-axis).
+  sin_c1 = -c_y / jnp.sqrt(1e-20 + c_x**2 + c_y**2)
+  cos_c1 = c_x / jnp.sqrt(1e-20 + c_x**2 + c_y**2)
+  zeros = jnp.zeros_like(sin_c1)
+  ones = jnp.ones_like(sin_c1)
+  # pylint: disable=bad-whitespace
+  c1_rot_matrix = jnp.stack([jnp.array([cos_c1, -sin_c1, zeros]),
+                             jnp.array([sin_c1,  cos_c1, zeros]),
+                             jnp.array([zeros,    zeros,  ones])])
+
+  # Rotate by angle c2 in the x-z plane (around the y-axis).
+  sin_c2 = c_z / jnp.sqrt(1e-20 + c_x**2 + c_y**2 + c_z**2)
+  cos_c2 = jnp.sqrt(c_x**2 + c_y**2) / jnp.sqrt(
+      1e-20 + c_x**2 + c_y**2 + c_z**2)
+  c2_rot_matrix = jnp.stack([jnp.array([cos_c2,  zeros, sin_c2]),
+                             jnp.array([zeros,    ones,  zeros]),
+                             jnp.array([-sin_c2, zeros, cos_c2])])
+
+  c_rot_matrix = _multiply(c2_rot_matrix, c1_rot_matrix)
+  n_xyz = jnp.stack(apply_rot_to_vec(c_rot_matrix, n_xyz, unstack=True)).T
+
+  # Place N in the x-y plane.
+  _, n_y, n_z = [n_xyz[:, i] for i in range(3)]
+  # Rotate by angle alpha in the y-z plane (around the x-axis).
+  sin_n = -n_z / jnp.sqrt(1e-20 + n_y**2 + n_z**2)
+  cos_n = n_y / jnp.sqrt(1e-20 + n_y**2 + n_z**2)
+  n_rot_matrix = jnp.stack([jnp.array([ones,  zeros,  zeros]),
+                            jnp.array([zeros, cos_n, -sin_n]),
+                            jnp.array([zeros, sin_n,  cos_n])])
+  # pylint: enable=bad-whitespace
+
+  return (translation,
+          jnp.transpose(_multiply(n_rot_matrix, c_rot_matrix), [2, 0, 1]))
+
+
+def make_transform_from_reference(
+    n_xyz: jnp.ndarray,
+    ca_xyz: jnp.ndarray,
+    c_xyz: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
+  """Returns rotation and translation matrices to convert from reference.
+
+  Note that this method does not take care of symmetries. If you provide the
+  atom positions in the non-standard way, the N atom will end up not at
+  [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You
+  need to take care of such cases in your code.
+
+  Args:
+    n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.
+    ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.
+    c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.
+
+  Returns:
+    A tuple (rotation, translation) where:
+      rotation is an array of shape [batch, 3, 3] defining the rotation.
+      translation is an array of shape [batch, 3] defining the translation.
+    After applying the translation and rotation to the reference backbone,
+    the coordinates will approximately equal to the input coordinates.
+
+    The order of translation and rotation differs from make_canonical_transform
+    because the rotation from this function should be applied before the
+    translation, unlike make_canonical_transform.
+  """
+  translation, rotation = make_canonical_transform(n_xyz, ca_xyz, c_xyz)
+  return np.transpose(rotation, (0, 2, 1)), -translation