Mercurial > repos > galaxy-australia > alphafold2
view docker/alphafold/alphafold/model/modules_multimer.py @ 1:6c92e000d684 draft
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author | galaxy-australia |
---|---|
date | Tue, 01 Mar 2022 02:53:05 +0000 |
parents | |
children |
line wrap: on
line source
# Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Core modules, which have been refactored in AlphaFold-Multimer. The main difference is that MSA sampling pipeline is moved inside the JAX model for easier implementation of recycling and ensembling. Lower-level modules up to EvoformerIteration are reused from modules.py. """ import functools from typing import Sequence from alphafold.common import residue_constants from alphafold.model import all_atom_multimer from alphafold.model import common_modules from alphafold.model import folding_multimer from alphafold.model import geometry from alphafold.model import layer_stack from alphafold.model import modules from alphafold.model import prng from alphafold.model import utils import haiku as hk import jax import jax.numpy as jnp import numpy as np def reduce_fn(x, mode): if mode == 'none' or mode is None: return jnp.asarray(x) elif mode == 'sum': return jnp.asarray(x).sum() elif mode == 'mean': return jnp.mean(jnp.asarray(x)) else: raise ValueError('Unsupported reduction option.') def gumbel_noise(key: jnp.ndarray, shape: Sequence[int]) -> jnp.ndarray: """Generate Gumbel Noise of given Shape. This generates samples from Gumbel(0, 1). Args: key: Jax random number key. shape: Shape of noise to return. Returns: Gumbel noise of given shape. """ epsilon = 1e-6 uniform = utils.padding_consistent_rng(jax.random.uniform) uniform_noise = uniform( key, shape=shape, dtype=jnp.float32, minval=0., maxval=1.) gumbel = -jnp.log(-jnp.log(uniform_noise + epsilon) + epsilon) return gumbel def gumbel_max_sample(key: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray: """Samples from a probability distribution given by 'logits'. This uses Gumbel-max trick to implement the sampling in an efficient manner. Args: key: prng key. logits: Logarithm of probabilities to sample from, probabilities can be unnormalized. Returns: Sample from logprobs in one-hot form. """ z = gumbel_noise(key, logits.shape) return jax.nn.one_hot( jnp.argmax(logits + z, axis=-1), logits.shape[-1], dtype=logits.dtype) def gumbel_argsort_sample_idx(key: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray: """Samples with replacement from a distribution given by 'logits'. This uses Gumbel trick to implement the sampling an efficient manner. For a distribution over k items this samples k times without replacement, so this is effectively sampling a random permutation with probabilities over the permutations derived from the logprobs. Args: key: prng key. logits: Logarithm of probabilities to sample from, probabilities can be unnormalized. Returns: Sample from logprobs in one-hot form. """ z = gumbel_noise(key, logits.shape) # This construction is equivalent to jnp.argsort, but using a non stable sort, # since stable sort's aren't supported by jax2tf. axis = len(logits.shape) - 1 iota = jax.lax.broadcasted_iota(jnp.int64, logits.shape, axis) _, perm = jax.lax.sort_key_val( logits + z, iota, dimension=-1, is_stable=False) return perm[::-1] def make_masked_msa(batch, key, config, epsilon=1e-6): """Create data for BERT on raw MSA.""" # Add a random amino acid uniformly. random_aa = jnp.array([0.05] * 20 + [0., 0.], dtype=jnp.float32) categorical_probs = ( config.uniform_prob * random_aa + config.profile_prob * batch['msa_profile'] + config.same_prob * jax.nn.one_hot(batch['msa'], 22)) # Put all remaining probability on [MASK] which is a new column. pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))] pad_shapes[-1][1] = 1 mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob assert mask_prob >= 0. categorical_probs = jnp.pad( categorical_probs, pad_shapes, constant_values=mask_prob) sh = batch['msa'].shape key, mask_subkey, gumbel_subkey = key.split(3) uniform = utils.padding_consistent_rng(jax.random.uniform) mask_position = uniform(mask_subkey.get(), sh) < config.replace_fraction mask_position *= batch['msa_mask'] logits = jnp.log(categorical_probs + epsilon) bert_msa = gumbel_max_sample(gumbel_subkey.get(), logits) bert_msa = jnp.where(mask_position, jnp.argmax(bert_msa, axis=-1), batch['msa']) bert_msa *= batch['msa_mask'] # Mix real and masked MSA. if 'bert_mask' in batch: batch['bert_mask'] *= mask_position.astype(jnp.float32) else: batch['bert_mask'] = mask_position.astype(jnp.float32) batch['true_msa'] = batch['msa'] batch['msa'] = bert_msa return batch def nearest_neighbor_clusters(batch, gap_agreement_weight=0.): """Assign each extra MSA sequence to its nearest neighbor in sampled MSA.""" # Determine how much weight we assign to each agreement. In theory, we could # use a full blosum matrix here, but right now let's just down-weight gap # agreement because it could be spurious. # Never put weight on agreeing on BERT mask. weights = jnp.array( [1.] * 21 + [gap_agreement_weight] + [0.], dtype=jnp.float32) msa_mask = batch['msa_mask'] msa_one_hot = jax.nn.one_hot(batch['msa'], 23) extra_mask = batch['extra_msa_mask'] extra_one_hot = jax.nn.one_hot(batch['extra_msa'], 23) msa_one_hot_masked = msa_mask[:, :, None] * msa_one_hot extra_one_hot_masked = extra_mask[:, :, None] * extra_one_hot agreement = jnp.einsum('mrc, nrc->nm', extra_one_hot_masked, weights * msa_one_hot_masked) cluster_assignment = jax.nn.softmax(1e3 * agreement, axis=0) cluster_assignment *= jnp.einsum('mr, nr->mn', msa_mask, extra_mask) cluster_count = jnp.sum(cluster_assignment, axis=-1) cluster_count += 1. # We always include the sequence itself. msa_sum = jnp.einsum('nm, mrc->nrc', cluster_assignment, extra_one_hot_masked) msa_sum += msa_one_hot_masked cluster_profile = msa_sum / cluster_count[:, None, None] extra_deletion_matrix = batch['extra_deletion_matrix'] deletion_matrix = batch['deletion_matrix'] del_sum = jnp.einsum('nm, mc->nc', cluster_assignment, extra_mask * extra_deletion_matrix) del_sum += deletion_matrix # Original sequence. cluster_deletion_mean = del_sum / cluster_count[:, None] return cluster_profile, cluster_deletion_mean def create_msa_feat(batch): """Create and concatenate MSA features.""" msa_1hot = jax.nn.one_hot(batch['msa'], 23) deletion_matrix = batch['deletion_matrix'] has_deletion = jnp.clip(deletion_matrix, 0., 1.)[..., None] deletion_value = (jnp.arctan(deletion_matrix / 3.) * (2. / jnp.pi))[..., None] deletion_mean_value = (jnp.arctan(batch['cluster_deletion_mean'] / 3.) * (2. / jnp.pi))[..., None] msa_feat = [ msa_1hot, has_deletion, deletion_value, batch['cluster_profile'], deletion_mean_value ] return jnp.concatenate(msa_feat, axis=-1) def create_extra_msa_feature(batch, num_extra_msa): """Expand extra_msa into 1hot and concat with other extra msa features. We do this as late as possible as the one_hot extra msa can be very large. Args: batch: a dictionary with the following keys: * 'extra_msa': [num_seq, num_res] MSA that wasn't selected as a cluster centre. Note - This isn't one-hotted. * 'extra_deletion_matrix': [num_seq, num_res] Number of deletions at given position. num_extra_msa: Number of extra msa to use. Returns: Concatenated tensor of extra MSA features. """ # 23 = 20 amino acids + 'X' for unknown + gap + bert mask extra_msa = batch['extra_msa'][:num_extra_msa] deletion_matrix = batch['extra_deletion_matrix'][:num_extra_msa] msa_1hot = jax.nn.one_hot(extra_msa, 23) has_deletion = jnp.clip(deletion_matrix, 0., 1.)[..., None] deletion_value = (jnp.arctan(deletion_matrix / 3.) * (2. / jnp.pi))[..., None] extra_msa_mask = batch['extra_msa_mask'][:num_extra_msa] return jnp.concatenate([msa_1hot, has_deletion, deletion_value], axis=-1), extra_msa_mask def sample_msa(key, batch, max_seq): """Sample MSA randomly, remaining sequences are stored as `extra_*`. Args: key: safe key for random number generation. batch: batch to sample msa from. max_seq: number of sequences to sample. Returns: Protein with sampled msa. """ # Sample uniformly among sequences with at least one non-masked position. logits = (jnp.clip(jnp.sum(batch['msa_mask'], axis=-1), 0., 1.) - 1.) * 1e6 # The cluster_bias_mask can be used to preserve the first row (target # sequence) for each chain, for example. if 'cluster_bias_mask' not in batch: cluster_bias_mask = jnp.pad( jnp.zeros(batch['msa'].shape[0] - 1), (1, 0), constant_values=1.) else: cluster_bias_mask = batch['cluster_bias_mask'] logits += cluster_bias_mask * 1e6 index_order = gumbel_argsort_sample_idx(key.get(), logits) sel_idx = index_order[:max_seq] extra_idx = index_order[max_seq:] for k in ['msa', 'deletion_matrix', 'msa_mask', 'bert_mask']: if k in batch: batch['extra_' + k] = batch[k][extra_idx] batch[k] = batch[k][sel_idx] return batch def make_msa_profile(batch): """Compute the MSA profile.""" # Compute the profile for every residue (over all MSA sequences). return utils.mask_mean( batch['msa_mask'][:, :, None], jax.nn.one_hot(batch['msa'], 22), axis=0) class AlphaFoldIteration(hk.Module): """A single recycling iteration of AlphaFold architecture. Computes ensembled (averaged) representations from the provided features. These representations are then passed to the various heads that have been requested by the configuration file. """ def __init__(self, config, global_config, name='alphafold_iteration'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, batch, is_training, return_representations=False, safe_key=None): if is_training: num_ensemble = np.asarray(self.config.num_ensemble_train) else: num_ensemble = np.asarray(self.config.num_ensemble_eval) # Compute representations for each MSA sample and average. embedding_module = EmbeddingsAndEvoformer( self.config.embeddings_and_evoformer, self.global_config) repr_shape = hk.eval_shape( lambda: embedding_module(batch, is_training)) representations = { k: jnp.zeros(v.shape, v.dtype) for (k, v) in repr_shape.items() } def ensemble_body(x, unused_y): """Add into representations ensemble.""" del unused_y representations, safe_key = x safe_key, safe_subkey = safe_key.split() representations_update = embedding_module( batch, is_training, safe_key=safe_subkey) for k in representations: if k not in {'msa', 'true_msa', 'bert_mask'}: representations[k] += representations_update[k] * ( 1. / num_ensemble).astype(representations[k].dtype) else: representations[k] = representations_update[k] return (representations, safe_key), None (representations, _), _ = hk.scan( ensemble_body, (representations, safe_key), None, length=num_ensemble) self.representations = representations self.batch = batch self.heads = {} for head_name, head_config in sorted(self.config.heads.items()): if not head_config.weight: continue # Do not instantiate zero-weight heads. head_factory = { 'masked_msa': modules.MaskedMsaHead, 'distogram': modules.DistogramHead, 'structure_module': folding_multimer.StructureModule, 'predicted_aligned_error': modules.PredictedAlignedErrorHead, 'predicted_lddt': modules.PredictedLDDTHead, 'experimentally_resolved': modules.ExperimentallyResolvedHead, }[head_name] self.heads[head_name] = (head_config, head_factory(head_config, self.global_config)) structure_module_output = None if 'entity_id' in batch and 'all_atom_positions' in batch: _, fold_module = self.heads['structure_module'] structure_module_output = fold_module(representations, batch, is_training) ret = {} ret['representations'] = representations for name, (head_config, module) in self.heads.items(): if name == 'structure_module' and structure_module_output is not None: ret[name] = structure_module_output representations['structure_module'] = structure_module_output.pop('act') # Skip confidence heads until StructureModule is executed. elif name in {'predicted_lddt', 'predicted_aligned_error', 'experimentally_resolved'}: continue else: ret[name] = module(representations, batch, is_training) # Add confidence heads after StructureModule is executed. if self.config.heads.get('predicted_lddt.weight', 0.0): name = 'predicted_lddt' head_config, module = self.heads[name] ret[name] = module(representations, batch, is_training) if self.config.heads.experimentally_resolved.weight: name = 'experimentally_resolved' head_config, module = self.heads[name] ret[name] = module(representations, batch, is_training) if self.config.heads.get('predicted_aligned_error.weight', 0.0): name = 'predicted_aligned_error' head_config, module = self.heads[name] ret[name] = module(representations, batch, is_training) # Will be used for ipTM computation. ret[name]['asym_id'] = batch['asym_id'] return ret class AlphaFold(hk.Module): """AlphaFold-Multimer model with recycling. """ def __init__(self, config, name='alphafold'): super().__init__(name=name) self.config = config self.global_config = config.global_config def __call__( self, batch, is_training, return_representations=False, safe_key=None): c = self.config impl = AlphaFoldIteration(c, self.global_config) if safe_key is None: safe_key = prng.SafeKey(hk.next_rng_key()) elif isinstance(safe_key, jnp.ndarray): safe_key = prng.SafeKey(safe_key) assert isinstance(batch, dict) num_res = batch['aatype'].shape[0] def get_prev(ret): new_prev = { 'prev_pos': ret['structure_module']['final_atom_positions'], 'prev_msa_first_row': ret['representations']['msa_first_row'], 'prev_pair': ret['representations']['pair'], } return jax.tree_map(jax.lax.stop_gradient, new_prev) def apply_network(prev, safe_key): recycled_batch = {**batch, **prev} return impl( batch=recycled_batch, is_training=is_training, safe_key=safe_key) if self.config.num_recycle: emb_config = self.config.embeddings_and_evoformer prev = { 'prev_pos': jnp.zeros([num_res, residue_constants.atom_type_num, 3]), 'prev_msa_first_row': jnp.zeros([num_res, emb_config.msa_channel]), 'prev_pair': jnp.zeros([num_res, num_res, emb_config.pair_channel]), } if 'num_iter_recycling' in batch: # Training time: num_iter_recycling is in batch. # Value for each ensemble batch is the same, so arbitrarily taking 0-th. num_iter = batch['num_iter_recycling'][0] # Add insurance that even when ensembling, we will not run more # recyclings than the model is configured to run. num_iter = jnp.minimum(num_iter, c.num_recycle) else: # Eval mode or tests: use the maximum number of iterations. num_iter = c.num_recycle def recycle_body(i, x): del i prev, safe_key = x safe_key1, safe_key2 = safe_key.split() if c.resample_msa_in_recycling else safe_key.duplicate() # pylint: disable=line-too-long ret = apply_network(prev=prev, safe_key=safe_key2) return get_prev(ret), safe_key1 prev, safe_key = hk.fori_loop(0, num_iter, recycle_body, (prev, safe_key)) else: prev = {} # Run extra iteration. ret = apply_network(prev=prev, safe_key=safe_key) if not return_representations: del ret['representations'] return ret class EmbeddingsAndEvoformer(hk.Module): """Embeds the input data and runs Evoformer. Produces the MSA, single and pair representations. """ def __init__(self, config, global_config, name='evoformer'): super().__init__(name=name) self.config = config self.global_config = global_config def _relative_encoding(self, batch): """Add relative position encodings. For position (i, j), the value is (i-j) clipped to [-k, k] and one-hotted. When not using 'use_chain_relative' the residue indices are used as is, e.g. for heteromers relative positions will be computed using the positions in the corresponding chains. When using 'use_chain_relative' we add an extra bin that denotes 'different chain'. Furthermore we also provide the relative chain index (i.e. sym_id) clipped and one-hotted to the network. And an extra feature which denotes whether they belong to the same chain type, i.e. it's 0 if they are in different heteromer chains and 1 otherwise. Args: batch: batch. Returns: Feature embedding using the features as described before. """ c = self.config rel_feats = [] pos = batch['residue_index'] asym_id = batch['asym_id'] asym_id_same = jnp.equal(asym_id[:, None], asym_id[None, :]) offset = pos[:, None] - pos[None, :] clipped_offset = jnp.clip( offset + c.max_relative_idx, a_min=0, a_max=2 * c.max_relative_idx) if c.use_chain_relative: final_offset = jnp.where(asym_id_same, clipped_offset, (2 * c.max_relative_idx + 1) * jnp.ones_like(clipped_offset)) rel_pos = jax.nn.one_hot(final_offset, 2 * c.max_relative_idx + 2) rel_feats.append(rel_pos) entity_id = batch['entity_id'] entity_id_same = jnp.equal(entity_id[:, None], entity_id[None, :]) rel_feats.append(entity_id_same.astype(rel_pos.dtype)[..., None]) sym_id = batch['sym_id'] rel_sym_id = sym_id[:, None] - sym_id[None, :] max_rel_chain = c.max_relative_chain clipped_rel_chain = jnp.clip( rel_sym_id + max_rel_chain, a_min=0, a_max=2 * max_rel_chain) final_rel_chain = jnp.where(entity_id_same, clipped_rel_chain, (2 * max_rel_chain + 1) * jnp.ones_like(clipped_rel_chain)) rel_chain = jax.nn.one_hot(final_rel_chain, 2 * c.max_relative_chain + 2) rel_feats.append(rel_chain) else: rel_pos = jax.nn.one_hot(clipped_offset, 2 * c.max_relative_idx + 1) rel_feats.append(rel_pos) rel_feat = jnp.concatenate(rel_feats, axis=-1) return common_modules.Linear( c.pair_channel, name='position_activations')( rel_feat) def __call__(self, batch, is_training, safe_key=None): c = self.config gc = self.global_config batch = dict(batch) if safe_key is None: safe_key = prng.SafeKey(hk.next_rng_key()) output = {} batch['msa_profile'] = make_msa_profile(batch) target_feat = jax.nn.one_hot(batch['aatype'], 21) preprocess_1d = common_modules.Linear( c.msa_channel, name='preprocess_1d')( target_feat) safe_key, sample_key, mask_key = safe_key.split(3) batch = sample_msa(sample_key, batch, c.num_msa) batch = make_masked_msa(batch, mask_key, c.masked_msa) (batch['cluster_profile'], batch['cluster_deletion_mean']) = nearest_neighbor_clusters(batch) msa_feat = create_msa_feat(batch) preprocess_msa = common_modules.Linear( c.msa_channel, name='preprocess_msa')( msa_feat) msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa left_single = common_modules.Linear( c.pair_channel, name='left_single')( target_feat) right_single = common_modules.Linear( c.pair_channel, name='right_single')( target_feat) pair_activations = left_single[:, None] + right_single[None] mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :] mask_2d = mask_2d.astype(jnp.float32) if c.recycle_pos and 'prev_pos' in batch: prev_pseudo_beta = modules.pseudo_beta_fn( batch['aatype'], batch['prev_pos'], None) dgram = modules.dgram_from_positions( prev_pseudo_beta, **self.config.prev_pos) pair_activations += common_modules.Linear( c.pair_channel, name='prev_pos_linear')( dgram) if c.recycle_features: if 'prev_msa_first_row' in batch: prev_msa_first_row = hk.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='prev_msa_first_row_norm')( batch['prev_msa_first_row']) msa_activations = msa_activations.at[0].add(prev_msa_first_row) if 'prev_pair' in batch: pair_activations += hk.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='prev_pair_norm')( batch['prev_pair']) if c.max_relative_idx: pair_activations += self._relative_encoding(batch) if c.template.enabled: template_module = TemplateEmbedding(c.template, gc) template_batch = { 'template_aatype': batch['template_aatype'], 'template_all_atom_positions': batch['template_all_atom_positions'], 'template_all_atom_mask': batch['template_all_atom_mask'] } # Construct a mask such that only intra-chain template features are # computed, since all templates are for each chain individually. multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :] safe_key, safe_subkey = safe_key.split() template_act = template_module( query_embedding=pair_activations, template_batch=template_batch, padding_mask_2d=mask_2d, multichain_mask_2d=multichain_mask, is_training=is_training, safe_key=safe_subkey) pair_activations += template_act # Extra MSA stack. (extra_msa_feat, extra_msa_mask) = create_extra_msa_feature(batch, c.num_extra_msa) extra_msa_activations = common_modules.Linear( c.extra_msa_channel, name='extra_msa_activations')( extra_msa_feat) extra_msa_mask = extra_msa_mask.astype(jnp.float32) extra_evoformer_input = { 'msa': extra_msa_activations, 'pair': pair_activations, } extra_masks = {'msa': extra_msa_mask, 'pair': mask_2d} extra_evoformer_iteration = modules.EvoformerIteration( c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack') def extra_evoformer_fn(x): act, safe_key = x safe_key, safe_subkey = safe_key.split() extra_evoformer_output = extra_evoformer_iteration( activations=act, masks=extra_masks, is_training=is_training, safe_key=safe_subkey) return (extra_evoformer_output, safe_key) if gc.use_remat: extra_evoformer_fn = hk.remat(extra_evoformer_fn) safe_key, safe_subkey = safe_key.split() extra_evoformer_stack = layer_stack.layer_stack( c.extra_msa_stack_num_block)( extra_evoformer_fn) extra_evoformer_output, safe_key = extra_evoformer_stack( (extra_evoformer_input, safe_subkey)) pair_activations = extra_evoformer_output['pair'] # Get the size of the MSA before potentially adding templates, so we # can crop out the templates later. num_msa_sequences = msa_activations.shape[0] evoformer_input = { 'msa': msa_activations, 'pair': pair_activations, } evoformer_masks = {'msa': batch['msa_mask'].astype(jnp.float32), 'pair': mask_2d} if c.template.enabled: template_features, template_masks = ( template_embedding_1d(batch=batch, num_channel=c.msa_channel)) evoformer_input['msa'] = jnp.concatenate( [evoformer_input['msa'], template_features], axis=0) evoformer_masks['msa'] = jnp.concatenate( [evoformer_masks['msa'], template_masks], axis=0) evoformer_iteration = modules.EvoformerIteration( c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration') def evoformer_fn(x): act, safe_key = x safe_key, safe_subkey = safe_key.split() evoformer_output = evoformer_iteration( activations=act, masks=evoformer_masks, is_training=is_training, safe_key=safe_subkey) return (evoformer_output, safe_key) if gc.use_remat: evoformer_fn = hk.remat(evoformer_fn) safe_key, safe_subkey = safe_key.split() evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)( evoformer_fn) def run_evoformer(evoformer_input): evoformer_output, _ = evoformer_stack((evoformer_input, safe_subkey)) return evoformer_output evoformer_output = run_evoformer(evoformer_input) msa_activations = evoformer_output['msa'] pair_activations = evoformer_output['pair'] single_activations = common_modules.Linear( c.seq_channel, name='single_activations')( msa_activations[0]) output.update({ 'single': single_activations, 'pair': pair_activations, # Crop away template rows such that they are not used in MaskedMsaHead. 'msa': msa_activations[:num_msa_sequences, :, :], 'msa_first_row': msa_activations[0], }) return output class TemplateEmbedding(hk.Module): """Embed a set of templates.""" def __init__(self, config, global_config, name='template_embedding'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, query_embedding, template_batch, padding_mask_2d, multichain_mask_2d, is_training, safe_key=None): """Generate an embedding for a set of templates. Args: query_embedding: [num_res, num_res, num_channel] a query tensor that will be used to attend over the templates to remove the num_templates dimension. template_batch: A dictionary containing: `template_aatype`: [num_templates, num_res] aatype for each template. `template_all_atom_positions`: [num_templates, num_res, 37, 3] atom positions for all templates. `template_all_atom_mask`: [num_templates, num_res, 37] mask for each template. padding_mask_2d: [num_res, num_res] Pair mask for attention operations. multichain_mask_2d: [num_res, num_res] Mask indicating which residue pairs are intra-chain, used to mask out residue distance based features between chains. is_training: bool indicating where we are running in training mode. safe_key: random key generator. Returns: An embedding of size [num_res, num_res, num_channels] """ c = self.config if safe_key is None: safe_key = prng.SafeKey(hk.next_rng_key()) num_templates = template_batch['template_aatype'].shape[0] num_res, _, query_num_channels = query_embedding.shape # Embed each template separately. template_embedder = SingleTemplateEmbedding(self.config, self.global_config) def partial_template_embedder(template_aatype, template_all_atom_positions, template_all_atom_mask, unsafe_key): safe_key = prng.SafeKey(unsafe_key) return template_embedder(query_embedding, template_aatype, template_all_atom_positions, template_all_atom_mask, padding_mask_2d, multichain_mask_2d, is_training, safe_key) safe_key, unsafe_key = safe_key.split() unsafe_keys = jax.random.split(unsafe_key._key, num_templates) def scan_fn(carry, x): return carry + partial_template_embedder(*x), None scan_init = jnp.zeros((num_res, num_res, c.num_channels), dtype=query_embedding.dtype) summed_template_embeddings, _ = hk.scan( scan_fn, scan_init, (template_batch['template_aatype'], template_batch['template_all_atom_positions'], template_batch['template_all_atom_mask'], unsafe_keys)) embedding = summed_template_embeddings / num_templates embedding = jax.nn.relu(embedding) embedding = common_modules.Linear( query_num_channels, initializer='relu', name='output_linear')(embedding) return embedding class SingleTemplateEmbedding(hk.Module): """Embed a single template.""" def __init__(self, config, global_config, name='single_template_embedding'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, query_embedding, template_aatype, template_all_atom_positions, template_all_atom_mask, padding_mask_2d, multichain_mask_2d, is_training, safe_key): """Build the single template embedding graph. Args: query_embedding: (num_res, num_res, num_channels) - embedding of the query sequence/msa. template_aatype: [num_res] aatype for each template. template_all_atom_positions: [num_res, 37, 3] atom positions for all templates. template_all_atom_mask: [num_res, 37] mask for each template. padding_mask_2d: Padding mask (Note: this doesn't care if a template exists, unlike the template_pseudo_beta_mask). multichain_mask_2d: A mask indicating intra-chain residue pairs, used to mask out between chain distances/features when templates are for single chains. is_training: Are we in training mode. safe_key: Random key generator. Returns: A template embedding (num_res, num_res, num_channels). """ gc = self.global_config c = self.config assert padding_mask_2d.dtype == query_embedding.dtype dtype = query_embedding.dtype num_channels = self.config.num_channels def construct_input(query_embedding, template_aatype, template_all_atom_positions, template_all_atom_mask, multichain_mask_2d): # Compute distogram feature for the template. template_positions, pseudo_beta_mask = modules.pseudo_beta_fn( template_aatype, template_all_atom_positions, template_all_atom_mask) pseudo_beta_mask_2d = (pseudo_beta_mask[:, None] * pseudo_beta_mask[None, :]) pseudo_beta_mask_2d *= multichain_mask_2d template_dgram = modules.dgram_from_positions( template_positions, **self.config.dgram_features) template_dgram *= pseudo_beta_mask_2d[..., None] template_dgram = template_dgram.astype(dtype) pseudo_beta_mask_2d = pseudo_beta_mask_2d.astype(dtype) to_concat = [(template_dgram, 1), (pseudo_beta_mask_2d, 0)] aatype = jax.nn.one_hot(template_aatype, 22, axis=-1, dtype=dtype) to_concat.append((aatype[None, :, :], 1)) to_concat.append((aatype[:, None, :], 1)) # Compute a feature representing the normalized vector between each # backbone affine - i.e. in each residues local frame, what direction are # each of the other residues. raw_atom_pos = template_all_atom_positions atom_pos = geometry.Vec3Array.from_array(raw_atom_pos) rigid, backbone_mask = folding_multimer.make_backbone_affine( atom_pos, template_all_atom_mask, template_aatype) points = rigid.translation rigid_vec = rigid[:, None].inverse().apply_to_point(points) unit_vector = rigid_vec.normalized() unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z] backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :] backbone_mask_2d *= multichain_mask_2d unit_vector = [x*backbone_mask_2d for x in unit_vector] # Note that the backbone_mask takes into account C, CA and N (unlike # pseudo beta mask which just needs CB) so we add both masks as features. to_concat.extend([(x, 0) for x in unit_vector]) to_concat.append((backbone_mask_2d, 0)) query_embedding = hk.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='query_embedding_norm')( query_embedding) # Allow the template embedder to see the query embedding. Note this # contains the position relative feature, so this is how the network knows # which residues are next to each other. to_concat.append((query_embedding, 1)) act = 0 for i, (x, n_input_dims) in enumerate(to_concat): act += common_modules.Linear( num_channels, num_input_dims=n_input_dims, initializer='relu', name=f'template_pair_embedding_{i}')(x) return act act = construct_input(query_embedding, template_aatype, template_all_atom_positions, template_all_atom_mask, multichain_mask_2d) template_iteration = TemplateEmbeddingIteration( c.template_pair_stack, gc, name='template_embedding_iteration') def template_iteration_fn(x): act, safe_key = x safe_key, safe_subkey = safe_key.split() act = template_iteration( act=act, pair_mask=padding_mask_2d, is_training=is_training, safe_key=safe_subkey) return (act, safe_key) if gc.use_remat: template_iteration_fn = hk.remat(template_iteration_fn) safe_key, safe_subkey = safe_key.split() template_stack = layer_stack.layer_stack( c.template_pair_stack.num_block)( template_iteration_fn) act, safe_key = template_stack((act, safe_subkey)) act = hk.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='output_layer_norm')( act) return act class TemplateEmbeddingIteration(hk.Module): """Single Iteration of Template Embedding.""" def __init__(self, config, global_config, name='template_embedding_iteration'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, act, pair_mask, is_training=True, safe_key=None): """Build a single iteration of the template embedder. Args: act: [num_res, num_res, num_channel] Input pairwise activations. pair_mask: [num_res, num_res] padding mask. is_training: Whether to run in training mode. safe_key: Safe pseudo-random generator key. Returns: [num_res, num_res, num_channel] tensor of activations. """ c = self.config gc = self.global_config if safe_key is None: safe_key = prng.SafeKey(hk.next_rng_key()) dropout_wrapper_fn = functools.partial( modules.dropout_wrapper, is_training=is_training, global_config=gc) safe_key, *sub_keys = safe_key.split(20) sub_keys = iter(sub_keys) act = dropout_wrapper_fn( modules.TriangleMultiplication(c.triangle_multiplication_outgoing, gc, name='triangle_multiplication_outgoing'), act, pair_mask, safe_key=next(sub_keys)) act = dropout_wrapper_fn( modules.TriangleMultiplication(c.triangle_multiplication_incoming, gc, name='triangle_multiplication_incoming'), act, pair_mask, safe_key=next(sub_keys)) act = dropout_wrapper_fn( modules.TriangleAttention(c.triangle_attention_starting_node, gc, name='triangle_attention_starting_node'), act, pair_mask, safe_key=next(sub_keys)) act = dropout_wrapper_fn( modules.TriangleAttention(c.triangle_attention_ending_node, gc, name='triangle_attention_ending_node'), act, pair_mask, safe_key=next(sub_keys)) act = dropout_wrapper_fn( modules.Transition(c.pair_transition, gc, name='pair_transition'), act, pair_mask, safe_key=next(sub_keys)) return act def template_embedding_1d(batch, num_channel): """Embed templates into an (num_res, num_templates, num_channels) embedding. Args: batch: A batch containing: template_aatype, (num_templates, num_res) aatype for the templates. template_all_atom_positions, (num_templates, num_residues, 37, 3) atom positions for the templates. template_all_atom_mask, (num_templates, num_residues, 37) atom mask for each template. num_channel: The number of channels in the output. Returns: An embedding of shape (num_templates, num_res, num_channels) and a mask of shape (num_templates, num_res). """ # Embed the templates aatypes. aatype_one_hot = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1) num_templates = batch['template_aatype'].shape[0] all_chi_angles = [] all_chi_masks = [] for i in range(num_templates): atom_pos = geometry.Vec3Array.from_array( batch['template_all_atom_positions'][i, :, :, :]) template_chi_angles, template_chi_mask = all_atom_multimer.compute_chi_angles( atom_pos, batch['template_all_atom_mask'][i, :, :], batch['template_aatype'][i, :]) all_chi_angles.append(template_chi_angles) all_chi_masks.append(template_chi_mask) chi_angles = jnp.stack(all_chi_angles, axis=0) chi_mask = jnp.stack(all_chi_masks, axis=0) template_features = jnp.concatenate([ aatype_one_hot, jnp.sin(chi_angles) * chi_mask, jnp.cos(chi_angles) * chi_mask, chi_mask], axis=-1) template_mask = chi_mask[:, :, 0] template_activations = common_modules.Linear( num_channel, initializer='relu', name='template_single_embedding')( template_features) template_activations = jax.nn.relu(template_activations) template_activations = common_modules.Linear( num_channel, initializer='relu', name='template_projection')( template_activations) return template_activations, template_mask