comparison docker/alphafold/alphafold/model/modules_multimer.py @ 1:6c92e000d684 draft

"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author galaxy-australia
date Tue, 01 Mar 2022 02:53:05 +0000
parents
children
comparison
equal deleted inserted replaced
0:7ae9d78b06f5 1:6c92e000d684
1 # Copyright 2021 DeepMind Technologies Limited
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 """Core modules, which have been refactored in AlphaFold-Multimer.
16
17 The main difference is that MSA sampling pipeline is moved inside the JAX model
18 for easier implementation of recycling and ensembling.
19
20 Lower-level modules up to EvoformerIteration are reused from modules.py.
21 """
22
23 import functools
24 from typing import Sequence
25
26 from alphafold.common import residue_constants
27 from alphafold.model import all_atom_multimer
28 from alphafold.model import common_modules
29 from alphafold.model import folding_multimer
30 from alphafold.model import geometry
31 from alphafold.model import layer_stack
32 from alphafold.model import modules
33 from alphafold.model import prng
34 from alphafold.model import utils
35
36 import haiku as hk
37 import jax
38 import jax.numpy as jnp
39 import numpy as np
40
41
42 def reduce_fn(x, mode):
43 if mode == 'none' or mode is None:
44 return jnp.asarray(x)
45 elif mode == 'sum':
46 return jnp.asarray(x).sum()
47 elif mode == 'mean':
48 return jnp.mean(jnp.asarray(x))
49 else:
50 raise ValueError('Unsupported reduction option.')
51
52
53 def gumbel_noise(key: jnp.ndarray, shape: Sequence[int]) -> jnp.ndarray:
54 """Generate Gumbel Noise of given Shape.
55
56 This generates samples from Gumbel(0, 1).
57
58 Args:
59 key: Jax random number key.
60 shape: Shape of noise to return.
61
62 Returns:
63 Gumbel noise of given shape.
64 """
65 epsilon = 1e-6
66 uniform = utils.padding_consistent_rng(jax.random.uniform)
67 uniform_noise = uniform(
68 key, shape=shape, dtype=jnp.float32, minval=0., maxval=1.)
69 gumbel = -jnp.log(-jnp.log(uniform_noise + epsilon) + epsilon)
70 return gumbel
71
72
73 def gumbel_max_sample(key: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray:
74 """Samples from a probability distribution given by 'logits'.
75
76 This uses Gumbel-max trick to implement the sampling in an efficient manner.
77
78 Args:
79 key: prng key.
80 logits: Logarithm of probabilities to sample from, probabilities can be
81 unnormalized.
82
83 Returns:
84 Sample from logprobs in one-hot form.
85 """
86 z = gumbel_noise(key, logits.shape)
87 return jax.nn.one_hot(
88 jnp.argmax(logits + z, axis=-1),
89 logits.shape[-1],
90 dtype=logits.dtype)
91
92
93 def gumbel_argsort_sample_idx(key: jnp.ndarray,
94 logits: jnp.ndarray) -> jnp.ndarray:
95 """Samples with replacement from a distribution given by 'logits'.
96
97 This uses Gumbel trick to implement the sampling an efficient manner. For a
98 distribution over k items this samples k times without replacement, so this
99 is effectively sampling a random permutation with probabilities over the
100 permutations derived from the logprobs.
101
102 Args:
103 key: prng key.
104 logits: Logarithm of probabilities to sample from, probabilities can be
105 unnormalized.
106
107 Returns:
108 Sample from logprobs in one-hot form.
109 """
110 z = gumbel_noise(key, logits.shape)
111 # This construction is equivalent to jnp.argsort, but using a non stable sort,
112 # since stable sort's aren't supported by jax2tf.
113 axis = len(logits.shape) - 1
114 iota = jax.lax.broadcasted_iota(jnp.int64, logits.shape, axis)
115 _, perm = jax.lax.sort_key_val(
116 logits + z, iota, dimension=-1, is_stable=False)
117 return perm[::-1]
118
119
120 def make_masked_msa(batch, key, config, epsilon=1e-6):
121 """Create data for BERT on raw MSA."""
122 # Add a random amino acid uniformly.
123 random_aa = jnp.array([0.05] * 20 + [0., 0.], dtype=jnp.float32)
124
125 categorical_probs = (
126 config.uniform_prob * random_aa +
127 config.profile_prob * batch['msa_profile'] +
128 config.same_prob * jax.nn.one_hot(batch['msa'], 22))
129
130 # Put all remaining probability on [MASK] which is a new column.
131 pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))]
132 pad_shapes[-1][1] = 1
133 mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
134 assert mask_prob >= 0.
135 categorical_probs = jnp.pad(
136 categorical_probs, pad_shapes, constant_values=mask_prob)
137 sh = batch['msa'].shape
138 key, mask_subkey, gumbel_subkey = key.split(3)
139 uniform = utils.padding_consistent_rng(jax.random.uniform)
140 mask_position = uniform(mask_subkey.get(), sh) < config.replace_fraction
141 mask_position *= batch['msa_mask']
142
143 logits = jnp.log(categorical_probs + epsilon)
144 bert_msa = gumbel_max_sample(gumbel_subkey.get(), logits)
145 bert_msa = jnp.where(mask_position,
146 jnp.argmax(bert_msa, axis=-1), batch['msa'])
147 bert_msa *= batch['msa_mask']
148
149 # Mix real and masked MSA.
150 if 'bert_mask' in batch:
151 batch['bert_mask'] *= mask_position.astype(jnp.float32)
152 else:
153 batch['bert_mask'] = mask_position.astype(jnp.float32)
154 batch['true_msa'] = batch['msa']
155 batch['msa'] = bert_msa
156
157 return batch
158
159
160 def nearest_neighbor_clusters(batch, gap_agreement_weight=0.):
161 """Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
162
163 # Determine how much weight we assign to each agreement. In theory, we could
164 # use a full blosum matrix here, but right now let's just down-weight gap
165 # agreement because it could be spurious.
166 # Never put weight on agreeing on BERT mask.
167
168 weights = jnp.array(
169 [1.] * 21 + [gap_agreement_weight] + [0.], dtype=jnp.float32)
170
171 msa_mask = batch['msa_mask']
172 msa_one_hot = jax.nn.one_hot(batch['msa'], 23)
173
174 extra_mask = batch['extra_msa_mask']
175 extra_one_hot = jax.nn.one_hot(batch['extra_msa'], 23)
176
177 msa_one_hot_masked = msa_mask[:, :, None] * msa_one_hot
178 extra_one_hot_masked = extra_mask[:, :, None] * extra_one_hot
179
180 agreement = jnp.einsum('mrc, nrc->nm', extra_one_hot_masked,
181 weights * msa_one_hot_masked)
182
183 cluster_assignment = jax.nn.softmax(1e3 * agreement, axis=0)
184 cluster_assignment *= jnp.einsum('mr, nr->mn', msa_mask, extra_mask)
185
186 cluster_count = jnp.sum(cluster_assignment, axis=-1)
187 cluster_count += 1. # We always include the sequence itself.
188
189 msa_sum = jnp.einsum('nm, mrc->nrc', cluster_assignment, extra_one_hot_masked)
190 msa_sum += msa_one_hot_masked
191
192 cluster_profile = msa_sum / cluster_count[:, None, None]
193
194 extra_deletion_matrix = batch['extra_deletion_matrix']
195 deletion_matrix = batch['deletion_matrix']
196
197 del_sum = jnp.einsum('nm, mc->nc', cluster_assignment,
198 extra_mask * extra_deletion_matrix)
199 del_sum += deletion_matrix # Original sequence.
200 cluster_deletion_mean = del_sum / cluster_count[:, None]
201
202 return cluster_profile, cluster_deletion_mean
203
204
205 def create_msa_feat(batch):
206 """Create and concatenate MSA features."""
207 msa_1hot = jax.nn.one_hot(batch['msa'], 23)
208 deletion_matrix = batch['deletion_matrix']
209 has_deletion = jnp.clip(deletion_matrix, 0., 1.)[..., None]
210 deletion_value = (jnp.arctan(deletion_matrix / 3.) * (2. / jnp.pi))[..., None]
211
212 deletion_mean_value = (jnp.arctan(batch['cluster_deletion_mean'] / 3.) *
213 (2. / jnp.pi))[..., None]
214
215 msa_feat = [
216 msa_1hot,
217 has_deletion,
218 deletion_value,
219 batch['cluster_profile'],
220 deletion_mean_value
221 ]
222
223 return jnp.concatenate(msa_feat, axis=-1)
224
225
226 def create_extra_msa_feature(batch, num_extra_msa):
227 """Expand extra_msa into 1hot and concat with other extra msa features.
228
229 We do this as late as possible as the one_hot extra msa can be very large.
230
231 Args:
232 batch: a dictionary with the following keys:
233 * 'extra_msa': [num_seq, num_res] MSA that wasn't selected as a cluster
234 centre. Note - This isn't one-hotted.
235 * 'extra_deletion_matrix': [num_seq, num_res] Number of deletions at given
236 position.
237 num_extra_msa: Number of extra msa to use.
238
239 Returns:
240 Concatenated tensor of extra MSA features.
241 """
242 # 23 = 20 amino acids + 'X' for unknown + gap + bert mask
243 extra_msa = batch['extra_msa'][:num_extra_msa]
244 deletion_matrix = batch['extra_deletion_matrix'][:num_extra_msa]
245 msa_1hot = jax.nn.one_hot(extra_msa, 23)
246 has_deletion = jnp.clip(deletion_matrix, 0., 1.)[..., None]
247 deletion_value = (jnp.arctan(deletion_matrix / 3.) * (2. / jnp.pi))[..., None]
248 extra_msa_mask = batch['extra_msa_mask'][:num_extra_msa]
249 return jnp.concatenate([msa_1hot, has_deletion, deletion_value],
250 axis=-1), extra_msa_mask
251
252
253 def sample_msa(key, batch, max_seq):
254 """Sample MSA randomly, remaining sequences are stored as `extra_*`.
255
256 Args:
257 key: safe key for random number generation.
258 batch: batch to sample msa from.
259 max_seq: number of sequences to sample.
260 Returns:
261 Protein with sampled msa.
262 """
263 # Sample uniformly among sequences with at least one non-masked position.
264 logits = (jnp.clip(jnp.sum(batch['msa_mask'], axis=-1), 0., 1.) - 1.) * 1e6
265 # The cluster_bias_mask can be used to preserve the first row (target
266 # sequence) for each chain, for example.
267 if 'cluster_bias_mask' not in batch:
268 cluster_bias_mask = jnp.pad(
269 jnp.zeros(batch['msa'].shape[0] - 1), (1, 0), constant_values=1.)
270 else:
271 cluster_bias_mask = batch['cluster_bias_mask']
272
273 logits += cluster_bias_mask * 1e6
274 index_order = gumbel_argsort_sample_idx(key.get(), logits)
275 sel_idx = index_order[:max_seq]
276 extra_idx = index_order[max_seq:]
277
278 for k in ['msa', 'deletion_matrix', 'msa_mask', 'bert_mask']:
279 if k in batch:
280 batch['extra_' + k] = batch[k][extra_idx]
281 batch[k] = batch[k][sel_idx]
282
283 return batch
284
285
286 def make_msa_profile(batch):
287 """Compute the MSA profile."""
288
289 # Compute the profile for every residue (over all MSA sequences).
290 return utils.mask_mean(
291 batch['msa_mask'][:, :, None], jax.nn.one_hot(batch['msa'], 22), axis=0)
292
293
294 class AlphaFoldIteration(hk.Module):
295 """A single recycling iteration of AlphaFold architecture.
296
297 Computes ensembled (averaged) representations from the provided features.
298 These representations are then passed to the various heads
299 that have been requested by the configuration file.
300 """
301
302 def __init__(self, config, global_config, name='alphafold_iteration'):
303 super().__init__(name=name)
304 self.config = config
305 self.global_config = global_config
306
307 def __call__(self,
308 batch,
309 is_training,
310 return_representations=False,
311 safe_key=None):
312
313 if is_training:
314 num_ensemble = np.asarray(self.config.num_ensemble_train)
315 else:
316 num_ensemble = np.asarray(self.config.num_ensemble_eval)
317
318 # Compute representations for each MSA sample and average.
319 embedding_module = EmbeddingsAndEvoformer(
320 self.config.embeddings_and_evoformer, self.global_config)
321 repr_shape = hk.eval_shape(
322 lambda: embedding_module(batch, is_training))
323 representations = {
324 k: jnp.zeros(v.shape, v.dtype) for (k, v) in repr_shape.items()
325 }
326
327 def ensemble_body(x, unused_y):
328 """Add into representations ensemble."""
329 del unused_y
330 representations, safe_key = x
331 safe_key, safe_subkey = safe_key.split()
332 representations_update = embedding_module(
333 batch, is_training, safe_key=safe_subkey)
334
335 for k in representations:
336 if k not in {'msa', 'true_msa', 'bert_mask'}:
337 representations[k] += representations_update[k] * (
338 1. / num_ensemble).astype(representations[k].dtype)
339 else:
340 representations[k] = representations_update[k]
341
342 return (representations, safe_key), None
343
344 (representations, _), _ = hk.scan(
345 ensemble_body, (representations, safe_key), None, length=num_ensemble)
346
347 self.representations = representations
348 self.batch = batch
349 self.heads = {}
350 for head_name, head_config in sorted(self.config.heads.items()):
351 if not head_config.weight:
352 continue # Do not instantiate zero-weight heads.
353
354 head_factory = {
355 'masked_msa':
356 modules.MaskedMsaHead,
357 'distogram':
358 modules.DistogramHead,
359 'structure_module':
360 folding_multimer.StructureModule,
361 'predicted_aligned_error':
362 modules.PredictedAlignedErrorHead,
363 'predicted_lddt':
364 modules.PredictedLDDTHead,
365 'experimentally_resolved':
366 modules.ExperimentallyResolvedHead,
367 }[head_name]
368 self.heads[head_name] = (head_config,
369 head_factory(head_config, self.global_config))
370
371 structure_module_output = None
372 if 'entity_id' in batch and 'all_atom_positions' in batch:
373 _, fold_module = self.heads['structure_module']
374 structure_module_output = fold_module(representations, batch, is_training)
375
376 ret = {}
377 ret['representations'] = representations
378
379 for name, (head_config, module) in self.heads.items():
380 if name == 'structure_module' and structure_module_output is not None:
381 ret[name] = structure_module_output
382 representations['structure_module'] = structure_module_output.pop('act')
383 # Skip confidence heads until StructureModule is executed.
384 elif name in {'predicted_lddt', 'predicted_aligned_error',
385 'experimentally_resolved'}:
386 continue
387 else:
388 ret[name] = module(representations, batch, is_training)
389
390 # Add confidence heads after StructureModule is executed.
391 if self.config.heads.get('predicted_lddt.weight', 0.0):
392 name = 'predicted_lddt'
393 head_config, module = self.heads[name]
394 ret[name] = module(representations, batch, is_training)
395
396 if self.config.heads.experimentally_resolved.weight:
397 name = 'experimentally_resolved'
398 head_config, module = self.heads[name]
399 ret[name] = module(representations, batch, is_training)
400
401 if self.config.heads.get('predicted_aligned_error.weight', 0.0):
402 name = 'predicted_aligned_error'
403 head_config, module = self.heads[name]
404 ret[name] = module(representations, batch, is_training)
405 # Will be used for ipTM computation.
406 ret[name]['asym_id'] = batch['asym_id']
407
408 return ret
409
410
411 class AlphaFold(hk.Module):
412 """AlphaFold-Multimer model with recycling.
413 """
414
415 def __init__(self, config, name='alphafold'):
416 super().__init__(name=name)
417 self.config = config
418 self.global_config = config.global_config
419
420 def __call__(
421 self,
422 batch,
423 is_training,
424 return_representations=False,
425 safe_key=None):
426
427 c = self.config
428 impl = AlphaFoldIteration(c, self.global_config)
429
430 if safe_key is None:
431 safe_key = prng.SafeKey(hk.next_rng_key())
432 elif isinstance(safe_key, jnp.ndarray):
433 safe_key = prng.SafeKey(safe_key)
434
435 assert isinstance(batch, dict)
436 num_res = batch['aatype'].shape[0]
437
438 def get_prev(ret):
439 new_prev = {
440 'prev_pos':
441 ret['structure_module']['final_atom_positions'],
442 'prev_msa_first_row': ret['representations']['msa_first_row'],
443 'prev_pair': ret['representations']['pair'],
444 }
445 return jax.tree_map(jax.lax.stop_gradient, new_prev)
446
447 def apply_network(prev, safe_key):
448 recycled_batch = {**batch, **prev}
449 return impl(
450 batch=recycled_batch,
451 is_training=is_training,
452 safe_key=safe_key)
453
454 if self.config.num_recycle:
455 emb_config = self.config.embeddings_and_evoformer
456 prev = {
457 'prev_pos':
458 jnp.zeros([num_res, residue_constants.atom_type_num, 3]),
459 'prev_msa_first_row':
460 jnp.zeros([num_res, emb_config.msa_channel]),
461 'prev_pair':
462 jnp.zeros([num_res, num_res, emb_config.pair_channel]),
463 }
464
465 if 'num_iter_recycling' in batch:
466 # Training time: num_iter_recycling is in batch.
467 # Value for each ensemble batch is the same, so arbitrarily taking 0-th.
468 num_iter = batch['num_iter_recycling'][0]
469
470 # Add insurance that even when ensembling, we will not run more
471 # recyclings than the model is configured to run.
472 num_iter = jnp.minimum(num_iter, c.num_recycle)
473 else:
474 # Eval mode or tests: use the maximum number of iterations.
475 num_iter = c.num_recycle
476
477 def recycle_body(i, x):
478 del i
479 prev, safe_key = x
480 safe_key1, safe_key2 = safe_key.split() if c.resample_msa_in_recycling else safe_key.duplicate() # pylint: disable=line-too-long
481 ret = apply_network(prev=prev, safe_key=safe_key2)
482 return get_prev(ret), safe_key1
483
484 prev, safe_key = hk.fori_loop(0, num_iter, recycle_body, (prev, safe_key))
485 else:
486 prev = {}
487
488 # Run extra iteration.
489 ret = apply_network(prev=prev, safe_key=safe_key)
490
491 if not return_representations:
492 del ret['representations']
493 return ret
494
495
496 class EmbeddingsAndEvoformer(hk.Module):
497 """Embeds the input data and runs Evoformer.
498
499 Produces the MSA, single and pair representations.
500 """
501
502 def __init__(self, config, global_config, name='evoformer'):
503 super().__init__(name=name)
504 self.config = config
505 self.global_config = global_config
506
507 def _relative_encoding(self, batch):
508 """Add relative position encodings.
509
510 For position (i, j), the value is (i-j) clipped to [-k, k] and one-hotted.
511
512 When not using 'use_chain_relative' the residue indices are used as is, e.g.
513 for heteromers relative positions will be computed using the positions in
514 the corresponding chains.
515
516 When using 'use_chain_relative' we add an extra bin that denotes
517 'different chain'. Furthermore we also provide the relative chain index
518 (i.e. sym_id) clipped and one-hotted to the network. And an extra feature
519 which denotes whether they belong to the same chain type, i.e. it's 0 if
520 they are in different heteromer chains and 1 otherwise.
521
522 Args:
523 batch: batch.
524 Returns:
525 Feature embedding using the features as described before.
526 """
527 c = self.config
528 rel_feats = []
529 pos = batch['residue_index']
530 asym_id = batch['asym_id']
531 asym_id_same = jnp.equal(asym_id[:, None], asym_id[None, :])
532 offset = pos[:, None] - pos[None, :]
533
534 clipped_offset = jnp.clip(
535 offset + c.max_relative_idx, a_min=0, a_max=2 * c.max_relative_idx)
536
537 if c.use_chain_relative:
538
539 final_offset = jnp.where(asym_id_same, clipped_offset,
540 (2 * c.max_relative_idx + 1) *
541 jnp.ones_like(clipped_offset))
542
543 rel_pos = jax.nn.one_hot(final_offset, 2 * c.max_relative_idx + 2)
544
545 rel_feats.append(rel_pos)
546
547 entity_id = batch['entity_id']
548 entity_id_same = jnp.equal(entity_id[:, None], entity_id[None, :])
549 rel_feats.append(entity_id_same.astype(rel_pos.dtype)[..., None])
550
551 sym_id = batch['sym_id']
552 rel_sym_id = sym_id[:, None] - sym_id[None, :]
553
554 max_rel_chain = c.max_relative_chain
555
556 clipped_rel_chain = jnp.clip(
557 rel_sym_id + max_rel_chain, a_min=0, a_max=2 * max_rel_chain)
558
559 final_rel_chain = jnp.where(entity_id_same, clipped_rel_chain,
560 (2 * max_rel_chain + 1) *
561 jnp.ones_like(clipped_rel_chain))
562 rel_chain = jax.nn.one_hot(final_rel_chain, 2 * c.max_relative_chain + 2)
563
564 rel_feats.append(rel_chain)
565
566 else:
567 rel_pos = jax.nn.one_hot(clipped_offset, 2 * c.max_relative_idx + 1)
568 rel_feats.append(rel_pos)
569
570 rel_feat = jnp.concatenate(rel_feats, axis=-1)
571
572 return common_modules.Linear(
573 c.pair_channel,
574 name='position_activations')(
575 rel_feat)
576
577 def __call__(self, batch, is_training, safe_key=None):
578
579 c = self.config
580 gc = self.global_config
581
582 batch = dict(batch)
583
584 if safe_key is None:
585 safe_key = prng.SafeKey(hk.next_rng_key())
586
587 output = {}
588
589 batch['msa_profile'] = make_msa_profile(batch)
590
591 target_feat = jax.nn.one_hot(batch['aatype'], 21)
592
593 preprocess_1d = common_modules.Linear(
594 c.msa_channel, name='preprocess_1d')(
595 target_feat)
596
597 safe_key, sample_key, mask_key = safe_key.split(3)
598 batch = sample_msa(sample_key, batch, c.num_msa)
599 batch = make_masked_msa(batch, mask_key, c.masked_msa)
600
601 (batch['cluster_profile'],
602 batch['cluster_deletion_mean']) = nearest_neighbor_clusters(batch)
603
604 msa_feat = create_msa_feat(batch)
605
606 preprocess_msa = common_modules.Linear(
607 c.msa_channel, name='preprocess_msa')(
608 msa_feat)
609
610 msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa
611
612 left_single = common_modules.Linear(
613 c.pair_channel, name='left_single')(
614 target_feat)
615 right_single = common_modules.Linear(
616 c.pair_channel, name='right_single')(
617 target_feat)
618 pair_activations = left_single[:, None] + right_single[None]
619 mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
620 mask_2d = mask_2d.astype(jnp.float32)
621
622 if c.recycle_pos and 'prev_pos' in batch:
623 prev_pseudo_beta = modules.pseudo_beta_fn(
624 batch['aatype'], batch['prev_pos'], None)
625
626 dgram = modules.dgram_from_positions(
627 prev_pseudo_beta, **self.config.prev_pos)
628 pair_activations += common_modules.Linear(
629 c.pair_channel, name='prev_pos_linear')(
630 dgram)
631
632 if c.recycle_features:
633 if 'prev_msa_first_row' in batch:
634 prev_msa_first_row = hk.LayerNorm(
635 axis=[-1],
636 create_scale=True,
637 create_offset=True,
638 name='prev_msa_first_row_norm')(
639 batch['prev_msa_first_row'])
640 msa_activations = msa_activations.at[0].add(prev_msa_first_row)
641
642 if 'prev_pair' in batch:
643 pair_activations += hk.LayerNorm(
644 axis=[-1],
645 create_scale=True,
646 create_offset=True,
647 name='prev_pair_norm')(
648 batch['prev_pair'])
649
650 if c.max_relative_idx:
651 pair_activations += self._relative_encoding(batch)
652
653 if c.template.enabled:
654 template_module = TemplateEmbedding(c.template, gc)
655 template_batch = {
656 'template_aatype': batch['template_aatype'],
657 'template_all_atom_positions': batch['template_all_atom_positions'],
658 'template_all_atom_mask': batch['template_all_atom_mask']
659 }
660 # Construct a mask such that only intra-chain template features are
661 # computed, since all templates are for each chain individually.
662 multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :]
663 safe_key, safe_subkey = safe_key.split()
664 template_act = template_module(
665 query_embedding=pair_activations,
666 template_batch=template_batch,
667 padding_mask_2d=mask_2d,
668 multichain_mask_2d=multichain_mask,
669 is_training=is_training,
670 safe_key=safe_subkey)
671 pair_activations += template_act
672
673 # Extra MSA stack.
674 (extra_msa_feat,
675 extra_msa_mask) = create_extra_msa_feature(batch, c.num_extra_msa)
676 extra_msa_activations = common_modules.Linear(
677 c.extra_msa_channel,
678 name='extra_msa_activations')(
679 extra_msa_feat)
680 extra_msa_mask = extra_msa_mask.astype(jnp.float32)
681
682 extra_evoformer_input = {
683 'msa': extra_msa_activations,
684 'pair': pair_activations,
685 }
686 extra_masks = {'msa': extra_msa_mask, 'pair': mask_2d}
687
688 extra_evoformer_iteration = modules.EvoformerIteration(
689 c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack')
690
691 def extra_evoformer_fn(x):
692 act, safe_key = x
693 safe_key, safe_subkey = safe_key.split()
694 extra_evoformer_output = extra_evoformer_iteration(
695 activations=act,
696 masks=extra_masks,
697 is_training=is_training,
698 safe_key=safe_subkey)
699 return (extra_evoformer_output, safe_key)
700
701 if gc.use_remat:
702 extra_evoformer_fn = hk.remat(extra_evoformer_fn)
703
704 safe_key, safe_subkey = safe_key.split()
705 extra_evoformer_stack = layer_stack.layer_stack(
706 c.extra_msa_stack_num_block)(
707 extra_evoformer_fn)
708 extra_evoformer_output, safe_key = extra_evoformer_stack(
709 (extra_evoformer_input, safe_subkey))
710
711 pair_activations = extra_evoformer_output['pair']
712
713 # Get the size of the MSA before potentially adding templates, so we
714 # can crop out the templates later.
715 num_msa_sequences = msa_activations.shape[0]
716 evoformer_input = {
717 'msa': msa_activations,
718 'pair': pair_activations,
719 }
720 evoformer_masks = {'msa': batch['msa_mask'].astype(jnp.float32),
721 'pair': mask_2d}
722
723 if c.template.enabled:
724 template_features, template_masks = (
725 template_embedding_1d(batch=batch, num_channel=c.msa_channel))
726
727 evoformer_input['msa'] = jnp.concatenate(
728 [evoformer_input['msa'], template_features], axis=0)
729 evoformer_masks['msa'] = jnp.concatenate(
730 [evoformer_masks['msa'], template_masks], axis=0)
731
732 evoformer_iteration = modules.EvoformerIteration(
733 c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration')
734
735 def evoformer_fn(x):
736 act, safe_key = x
737 safe_key, safe_subkey = safe_key.split()
738 evoformer_output = evoformer_iteration(
739 activations=act,
740 masks=evoformer_masks,
741 is_training=is_training,
742 safe_key=safe_subkey)
743 return (evoformer_output, safe_key)
744
745 if gc.use_remat:
746 evoformer_fn = hk.remat(evoformer_fn)
747
748 safe_key, safe_subkey = safe_key.split()
749 evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)(
750 evoformer_fn)
751
752 def run_evoformer(evoformer_input):
753 evoformer_output, _ = evoformer_stack((evoformer_input, safe_subkey))
754 return evoformer_output
755
756 evoformer_output = run_evoformer(evoformer_input)
757
758 msa_activations = evoformer_output['msa']
759 pair_activations = evoformer_output['pair']
760
761 single_activations = common_modules.Linear(
762 c.seq_channel, name='single_activations')(
763 msa_activations[0])
764
765 output.update({
766 'single':
767 single_activations,
768 'pair':
769 pair_activations,
770 # Crop away template rows such that they are not used in MaskedMsaHead.
771 'msa':
772 msa_activations[:num_msa_sequences, :, :],
773 'msa_first_row':
774 msa_activations[0],
775 })
776
777 return output
778
779
780 class TemplateEmbedding(hk.Module):
781 """Embed a set of templates."""
782
783 def __init__(self, config, global_config, name='template_embedding'):
784 super().__init__(name=name)
785 self.config = config
786 self.global_config = global_config
787
788 def __call__(self, query_embedding, template_batch, padding_mask_2d,
789 multichain_mask_2d, is_training,
790 safe_key=None):
791 """Generate an embedding for a set of templates.
792
793 Args:
794 query_embedding: [num_res, num_res, num_channel] a query tensor that will
795 be used to attend over the templates to remove the num_templates
796 dimension.
797 template_batch: A dictionary containing:
798 `template_aatype`: [num_templates, num_res] aatype for each template.
799 `template_all_atom_positions`: [num_templates, num_res, 37, 3] atom
800 positions for all templates.
801 `template_all_atom_mask`: [num_templates, num_res, 37] mask for each
802 template.
803 padding_mask_2d: [num_res, num_res] Pair mask for attention operations.
804 multichain_mask_2d: [num_res, num_res] Mask indicating which residue pairs
805 are intra-chain, used to mask out residue distance based features
806 between chains.
807 is_training: bool indicating where we are running in training mode.
808 safe_key: random key generator.
809
810 Returns:
811 An embedding of size [num_res, num_res, num_channels]
812 """
813 c = self.config
814 if safe_key is None:
815 safe_key = prng.SafeKey(hk.next_rng_key())
816
817 num_templates = template_batch['template_aatype'].shape[0]
818 num_res, _, query_num_channels = query_embedding.shape
819
820 # Embed each template separately.
821 template_embedder = SingleTemplateEmbedding(self.config, self.global_config)
822 def partial_template_embedder(template_aatype,
823 template_all_atom_positions,
824 template_all_atom_mask,
825 unsafe_key):
826 safe_key = prng.SafeKey(unsafe_key)
827 return template_embedder(query_embedding,
828 template_aatype,
829 template_all_atom_positions,
830 template_all_atom_mask,
831 padding_mask_2d,
832 multichain_mask_2d,
833 is_training,
834 safe_key)
835
836 safe_key, unsafe_key = safe_key.split()
837 unsafe_keys = jax.random.split(unsafe_key._key, num_templates)
838
839 def scan_fn(carry, x):
840 return carry + partial_template_embedder(*x), None
841
842 scan_init = jnp.zeros((num_res, num_res, c.num_channels),
843 dtype=query_embedding.dtype)
844 summed_template_embeddings, _ = hk.scan(
845 scan_fn, scan_init,
846 (template_batch['template_aatype'],
847 template_batch['template_all_atom_positions'],
848 template_batch['template_all_atom_mask'], unsafe_keys))
849
850 embedding = summed_template_embeddings / num_templates
851 embedding = jax.nn.relu(embedding)
852 embedding = common_modules.Linear(
853 query_num_channels,
854 initializer='relu',
855 name='output_linear')(embedding)
856
857 return embedding
858
859
860 class SingleTemplateEmbedding(hk.Module):
861 """Embed a single template."""
862
863 def __init__(self, config, global_config, name='single_template_embedding'):
864 super().__init__(name=name)
865 self.config = config
866 self.global_config = global_config
867
868 def __call__(self, query_embedding, template_aatype,
869 template_all_atom_positions, template_all_atom_mask,
870 padding_mask_2d, multichain_mask_2d, is_training,
871 safe_key):
872 """Build the single template embedding graph.
873
874 Args:
875 query_embedding: (num_res, num_res, num_channels) - embedding of the
876 query sequence/msa.
877 template_aatype: [num_res] aatype for each template.
878 template_all_atom_positions: [num_res, 37, 3] atom positions for all
879 templates.
880 template_all_atom_mask: [num_res, 37] mask for each template.
881 padding_mask_2d: Padding mask (Note: this doesn't care if a template
882 exists, unlike the template_pseudo_beta_mask).
883 multichain_mask_2d: A mask indicating intra-chain residue pairs, used
884 to mask out between chain distances/features when templates are for
885 single chains.
886 is_training: Are we in training mode.
887 safe_key: Random key generator.
888
889 Returns:
890 A template embedding (num_res, num_res, num_channels).
891 """
892 gc = self.global_config
893 c = self.config
894 assert padding_mask_2d.dtype == query_embedding.dtype
895 dtype = query_embedding.dtype
896 num_channels = self.config.num_channels
897
898 def construct_input(query_embedding, template_aatype,
899 template_all_atom_positions, template_all_atom_mask,
900 multichain_mask_2d):
901
902 # Compute distogram feature for the template.
903 template_positions, pseudo_beta_mask = modules.pseudo_beta_fn(
904 template_aatype, template_all_atom_positions, template_all_atom_mask)
905 pseudo_beta_mask_2d = (pseudo_beta_mask[:, None] *
906 pseudo_beta_mask[None, :])
907 pseudo_beta_mask_2d *= multichain_mask_2d
908 template_dgram = modules.dgram_from_positions(
909 template_positions, **self.config.dgram_features)
910 template_dgram *= pseudo_beta_mask_2d[..., None]
911 template_dgram = template_dgram.astype(dtype)
912 pseudo_beta_mask_2d = pseudo_beta_mask_2d.astype(dtype)
913 to_concat = [(template_dgram, 1), (pseudo_beta_mask_2d, 0)]
914
915 aatype = jax.nn.one_hot(template_aatype, 22, axis=-1, dtype=dtype)
916 to_concat.append((aatype[None, :, :], 1))
917 to_concat.append((aatype[:, None, :], 1))
918
919 # Compute a feature representing the normalized vector between each
920 # backbone affine - i.e. in each residues local frame, what direction are
921 # each of the other residues.
922 raw_atom_pos = template_all_atom_positions
923
924 atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
925 rigid, backbone_mask = folding_multimer.make_backbone_affine(
926 atom_pos,
927 template_all_atom_mask,
928 template_aatype)
929 points = rigid.translation
930 rigid_vec = rigid[:, None].inverse().apply_to_point(points)
931 unit_vector = rigid_vec.normalized()
932 unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z]
933
934 backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :]
935 backbone_mask_2d *= multichain_mask_2d
936 unit_vector = [x*backbone_mask_2d for x in unit_vector]
937
938 # Note that the backbone_mask takes into account C, CA and N (unlike
939 # pseudo beta mask which just needs CB) so we add both masks as features.
940 to_concat.extend([(x, 0) for x in unit_vector])
941 to_concat.append((backbone_mask_2d, 0))
942
943 query_embedding = hk.LayerNorm(
944 axis=[-1],
945 create_scale=True,
946 create_offset=True,
947 name='query_embedding_norm')(
948 query_embedding)
949 # Allow the template embedder to see the query embedding. Note this
950 # contains the position relative feature, so this is how the network knows
951 # which residues are next to each other.
952 to_concat.append((query_embedding, 1))
953
954 act = 0
955
956 for i, (x, n_input_dims) in enumerate(to_concat):
957
958 act += common_modules.Linear(
959 num_channels,
960 num_input_dims=n_input_dims,
961 initializer='relu',
962 name=f'template_pair_embedding_{i}')(x)
963 return act
964
965 act = construct_input(query_embedding, template_aatype,
966 template_all_atom_positions, template_all_atom_mask,
967 multichain_mask_2d)
968
969 template_iteration = TemplateEmbeddingIteration(
970 c.template_pair_stack, gc, name='template_embedding_iteration')
971
972 def template_iteration_fn(x):
973 act, safe_key = x
974
975 safe_key, safe_subkey = safe_key.split()
976 act = template_iteration(
977 act=act,
978 pair_mask=padding_mask_2d,
979 is_training=is_training,
980 safe_key=safe_subkey)
981 return (act, safe_key)
982
983 if gc.use_remat:
984 template_iteration_fn = hk.remat(template_iteration_fn)
985
986 safe_key, safe_subkey = safe_key.split()
987 template_stack = layer_stack.layer_stack(
988 c.template_pair_stack.num_block)(
989 template_iteration_fn)
990 act, safe_key = template_stack((act, safe_subkey))
991
992 act = hk.LayerNorm(
993 axis=[-1],
994 create_scale=True,
995 create_offset=True,
996 name='output_layer_norm')(
997 act)
998 return act
999
1000
1001 class TemplateEmbeddingIteration(hk.Module):
1002 """Single Iteration of Template Embedding."""
1003
1004 def __init__(self, config, global_config,
1005 name='template_embedding_iteration'):
1006 super().__init__(name=name)
1007 self.config = config
1008 self.global_config = global_config
1009
1010 def __call__(self, act, pair_mask, is_training=True,
1011 safe_key=None):
1012 """Build a single iteration of the template embedder.
1013
1014 Args:
1015 act: [num_res, num_res, num_channel] Input pairwise activations.
1016 pair_mask: [num_res, num_res] padding mask.
1017 is_training: Whether to run in training mode.
1018 safe_key: Safe pseudo-random generator key.
1019
1020 Returns:
1021 [num_res, num_res, num_channel] tensor of activations.
1022 """
1023 c = self.config
1024 gc = self.global_config
1025
1026 if safe_key is None:
1027 safe_key = prng.SafeKey(hk.next_rng_key())
1028
1029 dropout_wrapper_fn = functools.partial(
1030 modules.dropout_wrapper,
1031 is_training=is_training,
1032 global_config=gc)
1033
1034 safe_key, *sub_keys = safe_key.split(20)
1035 sub_keys = iter(sub_keys)
1036
1037 act = dropout_wrapper_fn(
1038 modules.TriangleMultiplication(c.triangle_multiplication_outgoing, gc,
1039 name='triangle_multiplication_outgoing'),
1040 act,
1041 pair_mask,
1042 safe_key=next(sub_keys))
1043
1044 act = dropout_wrapper_fn(
1045 modules.TriangleMultiplication(c.triangle_multiplication_incoming, gc,
1046 name='triangle_multiplication_incoming'),
1047 act,
1048 pair_mask,
1049 safe_key=next(sub_keys))
1050
1051 act = dropout_wrapper_fn(
1052 modules.TriangleAttention(c.triangle_attention_starting_node, gc,
1053 name='triangle_attention_starting_node'),
1054 act,
1055 pair_mask,
1056 safe_key=next(sub_keys))
1057
1058 act = dropout_wrapper_fn(
1059 modules.TriangleAttention(c.triangle_attention_ending_node, gc,
1060 name='triangle_attention_ending_node'),
1061 act,
1062 pair_mask,
1063 safe_key=next(sub_keys))
1064
1065 act = dropout_wrapper_fn(
1066 modules.Transition(c.pair_transition, gc,
1067 name='pair_transition'),
1068 act,
1069 pair_mask,
1070 safe_key=next(sub_keys))
1071
1072 return act
1073
1074
1075 def template_embedding_1d(batch, num_channel):
1076 """Embed templates into an (num_res, num_templates, num_channels) embedding.
1077
1078 Args:
1079 batch: A batch containing:
1080 template_aatype, (num_templates, num_res) aatype for the templates.
1081 template_all_atom_positions, (num_templates, num_residues, 37, 3) atom
1082 positions for the templates.
1083 template_all_atom_mask, (num_templates, num_residues, 37) atom mask for
1084 each template.
1085 num_channel: The number of channels in the output.
1086
1087 Returns:
1088 An embedding of shape (num_templates, num_res, num_channels) and a mask of
1089 shape (num_templates, num_res).
1090 """
1091
1092 # Embed the templates aatypes.
1093 aatype_one_hot = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1)
1094
1095 num_templates = batch['template_aatype'].shape[0]
1096 all_chi_angles = []
1097 all_chi_masks = []
1098 for i in range(num_templates):
1099 atom_pos = geometry.Vec3Array.from_array(
1100 batch['template_all_atom_positions'][i, :, :, :])
1101 template_chi_angles, template_chi_mask = all_atom_multimer.compute_chi_angles(
1102 atom_pos,
1103 batch['template_all_atom_mask'][i, :, :],
1104 batch['template_aatype'][i, :])
1105 all_chi_angles.append(template_chi_angles)
1106 all_chi_masks.append(template_chi_mask)
1107 chi_angles = jnp.stack(all_chi_angles, axis=0)
1108 chi_mask = jnp.stack(all_chi_masks, axis=0)
1109
1110 template_features = jnp.concatenate([
1111 aatype_one_hot,
1112 jnp.sin(chi_angles) * chi_mask,
1113 jnp.cos(chi_angles) * chi_mask,
1114 chi_mask], axis=-1)
1115
1116 template_mask = chi_mask[:, :, 0]
1117
1118 template_activations = common_modules.Linear(
1119 num_channel,
1120 initializer='relu',
1121 name='template_single_embedding')(
1122 template_features)
1123 template_activations = jax.nn.relu(template_activations)
1124 template_activations = common_modules.Linear(
1125 num_channel,
1126 initializer='relu',
1127 name='template_projection')(
1128 template_activations)
1129 return template_activations, template_mask