annotate docker/alphafold/alphafold/model/modules_multimer.py @ 1:6c92e000d684 draft

"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author galaxy-australia
date Tue, 01 Mar 2022 02:53:05 +0000
parents
children
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
1
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1 # Copyright 2021 DeepMind Technologies Limited
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
2 #
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
3 # Licensed under the Apache License, Version 2.0 (the "License");
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
4 # you may not use this file except in compliance with the License.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
5 # You may obtain a copy of the License at
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
6 #
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
7 # http://www.apache.org/licenses/LICENSE-2.0
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
8 #
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
9 # Unless required by applicable law or agreed to in writing, software
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
10 # distributed under the License is distributed on an "AS IS" BASIS,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
12 # See the License for the specific language governing permissions and
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
13 # limitations under the License.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
14
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
15 """Core modules, which have been refactored in AlphaFold-Multimer.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
16
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
17 The main difference is that MSA sampling pipeline is moved inside the JAX model
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
18 for easier implementation of recycling and ensembling.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
19
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
20 Lower-level modules up to EvoformerIteration are reused from modules.py.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
21 """
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
22
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
23 import functools
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
24 from typing import Sequence
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
25
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
26 from alphafold.common import residue_constants
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
27 from alphafold.model import all_atom_multimer
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
28 from alphafold.model import common_modules
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
29 from alphafold.model import folding_multimer
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
30 from alphafold.model import geometry
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
31 from alphafold.model import layer_stack
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
32 from alphafold.model import modules
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
33 from alphafold.model import prng
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
34 from alphafold.model import utils
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
35
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
36 import haiku as hk
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
37 import jax
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
38 import jax.numpy as jnp
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
39 import numpy as np
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
40
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
41
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
42 def reduce_fn(x, mode):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
43 if mode == 'none' or mode is None:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
44 return jnp.asarray(x)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
45 elif mode == 'sum':
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
46 return jnp.asarray(x).sum()
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
47 elif mode == 'mean':
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
48 return jnp.mean(jnp.asarray(x))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
49 else:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
50 raise ValueError('Unsupported reduction option.')
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
51
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
52
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
53 def gumbel_noise(key: jnp.ndarray, shape: Sequence[int]) -> jnp.ndarray:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
54 """Generate Gumbel Noise of given Shape.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
55
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
56 This generates samples from Gumbel(0, 1).
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
57
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
58 Args:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
59 key: Jax random number key.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
60 shape: Shape of noise to return.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
61
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
62 Returns:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
63 Gumbel noise of given shape.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
64 """
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
65 epsilon = 1e-6
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
66 uniform = utils.padding_consistent_rng(jax.random.uniform)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
67 uniform_noise = uniform(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
68 key, shape=shape, dtype=jnp.float32, minval=0., maxval=1.)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
69 gumbel = -jnp.log(-jnp.log(uniform_noise + epsilon) + epsilon)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
70 return gumbel
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
71
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
72
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
73 def gumbel_max_sample(key: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
74 """Samples from a probability distribution given by 'logits'.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
75
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
76 This uses Gumbel-max trick to implement the sampling in an efficient manner.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
77
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
78 Args:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
79 key: prng key.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
80 logits: Logarithm of probabilities to sample from, probabilities can be
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
81 unnormalized.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
82
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
83 Returns:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
84 Sample from logprobs in one-hot form.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
85 """
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
86 z = gumbel_noise(key, logits.shape)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
87 return jax.nn.one_hot(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
88 jnp.argmax(logits + z, axis=-1),
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
89 logits.shape[-1],
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
90 dtype=logits.dtype)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
91
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
92
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
93 def gumbel_argsort_sample_idx(key: jnp.ndarray,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
94 logits: jnp.ndarray) -> jnp.ndarray:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
95 """Samples with replacement from a distribution given by 'logits'.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
96
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
97 This uses Gumbel trick to implement the sampling an efficient manner. For a
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
98 distribution over k items this samples k times without replacement, so this
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
99 is effectively sampling a random permutation with probabilities over the
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
100 permutations derived from the logprobs.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
101
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
102 Args:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
103 key: prng key.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
104 logits: Logarithm of probabilities to sample from, probabilities can be
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
105 unnormalized.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
106
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
107 Returns:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
108 Sample from logprobs in one-hot form.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
109 """
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
110 z = gumbel_noise(key, logits.shape)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
111 # This construction is equivalent to jnp.argsort, but using a non stable sort,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
112 # since stable sort's aren't supported by jax2tf.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
113 axis = len(logits.shape) - 1
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
114 iota = jax.lax.broadcasted_iota(jnp.int64, logits.shape, axis)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
115 _, perm = jax.lax.sort_key_val(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
116 logits + z, iota, dimension=-1, is_stable=False)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
117 return perm[::-1]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
118
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
119
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
120 def make_masked_msa(batch, key, config, epsilon=1e-6):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
121 """Create data for BERT on raw MSA."""
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
122 # Add a random amino acid uniformly.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
123 random_aa = jnp.array([0.05] * 20 + [0., 0.], dtype=jnp.float32)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
124
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
125 categorical_probs = (
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
126 config.uniform_prob * random_aa +
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
127 config.profile_prob * batch['msa_profile'] +
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
128 config.same_prob * jax.nn.one_hot(batch['msa'], 22))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
129
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
130 # Put all remaining probability on [MASK] which is a new column.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
131 pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
132 pad_shapes[-1][1] = 1
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
133 mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
134 assert mask_prob >= 0.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
135 categorical_probs = jnp.pad(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
136 categorical_probs, pad_shapes, constant_values=mask_prob)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
137 sh = batch['msa'].shape
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
138 key, mask_subkey, gumbel_subkey = key.split(3)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
139 uniform = utils.padding_consistent_rng(jax.random.uniform)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
140 mask_position = uniform(mask_subkey.get(), sh) < config.replace_fraction
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
141 mask_position *= batch['msa_mask']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
142
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
143 logits = jnp.log(categorical_probs + epsilon)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
144 bert_msa = gumbel_max_sample(gumbel_subkey.get(), logits)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
145 bert_msa = jnp.where(mask_position,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
146 jnp.argmax(bert_msa, axis=-1), batch['msa'])
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
147 bert_msa *= batch['msa_mask']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
148
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
149 # Mix real and masked MSA.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
150 if 'bert_mask' in batch:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
151 batch['bert_mask'] *= mask_position.astype(jnp.float32)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
152 else:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
153 batch['bert_mask'] = mask_position.astype(jnp.float32)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
154 batch['true_msa'] = batch['msa']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
155 batch['msa'] = bert_msa
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
156
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
157 return batch
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
158
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
159
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
160 def nearest_neighbor_clusters(batch, gap_agreement_weight=0.):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
161 """Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
162
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
163 # Determine how much weight we assign to each agreement. In theory, we could
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
164 # use a full blosum matrix here, but right now let's just down-weight gap
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
165 # agreement because it could be spurious.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
166 # Never put weight on agreeing on BERT mask.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
167
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
168 weights = jnp.array(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
169 [1.] * 21 + [gap_agreement_weight] + [0.], dtype=jnp.float32)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
170
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
171 msa_mask = batch['msa_mask']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
172 msa_one_hot = jax.nn.one_hot(batch['msa'], 23)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
173
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
174 extra_mask = batch['extra_msa_mask']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
175 extra_one_hot = jax.nn.one_hot(batch['extra_msa'], 23)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
176
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
177 msa_one_hot_masked = msa_mask[:, :, None] * msa_one_hot
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
178 extra_one_hot_masked = extra_mask[:, :, None] * extra_one_hot
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
179
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
180 agreement = jnp.einsum('mrc, nrc->nm', extra_one_hot_masked,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
181 weights * msa_one_hot_masked)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
182
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
183 cluster_assignment = jax.nn.softmax(1e3 * agreement, axis=0)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
184 cluster_assignment *= jnp.einsum('mr, nr->mn', msa_mask, extra_mask)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
185
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
186 cluster_count = jnp.sum(cluster_assignment, axis=-1)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
187 cluster_count += 1. # We always include the sequence itself.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
188
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
189 msa_sum = jnp.einsum('nm, mrc->nrc', cluster_assignment, extra_one_hot_masked)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
190 msa_sum += msa_one_hot_masked
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
191
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
192 cluster_profile = msa_sum / cluster_count[:, None, None]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
193
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
194 extra_deletion_matrix = batch['extra_deletion_matrix']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
195 deletion_matrix = batch['deletion_matrix']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
196
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
197 del_sum = jnp.einsum('nm, mc->nc', cluster_assignment,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
198 extra_mask * extra_deletion_matrix)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
199 del_sum += deletion_matrix # Original sequence.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
200 cluster_deletion_mean = del_sum / cluster_count[:, None]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
201
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
202 return cluster_profile, cluster_deletion_mean
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
203
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
204
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
205 def create_msa_feat(batch):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
206 """Create and concatenate MSA features."""
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
207 msa_1hot = jax.nn.one_hot(batch['msa'], 23)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
208 deletion_matrix = batch['deletion_matrix']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
209 has_deletion = jnp.clip(deletion_matrix, 0., 1.)[..., None]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
210 deletion_value = (jnp.arctan(deletion_matrix / 3.) * (2. / jnp.pi))[..., None]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
211
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
212 deletion_mean_value = (jnp.arctan(batch['cluster_deletion_mean'] / 3.) *
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
213 (2. / jnp.pi))[..., None]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
214
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
215 msa_feat = [
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
216 msa_1hot,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
217 has_deletion,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
218 deletion_value,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
219 batch['cluster_profile'],
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
220 deletion_mean_value
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
221 ]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
222
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
223 return jnp.concatenate(msa_feat, axis=-1)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
224
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
225
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
226 def create_extra_msa_feature(batch, num_extra_msa):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
227 """Expand extra_msa into 1hot and concat with other extra msa features.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
228
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
229 We do this as late as possible as the one_hot extra msa can be very large.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
230
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
231 Args:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
232 batch: a dictionary with the following keys:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
233 * 'extra_msa': [num_seq, num_res] MSA that wasn't selected as a cluster
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
234 centre. Note - This isn't one-hotted.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
235 * 'extra_deletion_matrix': [num_seq, num_res] Number of deletions at given
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
236 position.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
237 num_extra_msa: Number of extra msa to use.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
238
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
239 Returns:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
240 Concatenated tensor of extra MSA features.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
241 """
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
242 # 23 = 20 amino acids + 'X' for unknown + gap + bert mask
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
243 extra_msa = batch['extra_msa'][:num_extra_msa]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
244 deletion_matrix = batch['extra_deletion_matrix'][:num_extra_msa]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
245 msa_1hot = jax.nn.one_hot(extra_msa, 23)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
246 has_deletion = jnp.clip(deletion_matrix, 0., 1.)[..., None]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
247 deletion_value = (jnp.arctan(deletion_matrix / 3.) * (2. / jnp.pi))[..., None]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
248 extra_msa_mask = batch['extra_msa_mask'][:num_extra_msa]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
249 return jnp.concatenate([msa_1hot, has_deletion, deletion_value],
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
250 axis=-1), extra_msa_mask
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
251
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
252
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
253 def sample_msa(key, batch, max_seq):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
254 """Sample MSA randomly, remaining sequences are stored as `extra_*`.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
255
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
256 Args:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
257 key: safe key for random number generation.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
258 batch: batch to sample msa from.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
259 max_seq: number of sequences to sample.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
260 Returns:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
261 Protein with sampled msa.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
262 """
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
263 # Sample uniformly among sequences with at least one non-masked position.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
264 logits = (jnp.clip(jnp.sum(batch['msa_mask'], axis=-1), 0., 1.) - 1.) * 1e6
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
265 # The cluster_bias_mask can be used to preserve the first row (target
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
266 # sequence) for each chain, for example.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
267 if 'cluster_bias_mask' not in batch:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
268 cluster_bias_mask = jnp.pad(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
269 jnp.zeros(batch['msa'].shape[0] - 1), (1, 0), constant_values=1.)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
270 else:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
271 cluster_bias_mask = batch['cluster_bias_mask']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
272
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
273 logits += cluster_bias_mask * 1e6
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
274 index_order = gumbel_argsort_sample_idx(key.get(), logits)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
275 sel_idx = index_order[:max_seq]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
276 extra_idx = index_order[max_seq:]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
277
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
278 for k in ['msa', 'deletion_matrix', 'msa_mask', 'bert_mask']:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
279 if k in batch:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
280 batch['extra_' + k] = batch[k][extra_idx]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
281 batch[k] = batch[k][sel_idx]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
282
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
283 return batch
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
284
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
285
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
286 def make_msa_profile(batch):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
287 """Compute the MSA profile."""
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
288
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
289 # Compute the profile for every residue (over all MSA sequences).
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
290 return utils.mask_mean(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
291 batch['msa_mask'][:, :, None], jax.nn.one_hot(batch['msa'], 22), axis=0)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
292
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
293
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
294 class AlphaFoldIteration(hk.Module):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
295 """A single recycling iteration of AlphaFold architecture.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
296
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
297 Computes ensembled (averaged) representations from the provided features.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
298 These representations are then passed to the various heads
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
299 that have been requested by the configuration file.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
300 """
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
301
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
302 def __init__(self, config, global_config, name='alphafold_iteration'):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
303 super().__init__(name=name)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
304 self.config = config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
305 self.global_config = global_config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
306
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
307 def __call__(self,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
308 batch,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
309 is_training,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
310 return_representations=False,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
311 safe_key=None):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
312
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
313 if is_training:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
314 num_ensemble = np.asarray(self.config.num_ensemble_train)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
315 else:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
316 num_ensemble = np.asarray(self.config.num_ensemble_eval)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
317
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
318 # Compute representations for each MSA sample and average.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
319 embedding_module = EmbeddingsAndEvoformer(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
320 self.config.embeddings_and_evoformer, self.global_config)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
321 repr_shape = hk.eval_shape(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
322 lambda: embedding_module(batch, is_training))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
323 representations = {
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
324 k: jnp.zeros(v.shape, v.dtype) for (k, v) in repr_shape.items()
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
325 }
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
326
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
327 def ensemble_body(x, unused_y):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
328 """Add into representations ensemble."""
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
329 del unused_y
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
330 representations, safe_key = x
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
331 safe_key, safe_subkey = safe_key.split()
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
332 representations_update = embedding_module(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
333 batch, is_training, safe_key=safe_subkey)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
334
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
335 for k in representations:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
336 if k not in {'msa', 'true_msa', 'bert_mask'}:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
337 representations[k] += representations_update[k] * (
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
338 1. / num_ensemble).astype(representations[k].dtype)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
339 else:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
340 representations[k] = representations_update[k]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
341
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
342 return (representations, safe_key), None
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
343
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
344 (representations, _), _ = hk.scan(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
345 ensemble_body, (representations, safe_key), None, length=num_ensemble)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
346
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
347 self.representations = representations
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
348 self.batch = batch
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
349 self.heads = {}
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
350 for head_name, head_config in sorted(self.config.heads.items()):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
351 if not head_config.weight:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
352 continue # Do not instantiate zero-weight heads.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
353
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
354 head_factory = {
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
355 'masked_msa':
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
356 modules.MaskedMsaHead,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
357 'distogram':
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
358 modules.DistogramHead,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
359 'structure_module':
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
360 folding_multimer.StructureModule,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
361 'predicted_aligned_error':
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
362 modules.PredictedAlignedErrorHead,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
363 'predicted_lddt':
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
364 modules.PredictedLDDTHead,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
365 'experimentally_resolved':
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
366 modules.ExperimentallyResolvedHead,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
367 }[head_name]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
368 self.heads[head_name] = (head_config,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
369 head_factory(head_config, self.global_config))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
370
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
371 structure_module_output = None
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
372 if 'entity_id' in batch and 'all_atom_positions' in batch:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
373 _, fold_module = self.heads['structure_module']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
374 structure_module_output = fold_module(representations, batch, is_training)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
375
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
376 ret = {}
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
377 ret['representations'] = representations
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
378
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
379 for name, (head_config, module) in self.heads.items():
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
380 if name == 'structure_module' and structure_module_output is not None:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
381 ret[name] = structure_module_output
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
382 representations['structure_module'] = structure_module_output.pop('act')
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
383 # Skip confidence heads until StructureModule is executed.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
384 elif name in {'predicted_lddt', 'predicted_aligned_error',
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
385 'experimentally_resolved'}:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
386 continue
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
387 else:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
388 ret[name] = module(representations, batch, is_training)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
389
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
390 # Add confidence heads after StructureModule is executed.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
391 if self.config.heads.get('predicted_lddt.weight', 0.0):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
392 name = 'predicted_lddt'
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
393 head_config, module = self.heads[name]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
394 ret[name] = module(representations, batch, is_training)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
395
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
396 if self.config.heads.experimentally_resolved.weight:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
397 name = 'experimentally_resolved'
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
398 head_config, module = self.heads[name]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
399 ret[name] = module(representations, batch, is_training)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
400
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
401 if self.config.heads.get('predicted_aligned_error.weight', 0.0):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
402 name = 'predicted_aligned_error'
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
403 head_config, module = self.heads[name]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
404 ret[name] = module(representations, batch, is_training)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
405 # Will be used for ipTM computation.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
406 ret[name]['asym_id'] = batch['asym_id']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
407
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
408 return ret
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
409
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
410
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
411 class AlphaFold(hk.Module):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
412 """AlphaFold-Multimer model with recycling.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
413 """
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
414
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
415 def __init__(self, config, name='alphafold'):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
416 super().__init__(name=name)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
417 self.config = config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
418 self.global_config = config.global_config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
419
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
420 def __call__(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
421 self,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
422 batch,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
423 is_training,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
424 return_representations=False,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
425 safe_key=None):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
426
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
427 c = self.config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
428 impl = AlphaFoldIteration(c, self.global_config)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
429
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
430 if safe_key is None:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
431 safe_key = prng.SafeKey(hk.next_rng_key())
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
432 elif isinstance(safe_key, jnp.ndarray):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
433 safe_key = prng.SafeKey(safe_key)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
434
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
435 assert isinstance(batch, dict)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
436 num_res = batch['aatype'].shape[0]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
437
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
438 def get_prev(ret):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
439 new_prev = {
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
440 'prev_pos':
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
441 ret['structure_module']['final_atom_positions'],
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
442 'prev_msa_first_row': ret['representations']['msa_first_row'],
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
443 'prev_pair': ret['representations']['pair'],
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
444 }
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
445 return jax.tree_map(jax.lax.stop_gradient, new_prev)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
446
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
447 def apply_network(prev, safe_key):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
448 recycled_batch = {**batch, **prev}
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
449 return impl(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
450 batch=recycled_batch,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
451 is_training=is_training,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
452 safe_key=safe_key)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
453
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
454 if self.config.num_recycle:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
455 emb_config = self.config.embeddings_and_evoformer
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
456 prev = {
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
457 'prev_pos':
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
458 jnp.zeros([num_res, residue_constants.atom_type_num, 3]),
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
459 'prev_msa_first_row':
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
460 jnp.zeros([num_res, emb_config.msa_channel]),
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
461 'prev_pair':
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
462 jnp.zeros([num_res, num_res, emb_config.pair_channel]),
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
463 }
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
464
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
465 if 'num_iter_recycling' in batch:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
466 # Training time: num_iter_recycling is in batch.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
467 # Value for each ensemble batch is the same, so arbitrarily taking 0-th.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
468 num_iter = batch['num_iter_recycling'][0]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
469
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
470 # Add insurance that even when ensembling, we will not run more
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
471 # recyclings than the model is configured to run.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
472 num_iter = jnp.minimum(num_iter, c.num_recycle)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
473 else:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
474 # Eval mode or tests: use the maximum number of iterations.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
475 num_iter = c.num_recycle
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
476
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
477 def recycle_body(i, x):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
478 del i
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
479 prev, safe_key = x
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
480 safe_key1, safe_key2 = safe_key.split() if c.resample_msa_in_recycling else safe_key.duplicate() # pylint: disable=line-too-long
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
481 ret = apply_network(prev=prev, safe_key=safe_key2)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
482 return get_prev(ret), safe_key1
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
483
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
484 prev, safe_key = hk.fori_loop(0, num_iter, recycle_body, (prev, safe_key))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
485 else:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
486 prev = {}
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
487
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
488 # Run extra iteration.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
489 ret = apply_network(prev=prev, safe_key=safe_key)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
490
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
491 if not return_representations:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
492 del ret['representations']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
493 return ret
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
494
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
495
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
496 class EmbeddingsAndEvoformer(hk.Module):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
497 """Embeds the input data and runs Evoformer.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
498
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
499 Produces the MSA, single and pair representations.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
500 """
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
501
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
502 def __init__(self, config, global_config, name='evoformer'):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
503 super().__init__(name=name)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
504 self.config = config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
505 self.global_config = global_config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
506
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
507 def _relative_encoding(self, batch):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
508 """Add relative position encodings.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
509
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
510 For position (i, j), the value is (i-j) clipped to [-k, k] and one-hotted.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
511
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
512 When not using 'use_chain_relative' the residue indices are used as is, e.g.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
513 for heteromers relative positions will be computed using the positions in
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
514 the corresponding chains.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
515
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
516 When using 'use_chain_relative' we add an extra bin that denotes
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
517 'different chain'. Furthermore we also provide the relative chain index
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
518 (i.e. sym_id) clipped and one-hotted to the network. And an extra feature
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
519 which denotes whether they belong to the same chain type, i.e. it's 0 if
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
520 they are in different heteromer chains and 1 otherwise.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
521
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
522 Args:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
523 batch: batch.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
524 Returns:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
525 Feature embedding using the features as described before.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
526 """
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
527 c = self.config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
528 rel_feats = []
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
529 pos = batch['residue_index']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
530 asym_id = batch['asym_id']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
531 asym_id_same = jnp.equal(asym_id[:, None], asym_id[None, :])
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
532 offset = pos[:, None] - pos[None, :]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
533
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
534 clipped_offset = jnp.clip(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
535 offset + c.max_relative_idx, a_min=0, a_max=2 * c.max_relative_idx)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
536
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
537 if c.use_chain_relative:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
538
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
539 final_offset = jnp.where(asym_id_same, clipped_offset,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
540 (2 * c.max_relative_idx + 1) *
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
541 jnp.ones_like(clipped_offset))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
542
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
543 rel_pos = jax.nn.one_hot(final_offset, 2 * c.max_relative_idx + 2)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
544
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
545 rel_feats.append(rel_pos)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
546
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
547 entity_id = batch['entity_id']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
548 entity_id_same = jnp.equal(entity_id[:, None], entity_id[None, :])
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
549 rel_feats.append(entity_id_same.astype(rel_pos.dtype)[..., None])
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
550
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
551 sym_id = batch['sym_id']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
552 rel_sym_id = sym_id[:, None] - sym_id[None, :]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
553
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
554 max_rel_chain = c.max_relative_chain
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
555
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
556 clipped_rel_chain = jnp.clip(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
557 rel_sym_id + max_rel_chain, a_min=0, a_max=2 * max_rel_chain)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
558
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
559 final_rel_chain = jnp.where(entity_id_same, clipped_rel_chain,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
560 (2 * max_rel_chain + 1) *
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
561 jnp.ones_like(clipped_rel_chain))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
562 rel_chain = jax.nn.one_hot(final_rel_chain, 2 * c.max_relative_chain + 2)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
563
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
564 rel_feats.append(rel_chain)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
565
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
566 else:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
567 rel_pos = jax.nn.one_hot(clipped_offset, 2 * c.max_relative_idx + 1)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
568 rel_feats.append(rel_pos)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
569
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
570 rel_feat = jnp.concatenate(rel_feats, axis=-1)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
571
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
572 return common_modules.Linear(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
573 c.pair_channel,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
574 name='position_activations')(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
575 rel_feat)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
576
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
577 def __call__(self, batch, is_training, safe_key=None):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
578
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
579 c = self.config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
580 gc = self.global_config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
581
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
582 batch = dict(batch)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
583
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
584 if safe_key is None:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
585 safe_key = prng.SafeKey(hk.next_rng_key())
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
586
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
587 output = {}
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
588
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
589 batch['msa_profile'] = make_msa_profile(batch)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
590
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
591 target_feat = jax.nn.one_hot(batch['aatype'], 21)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
592
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
593 preprocess_1d = common_modules.Linear(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
594 c.msa_channel, name='preprocess_1d')(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
595 target_feat)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
596
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
597 safe_key, sample_key, mask_key = safe_key.split(3)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
598 batch = sample_msa(sample_key, batch, c.num_msa)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
599 batch = make_masked_msa(batch, mask_key, c.masked_msa)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
600
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
601 (batch['cluster_profile'],
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
602 batch['cluster_deletion_mean']) = nearest_neighbor_clusters(batch)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
603
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
604 msa_feat = create_msa_feat(batch)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
605
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
606 preprocess_msa = common_modules.Linear(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
607 c.msa_channel, name='preprocess_msa')(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
608 msa_feat)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
609
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
610 msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
611
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
612 left_single = common_modules.Linear(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
613 c.pair_channel, name='left_single')(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
614 target_feat)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
615 right_single = common_modules.Linear(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
616 c.pair_channel, name='right_single')(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
617 target_feat)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
618 pair_activations = left_single[:, None] + right_single[None]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
619 mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
620 mask_2d = mask_2d.astype(jnp.float32)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
621
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
622 if c.recycle_pos and 'prev_pos' in batch:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
623 prev_pseudo_beta = modules.pseudo_beta_fn(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
624 batch['aatype'], batch['prev_pos'], None)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
625
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
626 dgram = modules.dgram_from_positions(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
627 prev_pseudo_beta, **self.config.prev_pos)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
628 pair_activations += common_modules.Linear(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
629 c.pair_channel, name='prev_pos_linear')(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
630 dgram)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
631
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
632 if c.recycle_features:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
633 if 'prev_msa_first_row' in batch:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
634 prev_msa_first_row = hk.LayerNorm(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
635 axis=[-1],
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
636 create_scale=True,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
637 create_offset=True,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
638 name='prev_msa_first_row_norm')(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
639 batch['prev_msa_first_row'])
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
640 msa_activations = msa_activations.at[0].add(prev_msa_first_row)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
641
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
642 if 'prev_pair' in batch:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
643 pair_activations += hk.LayerNorm(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
644 axis=[-1],
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
645 create_scale=True,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
646 create_offset=True,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
647 name='prev_pair_norm')(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
648 batch['prev_pair'])
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
649
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
650 if c.max_relative_idx:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
651 pair_activations += self._relative_encoding(batch)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
652
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
653 if c.template.enabled:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
654 template_module = TemplateEmbedding(c.template, gc)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
655 template_batch = {
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
656 'template_aatype': batch['template_aatype'],
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
657 'template_all_atom_positions': batch['template_all_atom_positions'],
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
658 'template_all_atom_mask': batch['template_all_atom_mask']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
659 }
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
660 # Construct a mask such that only intra-chain template features are
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
661 # computed, since all templates are for each chain individually.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
662 multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
663 safe_key, safe_subkey = safe_key.split()
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
664 template_act = template_module(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
665 query_embedding=pair_activations,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
666 template_batch=template_batch,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
667 padding_mask_2d=mask_2d,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
668 multichain_mask_2d=multichain_mask,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
669 is_training=is_training,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
670 safe_key=safe_subkey)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
671 pair_activations += template_act
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
672
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
673 # Extra MSA stack.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
674 (extra_msa_feat,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
675 extra_msa_mask) = create_extra_msa_feature(batch, c.num_extra_msa)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
676 extra_msa_activations = common_modules.Linear(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
677 c.extra_msa_channel,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
678 name='extra_msa_activations')(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
679 extra_msa_feat)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
680 extra_msa_mask = extra_msa_mask.astype(jnp.float32)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
681
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
682 extra_evoformer_input = {
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
683 'msa': extra_msa_activations,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
684 'pair': pair_activations,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
685 }
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
686 extra_masks = {'msa': extra_msa_mask, 'pair': mask_2d}
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
687
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
688 extra_evoformer_iteration = modules.EvoformerIteration(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
689 c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack')
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
690
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
691 def extra_evoformer_fn(x):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
692 act, safe_key = x
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
693 safe_key, safe_subkey = safe_key.split()
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
694 extra_evoformer_output = extra_evoformer_iteration(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
695 activations=act,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
696 masks=extra_masks,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
697 is_training=is_training,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
698 safe_key=safe_subkey)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
699 return (extra_evoformer_output, safe_key)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
700
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
701 if gc.use_remat:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
702 extra_evoformer_fn = hk.remat(extra_evoformer_fn)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
703
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
704 safe_key, safe_subkey = safe_key.split()
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
705 extra_evoformer_stack = layer_stack.layer_stack(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
706 c.extra_msa_stack_num_block)(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
707 extra_evoformer_fn)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
708 extra_evoformer_output, safe_key = extra_evoformer_stack(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
709 (extra_evoformer_input, safe_subkey))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
710
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
711 pair_activations = extra_evoformer_output['pair']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
712
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
713 # Get the size of the MSA before potentially adding templates, so we
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
714 # can crop out the templates later.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
715 num_msa_sequences = msa_activations.shape[0]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
716 evoformer_input = {
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
717 'msa': msa_activations,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
718 'pair': pair_activations,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
719 }
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
720 evoformer_masks = {'msa': batch['msa_mask'].astype(jnp.float32),
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
721 'pair': mask_2d}
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
722
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
723 if c.template.enabled:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
724 template_features, template_masks = (
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
725 template_embedding_1d(batch=batch, num_channel=c.msa_channel))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
726
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
727 evoformer_input['msa'] = jnp.concatenate(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
728 [evoformer_input['msa'], template_features], axis=0)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
729 evoformer_masks['msa'] = jnp.concatenate(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
730 [evoformer_masks['msa'], template_masks], axis=0)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
731
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
732 evoformer_iteration = modules.EvoformerIteration(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
733 c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration')
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
734
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
735 def evoformer_fn(x):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
736 act, safe_key = x
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
737 safe_key, safe_subkey = safe_key.split()
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
738 evoformer_output = evoformer_iteration(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
739 activations=act,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
740 masks=evoformer_masks,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
741 is_training=is_training,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
742 safe_key=safe_subkey)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
743 return (evoformer_output, safe_key)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
744
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
745 if gc.use_remat:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
746 evoformer_fn = hk.remat(evoformer_fn)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
747
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
748 safe_key, safe_subkey = safe_key.split()
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
749 evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
750 evoformer_fn)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
751
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
752 def run_evoformer(evoformer_input):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
753 evoformer_output, _ = evoformer_stack((evoformer_input, safe_subkey))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
754 return evoformer_output
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
755
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
756 evoformer_output = run_evoformer(evoformer_input)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
757
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
758 msa_activations = evoformer_output['msa']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
759 pair_activations = evoformer_output['pair']
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
760
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
761 single_activations = common_modules.Linear(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
762 c.seq_channel, name='single_activations')(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
763 msa_activations[0])
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
764
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
765 output.update({
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
766 'single':
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
767 single_activations,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
768 'pair':
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
769 pair_activations,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
770 # Crop away template rows such that they are not used in MaskedMsaHead.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
771 'msa':
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
772 msa_activations[:num_msa_sequences, :, :],
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
773 'msa_first_row':
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
774 msa_activations[0],
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
775 })
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
776
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
777 return output
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
778
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
779
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
780 class TemplateEmbedding(hk.Module):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
781 """Embed a set of templates."""
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
782
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
783 def __init__(self, config, global_config, name='template_embedding'):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
784 super().__init__(name=name)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
785 self.config = config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
786 self.global_config = global_config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
787
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
788 def __call__(self, query_embedding, template_batch, padding_mask_2d,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
789 multichain_mask_2d, is_training,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
790 safe_key=None):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
791 """Generate an embedding for a set of templates.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
792
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
793 Args:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
794 query_embedding: [num_res, num_res, num_channel] a query tensor that will
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
795 be used to attend over the templates to remove the num_templates
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
796 dimension.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
797 template_batch: A dictionary containing:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
798 `template_aatype`: [num_templates, num_res] aatype for each template.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
799 `template_all_atom_positions`: [num_templates, num_res, 37, 3] atom
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
800 positions for all templates.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
801 `template_all_atom_mask`: [num_templates, num_res, 37] mask for each
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
802 template.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
803 padding_mask_2d: [num_res, num_res] Pair mask for attention operations.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
804 multichain_mask_2d: [num_res, num_res] Mask indicating which residue pairs
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
805 are intra-chain, used to mask out residue distance based features
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
806 between chains.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
807 is_training: bool indicating where we are running in training mode.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
808 safe_key: random key generator.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
809
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
810 Returns:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
811 An embedding of size [num_res, num_res, num_channels]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
812 """
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
813 c = self.config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
814 if safe_key is None:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
815 safe_key = prng.SafeKey(hk.next_rng_key())
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
816
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
817 num_templates = template_batch['template_aatype'].shape[0]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
818 num_res, _, query_num_channels = query_embedding.shape
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
819
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
820 # Embed each template separately.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
821 template_embedder = SingleTemplateEmbedding(self.config, self.global_config)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
822 def partial_template_embedder(template_aatype,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
823 template_all_atom_positions,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
824 template_all_atom_mask,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
825 unsafe_key):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
826 safe_key = prng.SafeKey(unsafe_key)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
827 return template_embedder(query_embedding,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
828 template_aatype,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
829 template_all_atom_positions,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
830 template_all_atom_mask,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
831 padding_mask_2d,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
832 multichain_mask_2d,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
833 is_training,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
834 safe_key)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
835
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
836 safe_key, unsafe_key = safe_key.split()
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
837 unsafe_keys = jax.random.split(unsafe_key._key, num_templates)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
838
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
839 def scan_fn(carry, x):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
840 return carry + partial_template_embedder(*x), None
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
841
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
842 scan_init = jnp.zeros((num_res, num_res, c.num_channels),
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
843 dtype=query_embedding.dtype)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
844 summed_template_embeddings, _ = hk.scan(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
845 scan_fn, scan_init,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
846 (template_batch['template_aatype'],
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
847 template_batch['template_all_atom_positions'],
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
848 template_batch['template_all_atom_mask'], unsafe_keys))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
849
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
850 embedding = summed_template_embeddings / num_templates
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
851 embedding = jax.nn.relu(embedding)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
852 embedding = common_modules.Linear(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
853 query_num_channels,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
854 initializer='relu',
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
855 name='output_linear')(embedding)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
856
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
857 return embedding
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
858
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
859
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
860 class SingleTemplateEmbedding(hk.Module):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
861 """Embed a single template."""
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
862
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
863 def __init__(self, config, global_config, name='single_template_embedding'):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
864 super().__init__(name=name)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
865 self.config = config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
866 self.global_config = global_config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
867
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
868 def __call__(self, query_embedding, template_aatype,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
869 template_all_atom_positions, template_all_atom_mask,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
870 padding_mask_2d, multichain_mask_2d, is_training,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
871 safe_key):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
872 """Build the single template embedding graph.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
873
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
874 Args:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
875 query_embedding: (num_res, num_res, num_channels) - embedding of the
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
876 query sequence/msa.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
877 template_aatype: [num_res] aatype for each template.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
878 template_all_atom_positions: [num_res, 37, 3] atom positions for all
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
879 templates.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
880 template_all_atom_mask: [num_res, 37] mask for each template.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
881 padding_mask_2d: Padding mask (Note: this doesn't care if a template
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
882 exists, unlike the template_pseudo_beta_mask).
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
883 multichain_mask_2d: A mask indicating intra-chain residue pairs, used
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
884 to mask out between chain distances/features when templates are for
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
885 single chains.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
886 is_training: Are we in training mode.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
887 safe_key: Random key generator.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
888
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
889 Returns:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
890 A template embedding (num_res, num_res, num_channels).
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
891 """
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
892 gc = self.global_config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
893 c = self.config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
894 assert padding_mask_2d.dtype == query_embedding.dtype
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
895 dtype = query_embedding.dtype
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
896 num_channels = self.config.num_channels
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
897
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
898 def construct_input(query_embedding, template_aatype,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
899 template_all_atom_positions, template_all_atom_mask,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
900 multichain_mask_2d):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
901
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
902 # Compute distogram feature for the template.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
903 template_positions, pseudo_beta_mask = modules.pseudo_beta_fn(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
904 template_aatype, template_all_atom_positions, template_all_atom_mask)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
905 pseudo_beta_mask_2d = (pseudo_beta_mask[:, None] *
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
906 pseudo_beta_mask[None, :])
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
907 pseudo_beta_mask_2d *= multichain_mask_2d
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
908 template_dgram = modules.dgram_from_positions(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
909 template_positions, **self.config.dgram_features)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
910 template_dgram *= pseudo_beta_mask_2d[..., None]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
911 template_dgram = template_dgram.astype(dtype)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
912 pseudo_beta_mask_2d = pseudo_beta_mask_2d.astype(dtype)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
913 to_concat = [(template_dgram, 1), (pseudo_beta_mask_2d, 0)]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
914
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
915 aatype = jax.nn.one_hot(template_aatype, 22, axis=-1, dtype=dtype)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
916 to_concat.append((aatype[None, :, :], 1))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
917 to_concat.append((aatype[:, None, :], 1))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
918
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
919 # Compute a feature representing the normalized vector between each
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
920 # backbone affine - i.e. in each residues local frame, what direction are
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
921 # each of the other residues.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
922 raw_atom_pos = template_all_atom_positions
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
923
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
924 atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
925 rigid, backbone_mask = folding_multimer.make_backbone_affine(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
926 atom_pos,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
927 template_all_atom_mask,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
928 template_aatype)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
929 points = rigid.translation
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
930 rigid_vec = rigid[:, None].inverse().apply_to_point(points)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
931 unit_vector = rigid_vec.normalized()
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
932 unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
933
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
934 backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
935 backbone_mask_2d *= multichain_mask_2d
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
936 unit_vector = [x*backbone_mask_2d for x in unit_vector]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
937
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
938 # Note that the backbone_mask takes into account C, CA and N (unlike
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
939 # pseudo beta mask which just needs CB) so we add both masks as features.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
940 to_concat.extend([(x, 0) for x in unit_vector])
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
941 to_concat.append((backbone_mask_2d, 0))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
942
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
943 query_embedding = hk.LayerNorm(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
944 axis=[-1],
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
945 create_scale=True,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
946 create_offset=True,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
947 name='query_embedding_norm')(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
948 query_embedding)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
949 # Allow the template embedder to see the query embedding. Note this
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
950 # contains the position relative feature, so this is how the network knows
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
951 # which residues are next to each other.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
952 to_concat.append((query_embedding, 1))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
953
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
954 act = 0
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
955
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
956 for i, (x, n_input_dims) in enumerate(to_concat):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
957
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
958 act += common_modules.Linear(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
959 num_channels,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
960 num_input_dims=n_input_dims,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
961 initializer='relu',
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
962 name=f'template_pair_embedding_{i}')(x)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
963 return act
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
964
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
965 act = construct_input(query_embedding, template_aatype,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
966 template_all_atom_positions, template_all_atom_mask,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
967 multichain_mask_2d)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
968
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
969 template_iteration = TemplateEmbeddingIteration(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
970 c.template_pair_stack, gc, name='template_embedding_iteration')
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
971
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
972 def template_iteration_fn(x):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
973 act, safe_key = x
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
974
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
975 safe_key, safe_subkey = safe_key.split()
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
976 act = template_iteration(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
977 act=act,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
978 pair_mask=padding_mask_2d,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
979 is_training=is_training,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
980 safe_key=safe_subkey)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
981 return (act, safe_key)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
982
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
983 if gc.use_remat:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
984 template_iteration_fn = hk.remat(template_iteration_fn)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
985
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
986 safe_key, safe_subkey = safe_key.split()
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
987 template_stack = layer_stack.layer_stack(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
988 c.template_pair_stack.num_block)(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
989 template_iteration_fn)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
990 act, safe_key = template_stack((act, safe_subkey))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
991
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
992 act = hk.LayerNorm(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
993 axis=[-1],
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
994 create_scale=True,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
995 create_offset=True,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
996 name='output_layer_norm')(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
997 act)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
998 return act
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
999
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1000
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1001 class TemplateEmbeddingIteration(hk.Module):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1002 """Single Iteration of Template Embedding."""
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1003
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1004 def __init__(self, config, global_config,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1005 name='template_embedding_iteration'):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1006 super().__init__(name=name)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1007 self.config = config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1008 self.global_config = global_config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1009
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1010 def __call__(self, act, pair_mask, is_training=True,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1011 safe_key=None):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1012 """Build a single iteration of the template embedder.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1013
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1014 Args:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1015 act: [num_res, num_res, num_channel] Input pairwise activations.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1016 pair_mask: [num_res, num_res] padding mask.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1017 is_training: Whether to run in training mode.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1018 safe_key: Safe pseudo-random generator key.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1019
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1020 Returns:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1021 [num_res, num_res, num_channel] tensor of activations.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1022 """
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1023 c = self.config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1024 gc = self.global_config
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1025
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1026 if safe_key is None:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1027 safe_key = prng.SafeKey(hk.next_rng_key())
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1028
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1029 dropout_wrapper_fn = functools.partial(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1030 modules.dropout_wrapper,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1031 is_training=is_training,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1032 global_config=gc)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1033
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1034 safe_key, *sub_keys = safe_key.split(20)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1035 sub_keys = iter(sub_keys)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1036
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1037 act = dropout_wrapper_fn(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1038 modules.TriangleMultiplication(c.triangle_multiplication_outgoing, gc,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1039 name='triangle_multiplication_outgoing'),
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1040 act,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1041 pair_mask,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1042 safe_key=next(sub_keys))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1043
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1044 act = dropout_wrapper_fn(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1045 modules.TriangleMultiplication(c.triangle_multiplication_incoming, gc,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1046 name='triangle_multiplication_incoming'),
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1047 act,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1048 pair_mask,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1049 safe_key=next(sub_keys))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1050
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1051 act = dropout_wrapper_fn(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1052 modules.TriangleAttention(c.triangle_attention_starting_node, gc,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1053 name='triangle_attention_starting_node'),
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1054 act,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1055 pair_mask,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1056 safe_key=next(sub_keys))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1057
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1058 act = dropout_wrapper_fn(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1059 modules.TriangleAttention(c.triangle_attention_ending_node, gc,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1060 name='triangle_attention_ending_node'),
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1061 act,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1062 pair_mask,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1063 safe_key=next(sub_keys))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1064
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1065 act = dropout_wrapper_fn(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1066 modules.Transition(c.pair_transition, gc,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1067 name='pair_transition'),
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1068 act,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1069 pair_mask,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1070 safe_key=next(sub_keys))
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1071
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1072 return act
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1073
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1074
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1075 def template_embedding_1d(batch, num_channel):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1076 """Embed templates into an (num_res, num_templates, num_channels) embedding.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1077
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1078 Args:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1079 batch: A batch containing:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1080 template_aatype, (num_templates, num_res) aatype for the templates.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1081 template_all_atom_positions, (num_templates, num_residues, 37, 3) atom
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1082 positions for the templates.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1083 template_all_atom_mask, (num_templates, num_residues, 37) atom mask for
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1084 each template.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1085 num_channel: The number of channels in the output.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1086
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1087 Returns:
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1088 An embedding of shape (num_templates, num_res, num_channels) and a mask of
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1089 shape (num_templates, num_res).
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1090 """
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1091
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1092 # Embed the templates aatypes.
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1093 aatype_one_hot = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1094
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1095 num_templates = batch['template_aatype'].shape[0]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1096 all_chi_angles = []
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1097 all_chi_masks = []
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1098 for i in range(num_templates):
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1099 atom_pos = geometry.Vec3Array.from_array(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1100 batch['template_all_atom_positions'][i, :, :, :])
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1101 template_chi_angles, template_chi_mask = all_atom_multimer.compute_chi_angles(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1102 atom_pos,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1103 batch['template_all_atom_mask'][i, :, :],
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1104 batch['template_aatype'][i, :])
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1105 all_chi_angles.append(template_chi_angles)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1106 all_chi_masks.append(template_chi_mask)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1107 chi_angles = jnp.stack(all_chi_angles, axis=0)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1108 chi_mask = jnp.stack(all_chi_masks, axis=0)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1109
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1110 template_features = jnp.concatenate([
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1111 aatype_one_hot,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1112 jnp.sin(chi_angles) * chi_mask,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1113 jnp.cos(chi_angles) * chi_mask,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1114 chi_mask], axis=-1)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1115
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1116 template_mask = chi_mask[:, :, 0]
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1117
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1118 template_activations = common_modules.Linear(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1119 num_channel,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1120 initializer='relu',
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1121 name='template_single_embedding')(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1122 template_features)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1123 template_activations = jax.nn.relu(template_activations)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1124 template_activations = common_modules.Linear(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1125 num_channel,
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1126 initializer='relu',
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1127 name='template_projection')(
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1128 template_activations)
6c92e000d684 "planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff changeset
1129 return template_activations, template_mask