Mercurial > repos > galaxy-australia > alphafold2
comparison docker/alphafold/alphafold/model/modules_multimer.py @ 1:6c92e000d684 draft
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author | galaxy-australia |
---|---|
date | Tue, 01 Mar 2022 02:53:05 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
0:7ae9d78b06f5 | 1:6c92e000d684 |
---|---|
1 # Copyright 2021 DeepMind Technologies Limited | |
2 # | |
3 # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 # you may not use this file except in compliance with the License. | |
5 # You may obtain a copy of the License at | |
6 # | |
7 # http://www.apache.org/licenses/LICENSE-2.0 | |
8 # | |
9 # Unless required by applicable law or agreed to in writing, software | |
10 # distributed under the License is distributed on an "AS IS" BASIS, | |
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 # See the License for the specific language governing permissions and | |
13 # limitations under the License. | |
14 | |
15 """Core modules, which have been refactored in AlphaFold-Multimer. | |
16 | |
17 The main difference is that MSA sampling pipeline is moved inside the JAX model | |
18 for easier implementation of recycling and ensembling. | |
19 | |
20 Lower-level modules up to EvoformerIteration are reused from modules.py. | |
21 """ | |
22 | |
23 import functools | |
24 from typing import Sequence | |
25 | |
26 from alphafold.common import residue_constants | |
27 from alphafold.model import all_atom_multimer | |
28 from alphafold.model import common_modules | |
29 from alphafold.model import folding_multimer | |
30 from alphafold.model import geometry | |
31 from alphafold.model import layer_stack | |
32 from alphafold.model import modules | |
33 from alphafold.model import prng | |
34 from alphafold.model import utils | |
35 | |
36 import haiku as hk | |
37 import jax | |
38 import jax.numpy as jnp | |
39 import numpy as np | |
40 | |
41 | |
42 def reduce_fn(x, mode): | |
43 if mode == 'none' or mode is None: | |
44 return jnp.asarray(x) | |
45 elif mode == 'sum': | |
46 return jnp.asarray(x).sum() | |
47 elif mode == 'mean': | |
48 return jnp.mean(jnp.asarray(x)) | |
49 else: | |
50 raise ValueError('Unsupported reduction option.') | |
51 | |
52 | |
53 def gumbel_noise(key: jnp.ndarray, shape: Sequence[int]) -> jnp.ndarray: | |
54 """Generate Gumbel Noise of given Shape. | |
55 | |
56 This generates samples from Gumbel(0, 1). | |
57 | |
58 Args: | |
59 key: Jax random number key. | |
60 shape: Shape of noise to return. | |
61 | |
62 Returns: | |
63 Gumbel noise of given shape. | |
64 """ | |
65 epsilon = 1e-6 | |
66 uniform = utils.padding_consistent_rng(jax.random.uniform) | |
67 uniform_noise = uniform( | |
68 key, shape=shape, dtype=jnp.float32, minval=0., maxval=1.) | |
69 gumbel = -jnp.log(-jnp.log(uniform_noise + epsilon) + epsilon) | |
70 return gumbel | |
71 | |
72 | |
73 def gumbel_max_sample(key: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray: | |
74 """Samples from a probability distribution given by 'logits'. | |
75 | |
76 This uses Gumbel-max trick to implement the sampling in an efficient manner. | |
77 | |
78 Args: | |
79 key: prng key. | |
80 logits: Logarithm of probabilities to sample from, probabilities can be | |
81 unnormalized. | |
82 | |
83 Returns: | |
84 Sample from logprobs in one-hot form. | |
85 """ | |
86 z = gumbel_noise(key, logits.shape) | |
87 return jax.nn.one_hot( | |
88 jnp.argmax(logits + z, axis=-1), | |
89 logits.shape[-1], | |
90 dtype=logits.dtype) | |
91 | |
92 | |
93 def gumbel_argsort_sample_idx(key: jnp.ndarray, | |
94 logits: jnp.ndarray) -> jnp.ndarray: | |
95 """Samples with replacement from a distribution given by 'logits'. | |
96 | |
97 This uses Gumbel trick to implement the sampling an efficient manner. For a | |
98 distribution over k items this samples k times without replacement, so this | |
99 is effectively sampling a random permutation with probabilities over the | |
100 permutations derived from the logprobs. | |
101 | |
102 Args: | |
103 key: prng key. | |
104 logits: Logarithm of probabilities to sample from, probabilities can be | |
105 unnormalized. | |
106 | |
107 Returns: | |
108 Sample from logprobs in one-hot form. | |
109 """ | |
110 z = gumbel_noise(key, logits.shape) | |
111 # This construction is equivalent to jnp.argsort, but using a non stable sort, | |
112 # since stable sort's aren't supported by jax2tf. | |
113 axis = len(logits.shape) - 1 | |
114 iota = jax.lax.broadcasted_iota(jnp.int64, logits.shape, axis) | |
115 _, perm = jax.lax.sort_key_val( | |
116 logits + z, iota, dimension=-1, is_stable=False) | |
117 return perm[::-1] | |
118 | |
119 | |
120 def make_masked_msa(batch, key, config, epsilon=1e-6): | |
121 """Create data for BERT on raw MSA.""" | |
122 # Add a random amino acid uniformly. | |
123 random_aa = jnp.array([0.05] * 20 + [0., 0.], dtype=jnp.float32) | |
124 | |
125 categorical_probs = ( | |
126 config.uniform_prob * random_aa + | |
127 config.profile_prob * batch['msa_profile'] + | |
128 config.same_prob * jax.nn.one_hot(batch['msa'], 22)) | |
129 | |
130 # Put all remaining probability on [MASK] which is a new column. | |
131 pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))] | |
132 pad_shapes[-1][1] = 1 | |
133 mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob | |
134 assert mask_prob >= 0. | |
135 categorical_probs = jnp.pad( | |
136 categorical_probs, pad_shapes, constant_values=mask_prob) | |
137 sh = batch['msa'].shape | |
138 key, mask_subkey, gumbel_subkey = key.split(3) | |
139 uniform = utils.padding_consistent_rng(jax.random.uniform) | |
140 mask_position = uniform(mask_subkey.get(), sh) < config.replace_fraction | |
141 mask_position *= batch['msa_mask'] | |
142 | |
143 logits = jnp.log(categorical_probs + epsilon) | |
144 bert_msa = gumbel_max_sample(gumbel_subkey.get(), logits) | |
145 bert_msa = jnp.where(mask_position, | |
146 jnp.argmax(bert_msa, axis=-1), batch['msa']) | |
147 bert_msa *= batch['msa_mask'] | |
148 | |
149 # Mix real and masked MSA. | |
150 if 'bert_mask' in batch: | |
151 batch['bert_mask'] *= mask_position.astype(jnp.float32) | |
152 else: | |
153 batch['bert_mask'] = mask_position.astype(jnp.float32) | |
154 batch['true_msa'] = batch['msa'] | |
155 batch['msa'] = bert_msa | |
156 | |
157 return batch | |
158 | |
159 | |
160 def nearest_neighbor_clusters(batch, gap_agreement_weight=0.): | |
161 """Assign each extra MSA sequence to its nearest neighbor in sampled MSA.""" | |
162 | |
163 # Determine how much weight we assign to each agreement. In theory, we could | |
164 # use a full blosum matrix here, but right now let's just down-weight gap | |
165 # agreement because it could be spurious. | |
166 # Never put weight on agreeing on BERT mask. | |
167 | |
168 weights = jnp.array( | |
169 [1.] * 21 + [gap_agreement_weight] + [0.], dtype=jnp.float32) | |
170 | |
171 msa_mask = batch['msa_mask'] | |
172 msa_one_hot = jax.nn.one_hot(batch['msa'], 23) | |
173 | |
174 extra_mask = batch['extra_msa_mask'] | |
175 extra_one_hot = jax.nn.one_hot(batch['extra_msa'], 23) | |
176 | |
177 msa_one_hot_masked = msa_mask[:, :, None] * msa_one_hot | |
178 extra_one_hot_masked = extra_mask[:, :, None] * extra_one_hot | |
179 | |
180 agreement = jnp.einsum('mrc, nrc->nm', extra_one_hot_masked, | |
181 weights * msa_one_hot_masked) | |
182 | |
183 cluster_assignment = jax.nn.softmax(1e3 * agreement, axis=0) | |
184 cluster_assignment *= jnp.einsum('mr, nr->mn', msa_mask, extra_mask) | |
185 | |
186 cluster_count = jnp.sum(cluster_assignment, axis=-1) | |
187 cluster_count += 1. # We always include the sequence itself. | |
188 | |
189 msa_sum = jnp.einsum('nm, mrc->nrc', cluster_assignment, extra_one_hot_masked) | |
190 msa_sum += msa_one_hot_masked | |
191 | |
192 cluster_profile = msa_sum / cluster_count[:, None, None] | |
193 | |
194 extra_deletion_matrix = batch['extra_deletion_matrix'] | |
195 deletion_matrix = batch['deletion_matrix'] | |
196 | |
197 del_sum = jnp.einsum('nm, mc->nc', cluster_assignment, | |
198 extra_mask * extra_deletion_matrix) | |
199 del_sum += deletion_matrix # Original sequence. | |
200 cluster_deletion_mean = del_sum / cluster_count[:, None] | |
201 | |
202 return cluster_profile, cluster_deletion_mean | |
203 | |
204 | |
205 def create_msa_feat(batch): | |
206 """Create and concatenate MSA features.""" | |
207 msa_1hot = jax.nn.one_hot(batch['msa'], 23) | |
208 deletion_matrix = batch['deletion_matrix'] | |
209 has_deletion = jnp.clip(deletion_matrix, 0., 1.)[..., None] | |
210 deletion_value = (jnp.arctan(deletion_matrix / 3.) * (2. / jnp.pi))[..., None] | |
211 | |
212 deletion_mean_value = (jnp.arctan(batch['cluster_deletion_mean'] / 3.) * | |
213 (2. / jnp.pi))[..., None] | |
214 | |
215 msa_feat = [ | |
216 msa_1hot, | |
217 has_deletion, | |
218 deletion_value, | |
219 batch['cluster_profile'], | |
220 deletion_mean_value | |
221 ] | |
222 | |
223 return jnp.concatenate(msa_feat, axis=-1) | |
224 | |
225 | |
226 def create_extra_msa_feature(batch, num_extra_msa): | |
227 """Expand extra_msa into 1hot and concat with other extra msa features. | |
228 | |
229 We do this as late as possible as the one_hot extra msa can be very large. | |
230 | |
231 Args: | |
232 batch: a dictionary with the following keys: | |
233 * 'extra_msa': [num_seq, num_res] MSA that wasn't selected as a cluster | |
234 centre. Note - This isn't one-hotted. | |
235 * 'extra_deletion_matrix': [num_seq, num_res] Number of deletions at given | |
236 position. | |
237 num_extra_msa: Number of extra msa to use. | |
238 | |
239 Returns: | |
240 Concatenated tensor of extra MSA features. | |
241 """ | |
242 # 23 = 20 amino acids + 'X' for unknown + gap + bert mask | |
243 extra_msa = batch['extra_msa'][:num_extra_msa] | |
244 deletion_matrix = batch['extra_deletion_matrix'][:num_extra_msa] | |
245 msa_1hot = jax.nn.one_hot(extra_msa, 23) | |
246 has_deletion = jnp.clip(deletion_matrix, 0., 1.)[..., None] | |
247 deletion_value = (jnp.arctan(deletion_matrix / 3.) * (2. / jnp.pi))[..., None] | |
248 extra_msa_mask = batch['extra_msa_mask'][:num_extra_msa] | |
249 return jnp.concatenate([msa_1hot, has_deletion, deletion_value], | |
250 axis=-1), extra_msa_mask | |
251 | |
252 | |
253 def sample_msa(key, batch, max_seq): | |
254 """Sample MSA randomly, remaining sequences are stored as `extra_*`. | |
255 | |
256 Args: | |
257 key: safe key for random number generation. | |
258 batch: batch to sample msa from. | |
259 max_seq: number of sequences to sample. | |
260 Returns: | |
261 Protein with sampled msa. | |
262 """ | |
263 # Sample uniformly among sequences with at least one non-masked position. | |
264 logits = (jnp.clip(jnp.sum(batch['msa_mask'], axis=-1), 0., 1.) - 1.) * 1e6 | |
265 # The cluster_bias_mask can be used to preserve the first row (target | |
266 # sequence) for each chain, for example. | |
267 if 'cluster_bias_mask' not in batch: | |
268 cluster_bias_mask = jnp.pad( | |
269 jnp.zeros(batch['msa'].shape[0] - 1), (1, 0), constant_values=1.) | |
270 else: | |
271 cluster_bias_mask = batch['cluster_bias_mask'] | |
272 | |
273 logits += cluster_bias_mask * 1e6 | |
274 index_order = gumbel_argsort_sample_idx(key.get(), logits) | |
275 sel_idx = index_order[:max_seq] | |
276 extra_idx = index_order[max_seq:] | |
277 | |
278 for k in ['msa', 'deletion_matrix', 'msa_mask', 'bert_mask']: | |
279 if k in batch: | |
280 batch['extra_' + k] = batch[k][extra_idx] | |
281 batch[k] = batch[k][sel_idx] | |
282 | |
283 return batch | |
284 | |
285 | |
286 def make_msa_profile(batch): | |
287 """Compute the MSA profile.""" | |
288 | |
289 # Compute the profile for every residue (over all MSA sequences). | |
290 return utils.mask_mean( | |
291 batch['msa_mask'][:, :, None], jax.nn.one_hot(batch['msa'], 22), axis=0) | |
292 | |
293 | |
294 class AlphaFoldIteration(hk.Module): | |
295 """A single recycling iteration of AlphaFold architecture. | |
296 | |
297 Computes ensembled (averaged) representations from the provided features. | |
298 These representations are then passed to the various heads | |
299 that have been requested by the configuration file. | |
300 """ | |
301 | |
302 def __init__(self, config, global_config, name='alphafold_iteration'): | |
303 super().__init__(name=name) | |
304 self.config = config | |
305 self.global_config = global_config | |
306 | |
307 def __call__(self, | |
308 batch, | |
309 is_training, | |
310 return_representations=False, | |
311 safe_key=None): | |
312 | |
313 if is_training: | |
314 num_ensemble = np.asarray(self.config.num_ensemble_train) | |
315 else: | |
316 num_ensemble = np.asarray(self.config.num_ensemble_eval) | |
317 | |
318 # Compute representations for each MSA sample and average. | |
319 embedding_module = EmbeddingsAndEvoformer( | |
320 self.config.embeddings_and_evoformer, self.global_config) | |
321 repr_shape = hk.eval_shape( | |
322 lambda: embedding_module(batch, is_training)) | |
323 representations = { | |
324 k: jnp.zeros(v.shape, v.dtype) for (k, v) in repr_shape.items() | |
325 } | |
326 | |
327 def ensemble_body(x, unused_y): | |
328 """Add into representations ensemble.""" | |
329 del unused_y | |
330 representations, safe_key = x | |
331 safe_key, safe_subkey = safe_key.split() | |
332 representations_update = embedding_module( | |
333 batch, is_training, safe_key=safe_subkey) | |
334 | |
335 for k in representations: | |
336 if k not in {'msa', 'true_msa', 'bert_mask'}: | |
337 representations[k] += representations_update[k] * ( | |
338 1. / num_ensemble).astype(representations[k].dtype) | |
339 else: | |
340 representations[k] = representations_update[k] | |
341 | |
342 return (representations, safe_key), None | |
343 | |
344 (representations, _), _ = hk.scan( | |
345 ensemble_body, (representations, safe_key), None, length=num_ensemble) | |
346 | |
347 self.representations = representations | |
348 self.batch = batch | |
349 self.heads = {} | |
350 for head_name, head_config in sorted(self.config.heads.items()): | |
351 if not head_config.weight: | |
352 continue # Do not instantiate zero-weight heads. | |
353 | |
354 head_factory = { | |
355 'masked_msa': | |
356 modules.MaskedMsaHead, | |
357 'distogram': | |
358 modules.DistogramHead, | |
359 'structure_module': | |
360 folding_multimer.StructureModule, | |
361 'predicted_aligned_error': | |
362 modules.PredictedAlignedErrorHead, | |
363 'predicted_lddt': | |
364 modules.PredictedLDDTHead, | |
365 'experimentally_resolved': | |
366 modules.ExperimentallyResolvedHead, | |
367 }[head_name] | |
368 self.heads[head_name] = (head_config, | |
369 head_factory(head_config, self.global_config)) | |
370 | |
371 structure_module_output = None | |
372 if 'entity_id' in batch and 'all_atom_positions' in batch: | |
373 _, fold_module = self.heads['structure_module'] | |
374 structure_module_output = fold_module(representations, batch, is_training) | |
375 | |
376 ret = {} | |
377 ret['representations'] = representations | |
378 | |
379 for name, (head_config, module) in self.heads.items(): | |
380 if name == 'structure_module' and structure_module_output is not None: | |
381 ret[name] = structure_module_output | |
382 representations['structure_module'] = structure_module_output.pop('act') | |
383 # Skip confidence heads until StructureModule is executed. | |
384 elif name in {'predicted_lddt', 'predicted_aligned_error', | |
385 'experimentally_resolved'}: | |
386 continue | |
387 else: | |
388 ret[name] = module(representations, batch, is_training) | |
389 | |
390 # Add confidence heads after StructureModule is executed. | |
391 if self.config.heads.get('predicted_lddt.weight', 0.0): | |
392 name = 'predicted_lddt' | |
393 head_config, module = self.heads[name] | |
394 ret[name] = module(representations, batch, is_training) | |
395 | |
396 if self.config.heads.experimentally_resolved.weight: | |
397 name = 'experimentally_resolved' | |
398 head_config, module = self.heads[name] | |
399 ret[name] = module(representations, batch, is_training) | |
400 | |
401 if self.config.heads.get('predicted_aligned_error.weight', 0.0): | |
402 name = 'predicted_aligned_error' | |
403 head_config, module = self.heads[name] | |
404 ret[name] = module(representations, batch, is_training) | |
405 # Will be used for ipTM computation. | |
406 ret[name]['asym_id'] = batch['asym_id'] | |
407 | |
408 return ret | |
409 | |
410 | |
411 class AlphaFold(hk.Module): | |
412 """AlphaFold-Multimer model with recycling. | |
413 """ | |
414 | |
415 def __init__(self, config, name='alphafold'): | |
416 super().__init__(name=name) | |
417 self.config = config | |
418 self.global_config = config.global_config | |
419 | |
420 def __call__( | |
421 self, | |
422 batch, | |
423 is_training, | |
424 return_representations=False, | |
425 safe_key=None): | |
426 | |
427 c = self.config | |
428 impl = AlphaFoldIteration(c, self.global_config) | |
429 | |
430 if safe_key is None: | |
431 safe_key = prng.SafeKey(hk.next_rng_key()) | |
432 elif isinstance(safe_key, jnp.ndarray): | |
433 safe_key = prng.SafeKey(safe_key) | |
434 | |
435 assert isinstance(batch, dict) | |
436 num_res = batch['aatype'].shape[0] | |
437 | |
438 def get_prev(ret): | |
439 new_prev = { | |
440 'prev_pos': | |
441 ret['structure_module']['final_atom_positions'], | |
442 'prev_msa_first_row': ret['representations']['msa_first_row'], | |
443 'prev_pair': ret['representations']['pair'], | |
444 } | |
445 return jax.tree_map(jax.lax.stop_gradient, new_prev) | |
446 | |
447 def apply_network(prev, safe_key): | |
448 recycled_batch = {**batch, **prev} | |
449 return impl( | |
450 batch=recycled_batch, | |
451 is_training=is_training, | |
452 safe_key=safe_key) | |
453 | |
454 if self.config.num_recycle: | |
455 emb_config = self.config.embeddings_and_evoformer | |
456 prev = { | |
457 'prev_pos': | |
458 jnp.zeros([num_res, residue_constants.atom_type_num, 3]), | |
459 'prev_msa_first_row': | |
460 jnp.zeros([num_res, emb_config.msa_channel]), | |
461 'prev_pair': | |
462 jnp.zeros([num_res, num_res, emb_config.pair_channel]), | |
463 } | |
464 | |
465 if 'num_iter_recycling' in batch: | |
466 # Training time: num_iter_recycling is in batch. | |
467 # Value for each ensemble batch is the same, so arbitrarily taking 0-th. | |
468 num_iter = batch['num_iter_recycling'][0] | |
469 | |
470 # Add insurance that even when ensembling, we will not run more | |
471 # recyclings than the model is configured to run. | |
472 num_iter = jnp.minimum(num_iter, c.num_recycle) | |
473 else: | |
474 # Eval mode or tests: use the maximum number of iterations. | |
475 num_iter = c.num_recycle | |
476 | |
477 def recycle_body(i, x): | |
478 del i | |
479 prev, safe_key = x | |
480 safe_key1, safe_key2 = safe_key.split() if c.resample_msa_in_recycling else safe_key.duplicate() # pylint: disable=line-too-long | |
481 ret = apply_network(prev=prev, safe_key=safe_key2) | |
482 return get_prev(ret), safe_key1 | |
483 | |
484 prev, safe_key = hk.fori_loop(0, num_iter, recycle_body, (prev, safe_key)) | |
485 else: | |
486 prev = {} | |
487 | |
488 # Run extra iteration. | |
489 ret = apply_network(prev=prev, safe_key=safe_key) | |
490 | |
491 if not return_representations: | |
492 del ret['representations'] | |
493 return ret | |
494 | |
495 | |
496 class EmbeddingsAndEvoformer(hk.Module): | |
497 """Embeds the input data and runs Evoformer. | |
498 | |
499 Produces the MSA, single and pair representations. | |
500 """ | |
501 | |
502 def __init__(self, config, global_config, name='evoformer'): | |
503 super().__init__(name=name) | |
504 self.config = config | |
505 self.global_config = global_config | |
506 | |
507 def _relative_encoding(self, batch): | |
508 """Add relative position encodings. | |
509 | |
510 For position (i, j), the value is (i-j) clipped to [-k, k] and one-hotted. | |
511 | |
512 When not using 'use_chain_relative' the residue indices are used as is, e.g. | |
513 for heteromers relative positions will be computed using the positions in | |
514 the corresponding chains. | |
515 | |
516 When using 'use_chain_relative' we add an extra bin that denotes | |
517 'different chain'. Furthermore we also provide the relative chain index | |
518 (i.e. sym_id) clipped and one-hotted to the network. And an extra feature | |
519 which denotes whether they belong to the same chain type, i.e. it's 0 if | |
520 they are in different heteromer chains and 1 otherwise. | |
521 | |
522 Args: | |
523 batch: batch. | |
524 Returns: | |
525 Feature embedding using the features as described before. | |
526 """ | |
527 c = self.config | |
528 rel_feats = [] | |
529 pos = batch['residue_index'] | |
530 asym_id = batch['asym_id'] | |
531 asym_id_same = jnp.equal(asym_id[:, None], asym_id[None, :]) | |
532 offset = pos[:, None] - pos[None, :] | |
533 | |
534 clipped_offset = jnp.clip( | |
535 offset + c.max_relative_idx, a_min=0, a_max=2 * c.max_relative_idx) | |
536 | |
537 if c.use_chain_relative: | |
538 | |
539 final_offset = jnp.where(asym_id_same, clipped_offset, | |
540 (2 * c.max_relative_idx + 1) * | |
541 jnp.ones_like(clipped_offset)) | |
542 | |
543 rel_pos = jax.nn.one_hot(final_offset, 2 * c.max_relative_idx + 2) | |
544 | |
545 rel_feats.append(rel_pos) | |
546 | |
547 entity_id = batch['entity_id'] | |
548 entity_id_same = jnp.equal(entity_id[:, None], entity_id[None, :]) | |
549 rel_feats.append(entity_id_same.astype(rel_pos.dtype)[..., None]) | |
550 | |
551 sym_id = batch['sym_id'] | |
552 rel_sym_id = sym_id[:, None] - sym_id[None, :] | |
553 | |
554 max_rel_chain = c.max_relative_chain | |
555 | |
556 clipped_rel_chain = jnp.clip( | |
557 rel_sym_id + max_rel_chain, a_min=0, a_max=2 * max_rel_chain) | |
558 | |
559 final_rel_chain = jnp.where(entity_id_same, clipped_rel_chain, | |
560 (2 * max_rel_chain + 1) * | |
561 jnp.ones_like(clipped_rel_chain)) | |
562 rel_chain = jax.nn.one_hot(final_rel_chain, 2 * c.max_relative_chain + 2) | |
563 | |
564 rel_feats.append(rel_chain) | |
565 | |
566 else: | |
567 rel_pos = jax.nn.one_hot(clipped_offset, 2 * c.max_relative_idx + 1) | |
568 rel_feats.append(rel_pos) | |
569 | |
570 rel_feat = jnp.concatenate(rel_feats, axis=-1) | |
571 | |
572 return common_modules.Linear( | |
573 c.pair_channel, | |
574 name='position_activations')( | |
575 rel_feat) | |
576 | |
577 def __call__(self, batch, is_training, safe_key=None): | |
578 | |
579 c = self.config | |
580 gc = self.global_config | |
581 | |
582 batch = dict(batch) | |
583 | |
584 if safe_key is None: | |
585 safe_key = prng.SafeKey(hk.next_rng_key()) | |
586 | |
587 output = {} | |
588 | |
589 batch['msa_profile'] = make_msa_profile(batch) | |
590 | |
591 target_feat = jax.nn.one_hot(batch['aatype'], 21) | |
592 | |
593 preprocess_1d = common_modules.Linear( | |
594 c.msa_channel, name='preprocess_1d')( | |
595 target_feat) | |
596 | |
597 safe_key, sample_key, mask_key = safe_key.split(3) | |
598 batch = sample_msa(sample_key, batch, c.num_msa) | |
599 batch = make_masked_msa(batch, mask_key, c.masked_msa) | |
600 | |
601 (batch['cluster_profile'], | |
602 batch['cluster_deletion_mean']) = nearest_neighbor_clusters(batch) | |
603 | |
604 msa_feat = create_msa_feat(batch) | |
605 | |
606 preprocess_msa = common_modules.Linear( | |
607 c.msa_channel, name='preprocess_msa')( | |
608 msa_feat) | |
609 | |
610 msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa | |
611 | |
612 left_single = common_modules.Linear( | |
613 c.pair_channel, name='left_single')( | |
614 target_feat) | |
615 right_single = common_modules.Linear( | |
616 c.pair_channel, name='right_single')( | |
617 target_feat) | |
618 pair_activations = left_single[:, None] + right_single[None] | |
619 mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :] | |
620 mask_2d = mask_2d.astype(jnp.float32) | |
621 | |
622 if c.recycle_pos and 'prev_pos' in batch: | |
623 prev_pseudo_beta = modules.pseudo_beta_fn( | |
624 batch['aatype'], batch['prev_pos'], None) | |
625 | |
626 dgram = modules.dgram_from_positions( | |
627 prev_pseudo_beta, **self.config.prev_pos) | |
628 pair_activations += common_modules.Linear( | |
629 c.pair_channel, name='prev_pos_linear')( | |
630 dgram) | |
631 | |
632 if c.recycle_features: | |
633 if 'prev_msa_first_row' in batch: | |
634 prev_msa_first_row = hk.LayerNorm( | |
635 axis=[-1], | |
636 create_scale=True, | |
637 create_offset=True, | |
638 name='prev_msa_first_row_norm')( | |
639 batch['prev_msa_first_row']) | |
640 msa_activations = msa_activations.at[0].add(prev_msa_first_row) | |
641 | |
642 if 'prev_pair' in batch: | |
643 pair_activations += hk.LayerNorm( | |
644 axis=[-1], | |
645 create_scale=True, | |
646 create_offset=True, | |
647 name='prev_pair_norm')( | |
648 batch['prev_pair']) | |
649 | |
650 if c.max_relative_idx: | |
651 pair_activations += self._relative_encoding(batch) | |
652 | |
653 if c.template.enabled: | |
654 template_module = TemplateEmbedding(c.template, gc) | |
655 template_batch = { | |
656 'template_aatype': batch['template_aatype'], | |
657 'template_all_atom_positions': batch['template_all_atom_positions'], | |
658 'template_all_atom_mask': batch['template_all_atom_mask'] | |
659 } | |
660 # Construct a mask such that only intra-chain template features are | |
661 # computed, since all templates are for each chain individually. | |
662 multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :] | |
663 safe_key, safe_subkey = safe_key.split() | |
664 template_act = template_module( | |
665 query_embedding=pair_activations, | |
666 template_batch=template_batch, | |
667 padding_mask_2d=mask_2d, | |
668 multichain_mask_2d=multichain_mask, | |
669 is_training=is_training, | |
670 safe_key=safe_subkey) | |
671 pair_activations += template_act | |
672 | |
673 # Extra MSA stack. | |
674 (extra_msa_feat, | |
675 extra_msa_mask) = create_extra_msa_feature(batch, c.num_extra_msa) | |
676 extra_msa_activations = common_modules.Linear( | |
677 c.extra_msa_channel, | |
678 name='extra_msa_activations')( | |
679 extra_msa_feat) | |
680 extra_msa_mask = extra_msa_mask.astype(jnp.float32) | |
681 | |
682 extra_evoformer_input = { | |
683 'msa': extra_msa_activations, | |
684 'pair': pair_activations, | |
685 } | |
686 extra_masks = {'msa': extra_msa_mask, 'pair': mask_2d} | |
687 | |
688 extra_evoformer_iteration = modules.EvoformerIteration( | |
689 c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack') | |
690 | |
691 def extra_evoformer_fn(x): | |
692 act, safe_key = x | |
693 safe_key, safe_subkey = safe_key.split() | |
694 extra_evoformer_output = extra_evoformer_iteration( | |
695 activations=act, | |
696 masks=extra_masks, | |
697 is_training=is_training, | |
698 safe_key=safe_subkey) | |
699 return (extra_evoformer_output, safe_key) | |
700 | |
701 if gc.use_remat: | |
702 extra_evoformer_fn = hk.remat(extra_evoformer_fn) | |
703 | |
704 safe_key, safe_subkey = safe_key.split() | |
705 extra_evoformer_stack = layer_stack.layer_stack( | |
706 c.extra_msa_stack_num_block)( | |
707 extra_evoformer_fn) | |
708 extra_evoformer_output, safe_key = extra_evoformer_stack( | |
709 (extra_evoformer_input, safe_subkey)) | |
710 | |
711 pair_activations = extra_evoformer_output['pair'] | |
712 | |
713 # Get the size of the MSA before potentially adding templates, so we | |
714 # can crop out the templates later. | |
715 num_msa_sequences = msa_activations.shape[0] | |
716 evoformer_input = { | |
717 'msa': msa_activations, | |
718 'pair': pair_activations, | |
719 } | |
720 evoformer_masks = {'msa': batch['msa_mask'].astype(jnp.float32), | |
721 'pair': mask_2d} | |
722 | |
723 if c.template.enabled: | |
724 template_features, template_masks = ( | |
725 template_embedding_1d(batch=batch, num_channel=c.msa_channel)) | |
726 | |
727 evoformer_input['msa'] = jnp.concatenate( | |
728 [evoformer_input['msa'], template_features], axis=0) | |
729 evoformer_masks['msa'] = jnp.concatenate( | |
730 [evoformer_masks['msa'], template_masks], axis=0) | |
731 | |
732 evoformer_iteration = modules.EvoformerIteration( | |
733 c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration') | |
734 | |
735 def evoformer_fn(x): | |
736 act, safe_key = x | |
737 safe_key, safe_subkey = safe_key.split() | |
738 evoformer_output = evoformer_iteration( | |
739 activations=act, | |
740 masks=evoformer_masks, | |
741 is_training=is_training, | |
742 safe_key=safe_subkey) | |
743 return (evoformer_output, safe_key) | |
744 | |
745 if gc.use_remat: | |
746 evoformer_fn = hk.remat(evoformer_fn) | |
747 | |
748 safe_key, safe_subkey = safe_key.split() | |
749 evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)( | |
750 evoformer_fn) | |
751 | |
752 def run_evoformer(evoformer_input): | |
753 evoformer_output, _ = evoformer_stack((evoformer_input, safe_subkey)) | |
754 return evoformer_output | |
755 | |
756 evoformer_output = run_evoformer(evoformer_input) | |
757 | |
758 msa_activations = evoformer_output['msa'] | |
759 pair_activations = evoformer_output['pair'] | |
760 | |
761 single_activations = common_modules.Linear( | |
762 c.seq_channel, name='single_activations')( | |
763 msa_activations[0]) | |
764 | |
765 output.update({ | |
766 'single': | |
767 single_activations, | |
768 'pair': | |
769 pair_activations, | |
770 # Crop away template rows such that they are not used in MaskedMsaHead. | |
771 'msa': | |
772 msa_activations[:num_msa_sequences, :, :], | |
773 'msa_first_row': | |
774 msa_activations[0], | |
775 }) | |
776 | |
777 return output | |
778 | |
779 | |
780 class TemplateEmbedding(hk.Module): | |
781 """Embed a set of templates.""" | |
782 | |
783 def __init__(self, config, global_config, name='template_embedding'): | |
784 super().__init__(name=name) | |
785 self.config = config | |
786 self.global_config = global_config | |
787 | |
788 def __call__(self, query_embedding, template_batch, padding_mask_2d, | |
789 multichain_mask_2d, is_training, | |
790 safe_key=None): | |
791 """Generate an embedding for a set of templates. | |
792 | |
793 Args: | |
794 query_embedding: [num_res, num_res, num_channel] a query tensor that will | |
795 be used to attend over the templates to remove the num_templates | |
796 dimension. | |
797 template_batch: A dictionary containing: | |
798 `template_aatype`: [num_templates, num_res] aatype for each template. | |
799 `template_all_atom_positions`: [num_templates, num_res, 37, 3] atom | |
800 positions for all templates. | |
801 `template_all_atom_mask`: [num_templates, num_res, 37] mask for each | |
802 template. | |
803 padding_mask_2d: [num_res, num_res] Pair mask for attention operations. | |
804 multichain_mask_2d: [num_res, num_res] Mask indicating which residue pairs | |
805 are intra-chain, used to mask out residue distance based features | |
806 between chains. | |
807 is_training: bool indicating where we are running in training mode. | |
808 safe_key: random key generator. | |
809 | |
810 Returns: | |
811 An embedding of size [num_res, num_res, num_channels] | |
812 """ | |
813 c = self.config | |
814 if safe_key is None: | |
815 safe_key = prng.SafeKey(hk.next_rng_key()) | |
816 | |
817 num_templates = template_batch['template_aatype'].shape[0] | |
818 num_res, _, query_num_channels = query_embedding.shape | |
819 | |
820 # Embed each template separately. | |
821 template_embedder = SingleTemplateEmbedding(self.config, self.global_config) | |
822 def partial_template_embedder(template_aatype, | |
823 template_all_atom_positions, | |
824 template_all_atom_mask, | |
825 unsafe_key): | |
826 safe_key = prng.SafeKey(unsafe_key) | |
827 return template_embedder(query_embedding, | |
828 template_aatype, | |
829 template_all_atom_positions, | |
830 template_all_atom_mask, | |
831 padding_mask_2d, | |
832 multichain_mask_2d, | |
833 is_training, | |
834 safe_key) | |
835 | |
836 safe_key, unsafe_key = safe_key.split() | |
837 unsafe_keys = jax.random.split(unsafe_key._key, num_templates) | |
838 | |
839 def scan_fn(carry, x): | |
840 return carry + partial_template_embedder(*x), None | |
841 | |
842 scan_init = jnp.zeros((num_res, num_res, c.num_channels), | |
843 dtype=query_embedding.dtype) | |
844 summed_template_embeddings, _ = hk.scan( | |
845 scan_fn, scan_init, | |
846 (template_batch['template_aatype'], | |
847 template_batch['template_all_atom_positions'], | |
848 template_batch['template_all_atom_mask'], unsafe_keys)) | |
849 | |
850 embedding = summed_template_embeddings / num_templates | |
851 embedding = jax.nn.relu(embedding) | |
852 embedding = common_modules.Linear( | |
853 query_num_channels, | |
854 initializer='relu', | |
855 name='output_linear')(embedding) | |
856 | |
857 return embedding | |
858 | |
859 | |
860 class SingleTemplateEmbedding(hk.Module): | |
861 """Embed a single template.""" | |
862 | |
863 def __init__(self, config, global_config, name='single_template_embedding'): | |
864 super().__init__(name=name) | |
865 self.config = config | |
866 self.global_config = global_config | |
867 | |
868 def __call__(self, query_embedding, template_aatype, | |
869 template_all_atom_positions, template_all_atom_mask, | |
870 padding_mask_2d, multichain_mask_2d, is_training, | |
871 safe_key): | |
872 """Build the single template embedding graph. | |
873 | |
874 Args: | |
875 query_embedding: (num_res, num_res, num_channels) - embedding of the | |
876 query sequence/msa. | |
877 template_aatype: [num_res] aatype for each template. | |
878 template_all_atom_positions: [num_res, 37, 3] atom positions for all | |
879 templates. | |
880 template_all_atom_mask: [num_res, 37] mask for each template. | |
881 padding_mask_2d: Padding mask (Note: this doesn't care if a template | |
882 exists, unlike the template_pseudo_beta_mask). | |
883 multichain_mask_2d: A mask indicating intra-chain residue pairs, used | |
884 to mask out between chain distances/features when templates are for | |
885 single chains. | |
886 is_training: Are we in training mode. | |
887 safe_key: Random key generator. | |
888 | |
889 Returns: | |
890 A template embedding (num_res, num_res, num_channels). | |
891 """ | |
892 gc = self.global_config | |
893 c = self.config | |
894 assert padding_mask_2d.dtype == query_embedding.dtype | |
895 dtype = query_embedding.dtype | |
896 num_channels = self.config.num_channels | |
897 | |
898 def construct_input(query_embedding, template_aatype, | |
899 template_all_atom_positions, template_all_atom_mask, | |
900 multichain_mask_2d): | |
901 | |
902 # Compute distogram feature for the template. | |
903 template_positions, pseudo_beta_mask = modules.pseudo_beta_fn( | |
904 template_aatype, template_all_atom_positions, template_all_atom_mask) | |
905 pseudo_beta_mask_2d = (pseudo_beta_mask[:, None] * | |
906 pseudo_beta_mask[None, :]) | |
907 pseudo_beta_mask_2d *= multichain_mask_2d | |
908 template_dgram = modules.dgram_from_positions( | |
909 template_positions, **self.config.dgram_features) | |
910 template_dgram *= pseudo_beta_mask_2d[..., None] | |
911 template_dgram = template_dgram.astype(dtype) | |
912 pseudo_beta_mask_2d = pseudo_beta_mask_2d.astype(dtype) | |
913 to_concat = [(template_dgram, 1), (pseudo_beta_mask_2d, 0)] | |
914 | |
915 aatype = jax.nn.one_hot(template_aatype, 22, axis=-1, dtype=dtype) | |
916 to_concat.append((aatype[None, :, :], 1)) | |
917 to_concat.append((aatype[:, None, :], 1)) | |
918 | |
919 # Compute a feature representing the normalized vector between each | |
920 # backbone affine - i.e. in each residues local frame, what direction are | |
921 # each of the other residues. | |
922 raw_atom_pos = template_all_atom_positions | |
923 | |
924 atom_pos = geometry.Vec3Array.from_array(raw_atom_pos) | |
925 rigid, backbone_mask = folding_multimer.make_backbone_affine( | |
926 atom_pos, | |
927 template_all_atom_mask, | |
928 template_aatype) | |
929 points = rigid.translation | |
930 rigid_vec = rigid[:, None].inverse().apply_to_point(points) | |
931 unit_vector = rigid_vec.normalized() | |
932 unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z] | |
933 | |
934 backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :] | |
935 backbone_mask_2d *= multichain_mask_2d | |
936 unit_vector = [x*backbone_mask_2d for x in unit_vector] | |
937 | |
938 # Note that the backbone_mask takes into account C, CA and N (unlike | |
939 # pseudo beta mask which just needs CB) so we add both masks as features. | |
940 to_concat.extend([(x, 0) for x in unit_vector]) | |
941 to_concat.append((backbone_mask_2d, 0)) | |
942 | |
943 query_embedding = hk.LayerNorm( | |
944 axis=[-1], | |
945 create_scale=True, | |
946 create_offset=True, | |
947 name='query_embedding_norm')( | |
948 query_embedding) | |
949 # Allow the template embedder to see the query embedding. Note this | |
950 # contains the position relative feature, so this is how the network knows | |
951 # which residues are next to each other. | |
952 to_concat.append((query_embedding, 1)) | |
953 | |
954 act = 0 | |
955 | |
956 for i, (x, n_input_dims) in enumerate(to_concat): | |
957 | |
958 act += common_modules.Linear( | |
959 num_channels, | |
960 num_input_dims=n_input_dims, | |
961 initializer='relu', | |
962 name=f'template_pair_embedding_{i}')(x) | |
963 return act | |
964 | |
965 act = construct_input(query_embedding, template_aatype, | |
966 template_all_atom_positions, template_all_atom_mask, | |
967 multichain_mask_2d) | |
968 | |
969 template_iteration = TemplateEmbeddingIteration( | |
970 c.template_pair_stack, gc, name='template_embedding_iteration') | |
971 | |
972 def template_iteration_fn(x): | |
973 act, safe_key = x | |
974 | |
975 safe_key, safe_subkey = safe_key.split() | |
976 act = template_iteration( | |
977 act=act, | |
978 pair_mask=padding_mask_2d, | |
979 is_training=is_training, | |
980 safe_key=safe_subkey) | |
981 return (act, safe_key) | |
982 | |
983 if gc.use_remat: | |
984 template_iteration_fn = hk.remat(template_iteration_fn) | |
985 | |
986 safe_key, safe_subkey = safe_key.split() | |
987 template_stack = layer_stack.layer_stack( | |
988 c.template_pair_stack.num_block)( | |
989 template_iteration_fn) | |
990 act, safe_key = template_stack((act, safe_subkey)) | |
991 | |
992 act = hk.LayerNorm( | |
993 axis=[-1], | |
994 create_scale=True, | |
995 create_offset=True, | |
996 name='output_layer_norm')( | |
997 act) | |
998 return act | |
999 | |
1000 | |
1001 class TemplateEmbeddingIteration(hk.Module): | |
1002 """Single Iteration of Template Embedding.""" | |
1003 | |
1004 def __init__(self, config, global_config, | |
1005 name='template_embedding_iteration'): | |
1006 super().__init__(name=name) | |
1007 self.config = config | |
1008 self.global_config = global_config | |
1009 | |
1010 def __call__(self, act, pair_mask, is_training=True, | |
1011 safe_key=None): | |
1012 """Build a single iteration of the template embedder. | |
1013 | |
1014 Args: | |
1015 act: [num_res, num_res, num_channel] Input pairwise activations. | |
1016 pair_mask: [num_res, num_res] padding mask. | |
1017 is_training: Whether to run in training mode. | |
1018 safe_key: Safe pseudo-random generator key. | |
1019 | |
1020 Returns: | |
1021 [num_res, num_res, num_channel] tensor of activations. | |
1022 """ | |
1023 c = self.config | |
1024 gc = self.global_config | |
1025 | |
1026 if safe_key is None: | |
1027 safe_key = prng.SafeKey(hk.next_rng_key()) | |
1028 | |
1029 dropout_wrapper_fn = functools.partial( | |
1030 modules.dropout_wrapper, | |
1031 is_training=is_training, | |
1032 global_config=gc) | |
1033 | |
1034 safe_key, *sub_keys = safe_key.split(20) | |
1035 sub_keys = iter(sub_keys) | |
1036 | |
1037 act = dropout_wrapper_fn( | |
1038 modules.TriangleMultiplication(c.triangle_multiplication_outgoing, gc, | |
1039 name='triangle_multiplication_outgoing'), | |
1040 act, | |
1041 pair_mask, | |
1042 safe_key=next(sub_keys)) | |
1043 | |
1044 act = dropout_wrapper_fn( | |
1045 modules.TriangleMultiplication(c.triangle_multiplication_incoming, gc, | |
1046 name='triangle_multiplication_incoming'), | |
1047 act, | |
1048 pair_mask, | |
1049 safe_key=next(sub_keys)) | |
1050 | |
1051 act = dropout_wrapper_fn( | |
1052 modules.TriangleAttention(c.triangle_attention_starting_node, gc, | |
1053 name='triangle_attention_starting_node'), | |
1054 act, | |
1055 pair_mask, | |
1056 safe_key=next(sub_keys)) | |
1057 | |
1058 act = dropout_wrapper_fn( | |
1059 modules.TriangleAttention(c.triangle_attention_ending_node, gc, | |
1060 name='triangle_attention_ending_node'), | |
1061 act, | |
1062 pair_mask, | |
1063 safe_key=next(sub_keys)) | |
1064 | |
1065 act = dropout_wrapper_fn( | |
1066 modules.Transition(c.pair_transition, gc, | |
1067 name='pair_transition'), | |
1068 act, | |
1069 pair_mask, | |
1070 safe_key=next(sub_keys)) | |
1071 | |
1072 return act | |
1073 | |
1074 | |
1075 def template_embedding_1d(batch, num_channel): | |
1076 """Embed templates into an (num_res, num_templates, num_channels) embedding. | |
1077 | |
1078 Args: | |
1079 batch: A batch containing: | |
1080 template_aatype, (num_templates, num_res) aatype for the templates. | |
1081 template_all_atom_positions, (num_templates, num_residues, 37, 3) atom | |
1082 positions for the templates. | |
1083 template_all_atom_mask, (num_templates, num_residues, 37) atom mask for | |
1084 each template. | |
1085 num_channel: The number of channels in the output. | |
1086 | |
1087 Returns: | |
1088 An embedding of shape (num_templates, num_res, num_channels) and a mask of | |
1089 shape (num_templates, num_res). | |
1090 """ | |
1091 | |
1092 # Embed the templates aatypes. | |
1093 aatype_one_hot = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1) | |
1094 | |
1095 num_templates = batch['template_aatype'].shape[0] | |
1096 all_chi_angles = [] | |
1097 all_chi_masks = [] | |
1098 for i in range(num_templates): | |
1099 atom_pos = geometry.Vec3Array.from_array( | |
1100 batch['template_all_atom_positions'][i, :, :, :]) | |
1101 template_chi_angles, template_chi_mask = all_atom_multimer.compute_chi_angles( | |
1102 atom_pos, | |
1103 batch['template_all_atom_mask'][i, :, :], | |
1104 batch['template_aatype'][i, :]) | |
1105 all_chi_angles.append(template_chi_angles) | |
1106 all_chi_masks.append(template_chi_mask) | |
1107 chi_angles = jnp.stack(all_chi_angles, axis=0) | |
1108 chi_mask = jnp.stack(all_chi_masks, axis=0) | |
1109 | |
1110 template_features = jnp.concatenate([ | |
1111 aatype_one_hot, | |
1112 jnp.sin(chi_angles) * chi_mask, | |
1113 jnp.cos(chi_angles) * chi_mask, | |
1114 chi_mask], axis=-1) | |
1115 | |
1116 template_mask = chi_mask[:, :, 0] | |
1117 | |
1118 template_activations = common_modules.Linear( | |
1119 num_channel, | |
1120 initializer='relu', | |
1121 name='template_single_embedding')( | |
1122 template_features) | |
1123 template_activations = jax.nn.relu(template_activations) | |
1124 template_activations = common_modules.Linear( | |
1125 num_channel, | |
1126 initializer='relu', | |
1127 name='template_projection')( | |
1128 template_activations) | |
1129 return template_activations, template_mask |