view docker/alphafold/alphafold/model/tf/data_transforms.py @ 1:6c92e000d684 draft

"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author galaxy-australia
date Tue, 01 Mar 2022 02:53:05 +0000
parents
children
line wrap: on
line source

# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Data for AlphaFold."""

from alphafold.common import residue_constants
from alphafold.model.tf import shape_helpers
from alphafold.model.tf import shape_placeholders
from alphafold.model.tf import utils
import numpy as np
import tensorflow.compat.v1 as tf

# Pylint gets confused by the curry1 decorator because it changes the number
#   of arguments to the function.
# pylint:disable=no-value-for-parameter


NUM_RES = shape_placeholders.NUM_RES
NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ
NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ
NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES


def cast_64bit_ints(protein):

  for k, v in protein.items():
    if v.dtype == tf.int64:
      protein[k] = tf.cast(v, tf.int32)
  return protein


_MSA_FEATURE_NAMES = [
    'msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask',
    'true_msa'
]


def make_seq_mask(protein):
  protein['seq_mask'] = tf.ones(
      shape_helpers.shape_list(protein['aatype']), dtype=tf.float32)
  return protein


def make_template_mask(protein):
  protein['template_mask'] = tf.ones(
      shape_helpers.shape_list(protein['template_domain_names']),
      dtype=tf.float32)
  return protein


def curry1(f):
  """Supply all arguments but the first."""

  def fc(*args, **kwargs):
    return lambda x: f(x, *args, **kwargs)

  return fc


@curry1
def add_distillation_flag(protein, distillation):
  protein['is_distillation'] = tf.constant(float(distillation),
                                           shape=[],
                                           dtype=tf.float32)
  return protein


def make_all_atom_aatype(protein):
  protein['all_atom_aatype'] = protein['aatype']
  return protein


def fix_templates_aatype(protein):
  """Fixes aatype encoding of templates."""
  # Map one-hot to indices.
  protein['template_aatype'] = tf.argmax(
      protein['template_aatype'], output_type=tf.int32, axis=-1)
  # Map hhsearch-aatype to our aatype.
  new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
  new_order = tf.constant(new_order_list, dtype=tf.int32)
  protein['template_aatype'] = tf.gather(params=new_order,
                                         indices=protein['template_aatype'])
  return protein


def correct_msa_restypes(protein):
  """Correct MSA restype to have the same order as residue_constants."""
  new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
  new_order = tf.constant(new_order_list, dtype=protein['msa'].dtype)
  protein['msa'] = tf.gather(new_order, protein['msa'], axis=0)

  perm_matrix = np.zeros((22, 22), dtype=np.float32)
  perm_matrix[range(len(new_order_list)), new_order_list] = 1.

  for k in protein:
    if 'profile' in k:  # Include both hhblits and psiblast profiles
      num_dim = protein[k].shape.as_list()[-1]
      assert num_dim in [20, 21, 22], (
          'num_dim for %s out of expected range: %s' % (k, num_dim))
      protein[k] = tf.tensordot(protein[k], perm_matrix[:num_dim, :num_dim], 1)
  return protein


def squeeze_features(protein):
  """Remove singleton and repeated dimensions in protein features."""
  protein['aatype'] = tf.argmax(
      protein['aatype'], axis=-1, output_type=tf.int32)
  for k in [
      'domain_name', 'msa', 'num_alignments', 'seq_length', 'sequence',
      'superfamily', 'deletion_matrix', 'resolution',
      'between_segment_residues', 'residue_index', 'template_all_atom_masks']:
    if k in protein:
      final_dim = shape_helpers.shape_list(protein[k])[-1]
      if isinstance(final_dim, int) and final_dim == 1:
        protein[k] = tf.squeeze(protein[k], axis=-1)

  for k in ['seq_length', 'num_alignments']:
    if k in protein:
      protein[k] = protein[k][0]  # Remove fake sequence dimension
  return protein


def make_random_crop_to_size_seed(protein):
  """Random seed for cropping residues and templates."""
  protein['random_crop_to_size_seed'] = utils.make_random_seed()
  return protein


@curry1
def randomly_replace_msa_with_unknown(protein, replace_proportion):
  """Replace a proportion of the MSA with 'X'."""
  msa_mask = (tf.random.uniform(shape_helpers.shape_list(protein['msa'])) <
              replace_proportion)
  x_idx = 20
  gap_idx = 21
  msa_mask = tf.logical_and(msa_mask, protein['msa'] != gap_idx)
  protein['msa'] = tf.where(msa_mask,
                            tf.ones_like(protein['msa']) * x_idx,
                            protein['msa'])
  aatype_mask = (
      tf.random.uniform(shape_helpers.shape_list(protein['aatype'])) <
      replace_proportion)

  protein['aatype'] = tf.where(aatype_mask,
                               tf.ones_like(protein['aatype']) * x_idx,
                               protein['aatype'])
  return protein


@curry1
def sample_msa(protein, max_seq, keep_extra):
  """Sample MSA randomly, remaining sequences are stored as `extra_*`.

  Args:
    protein: batch to sample msa from.
    max_seq: number of sequences to sample.
    keep_extra: When True sequences not sampled are put into fields starting
      with 'extra_*'.

  Returns:
    Protein with sampled msa.
  """
  num_seq = tf.shape(protein['msa'])[0]
  shuffled = tf.random_shuffle(tf.range(1, num_seq))
  index_order = tf.concat([[0], shuffled], axis=0)
  num_sel = tf.minimum(max_seq, num_seq)

  sel_seq, not_sel_seq = tf.split(index_order, [num_sel, num_seq - num_sel])

  for k in _MSA_FEATURE_NAMES:
    if k in protein:
      if keep_extra:
        protein['extra_' + k] = tf.gather(protein[k], not_sel_seq)
      protein[k] = tf.gather(protein[k], sel_seq)

  return protein


@curry1
def crop_extra_msa(protein, max_extra_msa):
  """MSA features are cropped so only `max_extra_msa` sequences are kept."""
  num_seq = tf.shape(protein['extra_msa'])[0]
  num_sel = tf.minimum(max_extra_msa, num_seq)
  select_indices = tf.random_shuffle(tf.range(0, num_seq))[:num_sel]
  for k in _MSA_FEATURE_NAMES:
    if 'extra_' + k in protein:
      protein['extra_' + k] = tf.gather(protein['extra_' + k], select_indices)

  return protein


def delete_extra_msa(protein):
  for k in _MSA_FEATURE_NAMES:
    if 'extra_' + k in protein:
      del protein['extra_' + k]
  return protein


@curry1
def block_delete_msa(protein, config):
  """Sample MSA by deleting contiguous blocks.

  Jumper et al. (2021) Suppl. Alg. 1 "MSABlockDeletion"

  Arguments:
    protein: batch dict containing the msa
    config: ConfigDict with parameters

  Returns:
    updated protein
  """
  num_seq = shape_helpers.shape_list(protein['msa'])[0]
  block_num_seq = tf.cast(
      tf.floor(tf.cast(num_seq, tf.float32) * config.msa_fraction_per_block),
      tf.int32)

  if config.randomize_num_blocks:
    nb = tf.random.uniform([], 0, config.num_blocks + 1, dtype=tf.int32)
  else:
    nb = config.num_blocks

  del_block_starts = tf.random.uniform([nb], 0, num_seq, dtype=tf.int32)
  del_blocks = del_block_starts[:, None] + tf.range(block_num_seq)
  del_blocks = tf.clip_by_value(del_blocks, 0, num_seq - 1)
  del_indices = tf.unique(tf.sort(tf.reshape(del_blocks, [-1])))[0]

  # Make sure we keep the original sequence
  sparse_diff = tf.sets.difference(tf.range(1, num_seq)[None],
                                   del_indices[None])
  keep_indices = tf.squeeze(tf.sparse.to_dense(sparse_diff), 0)
  keep_indices = tf.concat([[0], keep_indices], axis=0)

  for k in _MSA_FEATURE_NAMES:
    if k in protein:
      protein[k] = tf.gather(protein[k], keep_indices)

  return protein


@curry1
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
  """Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""

  # Determine how much weight we assign to each agreement.  In theory, we could
  # use a full blosum matrix here, but right now let's just down-weight gap
  # agreement because it could be spurious.
  # Never put weight on agreeing on BERT mask
  weights = tf.concat([
      tf.ones(21),
      gap_agreement_weight * tf.ones(1),
      np.zeros(1)], 0)

  # Make agreement score as weighted Hamming distance
  sample_one_hot = (protein['msa_mask'][:, :, None] *
                    tf.one_hot(protein['msa'], 23))
  extra_one_hot = (protein['extra_msa_mask'][:, :, None] *
                   tf.one_hot(protein['extra_msa'], 23))

  num_seq, num_res, _ = shape_helpers.shape_list(sample_one_hot)
  extra_num_seq, _, _ = shape_helpers.shape_list(extra_one_hot)

  # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
  # in an optimized fashion to avoid possible memory or computation blowup.
  agreement = tf.matmul(
      tf.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
      tf.reshape(sample_one_hot * weights, [num_seq, num_res * 23]),
      transpose_b=True)

  # Assign each sequence in the extra sequences to the closest MSA sample
  protein['extra_cluster_assignment'] = tf.argmax(
      agreement, axis=1, output_type=tf.int32)

  return protein


@curry1
def summarize_clusters(protein):
  """Produce profile and deletion_matrix_mean within each cluster."""
  num_seq = shape_helpers.shape_list(protein['msa'])[0]
  def csum(x):
    return tf.math.unsorted_segment_sum(
        x, protein['extra_cluster_assignment'], num_seq)

  mask = protein['extra_msa_mask']
  mask_counts = 1e-6 + protein['msa_mask'] + csum(mask)  # Include center

  msa_sum = csum(mask[:, :, None] * tf.one_hot(protein['extra_msa'], 23))
  msa_sum += tf.one_hot(protein['msa'], 23)  # Original sequence
  protein['cluster_profile'] = msa_sum / mask_counts[:, :, None]

  del msa_sum

  del_sum = csum(mask * protein['extra_deletion_matrix'])
  del_sum += protein['deletion_matrix']  # Original sequence
  protein['cluster_deletion_mean'] = del_sum / mask_counts
  del del_sum

  return protein


def make_msa_mask(protein):
  """Mask features are all ones, but will later be zero-padded."""
  protein['msa_mask'] = tf.ones(
      shape_helpers.shape_list(protein['msa']), dtype=tf.float32)
  protein['msa_row_mask'] = tf.ones(
      shape_helpers.shape_list(protein['msa'])[0], dtype=tf.float32)
  return protein


def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
  """Create pseudo beta features."""
  is_gly = tf.equal(aatype, residue_constants.restype_order['G'])
  ca_idx = residue_constants.atom_order['CA']
  cb_idx = residue_constants.atom_order['CB']
  pseudo_beta = tf.where(
      tf.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
      all_atom_positions[..., ca_idx, :],
      all_atom_positions[..., cb_idx, :])

  if all_atom_masks is not None:
    pseudo_beta_mask = tf.where(
        is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx])
    pseudo_beta_mask = tf.cast(pseudo_beta_mask, tf.float32)
    return pseudo_beta, pseudo_beta_mask
  else:
    return pseudo_beta


@curry1
def make_pseudo_beta(protein, prefix=''):
  """Create pseudo-beta (alpha for glycine) position and mask."""
  assert prefix in ['', 'template_']
  protein[prefix + 'pseudo_beta'], protein[prefix + 'pseudo_beta_mask'] = (
      pseudo_beta_fn(
          protein['template_aatype' if prefix else 'all_atom_aatype'],
          protein[prefix + 'all_atom_positions'],
          protein['template_all_atom_masks' if prefix else 'all_atom_mask']))
  return protein


@curry1
def add_constant_field(protein, key, value):
  protein[key] = tf.convert_to_tensor(value)
  return protein


def shaped_categorical(probs, epsilon=1e-10):
  ds = shape_helpers.shape_list(probs)
  num_classes = ds[-1]
  counts = tf.random.categorical(
      tf.reshape(tf.log(probs + epsilon), [-1, num_classes]),
      1,
      dtype=tf.int32)
  return tf.reshape(counts, ds[:-1])


def make_hhblits_profile(protein):
  """Compute the HHblits MSA profile if not already present."""
  if 'hhblits_profile' in protein:
    return protein

  # Compute the profile for every residue (over all MSA sequences).
  protein['hhblits_profile'] = tf.reduce_mean(
      tf.one_hot(protein['msa'], 22), axis=0)
  return protein


@curry1
def make_masked_msa(protein, config, replace_fraction):
  """Create data for BERT on raw MSA."""
  # Add a random amino acid uniformly
  random_aa = tf.constant([0.05] * 20 + [0., 0.], dtype=tf.float32)

  categorical_probs = (
      config.uniform_prob * random_aa +
      config.profile_prob * protein['hhblits_profile'] +
      config.same_prob * tf.one_hot(protein['msa'], 22))

  # Put all remaining probability on [MASK] which is a new column
  pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))]
  pad_shapes[-1][1] = 1
  mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
  assert mask_prob >= 0.
  categorical_probs = tf.pad(
      categorical_probs, pad_shapes, constant_values=mask_prob)

  sh = shape_helpers.shape_list(protein['msa'])
  mask_position = tf.random.uniform(sh) < replace_fraction

  bert_msa = shaped_categorical(categorical_probs)
  bert_msa = tf.where(mask_position, bert_msa, protein['msa'])

  # Mix real and masked MSA
  protein['bert_mask'] = tf.cast(mask_position, tf.float32)
  protein['true_msa'] = protein['msa']
  protein['msa'] = bert_msa

  return protein


@curry1
def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size,
                    num_res, num_templates=0):
  """Guess at the MSA and sequence dimensions to make fixed size."""

  pad_size_map = {
      NUM_RES: num_res,
      NUM_MSA_SEQ: msa_cluster_size,
      NUM_EXTRA_SEQ: extra_msa_size,
      NUM_TEMPLATES: num_templates,
  }

  for k, v in protein.items():
    # Don't transfer this to the accelerator.
    if k == 'extra_cluster_assignment':
      continue
    shape = v.shape.as_list()
    schema = shape_schema[k]
    assert len(shape) == len(schema), (
        f'Rank mismatch between shape and shape schema for {k}: '
        f'{shape} vs {schema}')
    pad_size = [
        pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
    ]
    padding = [(0, p - tf.shape(v)[i]) for i, p in enumerate(pad_size)]
    if padding:
      protein[k] = tf.pad(
          v, padding, name=f'pad_to_fixed_{k}')
      protein[k].set_shape(pad_size)

  return protein


@curry1
def make_msa_feat(protein):
  """Create and concatenate MSA features."""
  # Whether there is a domain break. Always zero for chains, but keeping
  # for compatibility with domain datasets.
  has_break = tf.clip_by_value(
      tf.cast(protein['between_segment_residues'], tf.float32),
      0, 1)
  aatype_1hot = tf.one_hot(protein['aatype'], 21, axis=-1)

  target_feat = [
      tf.expand_dims(has_break, axis=-1),
      aatype_1hot,  # Everyone gets the original sequence.
  ]

  msa_1hot = tf.one_hot(protein['msa'], 23, axis=-1)
  has_deletion = tf.clip_by_value(protein['deletion_matrix'], 0., 1.)
  deletion_value = tf.atan(protein['deletion_matrix'] / 3.) * (2. / np.pi)

  msa_feat = [
      msa_1hot,
      tf.expand_dims(has_deletion, axis=-1),
      tf.expand_dims(deletion_value, axis=-1),
  ]

  if 'cluster_profile' in protein:
    deletion_mean_value = (
        tf.atan(protein['cluster_deletion_mean'] / 3.) * (2. / np.pi))
    msa_feat.extend([
        protein['cluster_profile'],
        tf.expand_dims(deletion_mean_value, axis=-1),
    ])

  if 'extra_deletion_matrix' in protein:
    protein['extra_has_deletion'] = tf.clip_by_value(
        protein['extra_deletion_matrix'], 0., 1.)
    protein['extra_deletion_value'] = tf.atan(
        protein['extra_deletion_matrix'] / 3.) * (2. / np.pi)

  protein['msa_feat'] = tf.concat(msa_feat, axis=-1)
  protein['target_feat'] = tf.concat(target_feat, axis=-1)
  return protein


@curry1
def select_feat(protein, feature_list):
  return {k: v for k, v in protein.items() if k in feature_list}


@curry1
def crop_templates(protein, max_templates):
  for k, v in protein.items():
    if k.startswith('template_'):
      protein[k] = v[:max_templates]
  return protein


@curry1
def random_crop_to_size(protein, crop_size, max_templates, shape_schema,
                        subsample_templates=False):
  """Crop randomly to `crop_size`, or keep as is if shorter than that."""
  seq_length = protein['seq_length']
  if 'template_mask' in protein:
    num_templates = tf.cast(
        shape_helpers.shape_list(protein['template_mask'])[0], tf.int32)
  else:
    num_templates = tf.constant(0, dtype=tf.int32)
  num_res_crop_size = tf.math.minimum(seq_length, crop_size)

  # Ensures that the cropping of residues and templates happens in the same way
  # across ensembling iterations.
  # Do not use for randomness that should vary in ensembling.
  seed_maker = utils.SeedMaker(initial_seed=protein['random_crop_to_size_seed'])

  if subsample_templates:
    templates_crop_start = tf.random.stateless_uniform(
        shape=(), minval=0, maxval=num_templates + 1, dtype=tf.int32,
        seed=seed_maker())
  else:
    templates_crop_start = 0

  num_templates_crop_size = tf.math.minimum(
      num_templates - templates_crop_start, max_templates)

  num_res_crop_start = tf.random.stateless_uniform(
      shape=(), minval=0, maxval=seq_length - num_res_crop_size + 1,
      dtype=tf.int32, seed=seed_maker())

  templates_select_indices = tf.argsort(tf.random.stateless_uniform(
      [num_templates], seed=seed_maker()))

  for k, v in protein.items():
    if k not in shape_schema or (
        'template' not in k and NUM_RES not in shape_schema[k]):
      continue

    # randomly permute the templates before cropping them.
    if k.startswith('template') and subsample_templates:
      v = tf.gather(v, templates_select_indices)

    crop_sizes = []
    crop_starts = []
    for i, (dim_size, dim) in enumerate(zip(shape_schema[k],
                                            shape_helpers.shape_list(v))):
      is_num_res = (dim_size == NUM_RES)
      if i == 0 and k.startswith('template'):
        crop_size = num_templates_crop_size
        crop_start = templates_crop_start
      else:
        crop_start = num_res_crop_start if is_num_res else 0
        crop_size = (num_res_crop_size if is_num_res else
                     (-1 if dim is None else dim))
      crop_sizes.append(crop_size)
      crop_starts.append(crop_start)
    protein[k] = tf.slice(v, crop_starts, crop_sizes)

  protein['seq_length'] = num_res_crop_size
  return protein


def make_atom14_masks(protein):
  """Construct denser atom positions (14 dimensions instead of 37)."""
  restype_atom14_to_atom37 = []  # mapping (restype, atom14) --> atom37
  restype_atom37_to_atom14 = []  # mapping (restype, atom37) --> atom14
  restype_atom14_mask = []

  for rt in residue_constants.restypes:
    atom_names = residue_constants.restype_name_to_atom14_names[
        residue_constants.restype_1to3[rt]]

    restype_atom14_to_atom37.append([
        (residue_constants.atom_order[name] if name else 0)
        for name in atom_names
    ])

    atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
    restype_atom37_to_atom14.append([
        (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
        for name in residue_constants.atom_types
    ])

    restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])

  # Add dummy mapping for restype 'UNK'
  restype_atom14_to_atom37.append([0] * 14)
  restype_atom37_to_atom14.append([0] * 37)
  restype_atom14_mask.append([0.] * 14)

  restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)
  restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)
  restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)

  # create the mapping for (residx, atom14) --> atom37, i.e. an array
  # with shape (num_res, 14) containing the atom37 indices for this protein
  residx_atom14_to_atom37 = tf.gather(restype_atom14_to_atom37,
                                      protein['aatype'])
  residx_atom14_mask = tf.gather(restype_atom14_mask,
                                 protein['aatype'])

  protein['atom14_atom_exists'] = residx_atom14_mask
  protein['residx_atom14_to_atom37'] = residx_atom14_to_atom37

  # create the gather indices for mapping back
  residx_atom37_to_atom14 = tf.gather(restype_atom37_to_atom14,
                                      protein['aatype'])
  protein['residx_atom37_to_atom14'] = residx_atom37_to_atom14

  # create the corresponding mask
  restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
  for restype, restype_letter in enumerate(residue_constants.restypes):
    restype_name = residue_constants.restype_1to3[restype_letter]
    atom_names = residue_constants.residue_atoms[restype_name]
    for atom_name in atom_names:
      atom_type = residue_constants.atom_order[atom_name]
      restype_atom37_mask[restype, atom_type] = 1

  residx_atom37_mask = tf.gather(restype_atom37_mask,
                                 protein['aatype'])
  protein['atom37_atom_exists'] = residx_atom37_mask

  return protein