Mercurial > repos > galaxy-australia > alphafold2
annotate docker/alphafold/alphafold/model/modules_multimer.py @ 1:6c92e000d684 draft
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author | galaxy-australia |
---|---|
date | Tue, 01 Mar 2022 02:53:05 +0000 |
parents | |
children |
rev | line source |
---|---|
1
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1 # Copyright 2021 DeepMind Technologies Limited |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
2 # |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
3 # Licensed under the Apache License, Version 2.0 (the "License"); |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
4 # you may not use this file except in compliance with the License. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
5 # You may obtain a copy of the License at |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
6 # |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
7 # http://www.apache.org/licenses/LICENSE-2.0 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
8 # |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
9 # Unless required by applicable law or agreed to in writing, software |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
10 # distributed under the License is distributed on an "AS IS" BASIS, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
12 # See the License for the specific language governing permissions and |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
13 # limitations under the License. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
14 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
15 """Core modules, which have been refactored in AlphaFold-Multimer. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
16 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
17 The main difference is that MSA sampling pipeline is moved inside the JAX model |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
18 for easier implementation of recycling and ensembling. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
19 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
20 Lower-level modules up to EvoformerIteration are reused from modules.py. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
21 """ |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
22 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
23 import functools |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
24 from typing import Sequence |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
25 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
26 from alphafold.common import residue_constants |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
27 from alphafold.model import all_atom_multimer |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
28 from alphafold.model import common_modules |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
29 from alphafold.model import folding_multimer |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
30 from alphafold.model import geometry |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
31 from alphafold.model import layer_stack |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
32 from alphafold.model import modules |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
33 from alphafold.model import prng |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
34 from alphafold.model import utils |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
35 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
36 import haiku as hk |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
37 import jax |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
38 import jax.numpy as jnp |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
39 import numpy as np |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
40 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
41 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
42 def reduce_fn(x, mode): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
43 if mode == 'none' or mode is None: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
44 return jnp.asarray(x) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
45 elif mode == 'sum': |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
46 return jnp.asarray(x).sum() |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
47 elif mode == 'mean': |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
48 return jnp.mean(jnp.asarray(x)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
49 else: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
50 raise ValueError('Unsupported reduction option.') |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
51 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
52 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
53 def gumbel_noise(key: jnp.ndarray, shape: Sequence[int]) -> jnp.ndarray: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
54 """Generate Gumbel Noise of given Shape. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
55 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
56 This generates samples from Gumbel(0, 1). |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
57 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
58 Args: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
59 key: Jax random number key. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
60 shape: Shape of noise to return. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
61 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
62 Returns: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
63 Gumbel noise of given shape. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
64 """ |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
65 epsilon = 1e-6 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
66 uniform = utils.padding_consistent_rng(jax.random.uniform) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
67 uniform_noise = uniform( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
68 key, shape=shape, dtype=jnp.float32, minval=0., maxval=1.) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
69 gumbel = -jnp.log(-jnp.log(uniform_noise + epsilon) + epsilon) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
70 return gumbel |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
71 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
72 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
73 def gumbel_max_sample(key: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
74 """Samples from a probability distribution given by 'logits'. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
75 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
76 This uses Gumbel-max trick to implement the sampling in an efficient manner. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
77 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
78 Args: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
79 key: prng key. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
80 logits: Logarithm of probabilities to sample from, probabilities can be |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
81 unnormalized. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
82 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
83 Returns: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
84 Sample from logprobs in one-hot form. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
85 """ |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
86 z = gumbel_noise(key, logits.shape) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
87 return jax.nn.one_hot( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
88 jnp.argmax(logits + z, axis=-1), |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
89 logits.shape[-1], |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
90 dtype=logits.dtype) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
91 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
92 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
93 def gumbel_argsort_sample_idx(key: jnp.ndarray, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
94 logits: jnp.ndarray) -> jnp.ndarray: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
95 """Samples with replacement from a distribution given by 'logits'. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
96 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
97 This uses Gumbel trick to implement the sampling an efficient manner. For a |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
98 distribution over k items this samples k times without replacement, so this |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
99 is effectively sampling a random permutation with probabilities over the |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
100 permutations derived from the logprobs. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
101 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
102 Args: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
103 key: prng key. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
104 logits: Logarithm of probabilities to sample from, probabilities can be |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
105 unnormalized. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
106 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
107 Returns: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
108 Sample from logprobs in one-hot form. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
109 """ |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
110 z = gumbel_noise(key, logits.shape) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
111 # This construction is equivalent to jnp.argsort, but using a non stable sort, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
112 # since stable sort's aren't supported by jax2tf. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
113 axis = len(logits.shape) - 1 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
114 iota = jax.lax.broadcasted_iota(jnp.int64, logits.shape, axis) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
115 _, perm = jax.lax.sort_key_val( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
116 logits + z, iota, dimension=-1, is_stable=False) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
117 return perm[::-1] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
118 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
119 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
120 def make_masked_msa(batch, key, config, epsilon=1e-6): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
121 """Create data for BERT on raw MSA.""" |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
122 # Add a random amino acid uniformly. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
123 random_aa = jnp.array([0.05] * 20 + [0., 0.], dtype=jnp.float32) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
124 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
125 categorical_probs = ( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
126 config.uniform_prob * random_aa + |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
127 config.profile_prob * batch['msa_profile'] + |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
128 config.same_prob * jax.nn.one_hot(batch['msa'], 22)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
129 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
130 # Put all remaining probability on [MASK] which is a new column. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
131 pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
132 pad_shapes[-1][1] = 1 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
133 mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
134 assert mask_prob >= 0. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
135 categorical_probs = jnp.pad( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
136 categorical_probs, pad_shapes, constant_values=mask_prob) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
137 sh = batch['msa'].shape |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
138 key, mask_subkey, gumbel_subkey = key.split(3) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
139 uniform = utils.padding_consistent_rng(jax.random.uniform) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
140 mask_position = uniform(mask_subkey.get(), sh) < config.replace_fraction |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
141 mask_position *= batch['msa_mask'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
142 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
143 logits = jnp.log(categorical_probs + epsilon) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
144 bert_msa = gumbel_max_sample(gumbel_subkey.get(), logits) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
145 bert_msa = jnp.where(mask_position, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
146 jnp.argmax(bert_msa, axis=-1), batch['msa']) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
147 bert_msa *= batch['msa_mask'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
148 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
149 # Mix real and masked MSA. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
150 if 'bert_mask' in batch: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
151 batch['bert_mask'] *= mask_position.astype(jnp.float32) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
152 else: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
153 batch['bert_mask'] = mask_position.astype(jnp.float32) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
154 batch['true_msa'] = batch['msa'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
155 batch['msa'] = bert_msa |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
156 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
157 return batch |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
158 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
159 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
160 def nearest_neighbor_clusters(batch, gap_agreement_weight=0.): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
161 """Assign each extra MSA sequence to its nearest neighbor in sampled MSA.""" |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
162 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
163 # Determine how much weight we assign to each agreement. In theory, we could |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
164 # use a full blosum matrix here, but right now let's just down-weight gap |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
165 # agreement because it could be spurious. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
166 # Never put weight on agreeing on BERT mask. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
167 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
168 weights = jnp.array( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
169 [1.] * 21 + [gap_agreement_weight] + [0.], dtype=jnp.float32) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
170 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
171 msa_mask = batch['msa_mask'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
172 msa_one_hot = jax.nn.one_hot(batch['msa'], 23) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
173 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
174 extra_mask = batch['extra_msa_mask'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
175 extra_one_hot = jax.nn.one_hot(batch['extra_msa'], 23) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
176 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
177 msa_one_hot_masked = msa_mask[:, :, None] * msa_one_hot |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
178 extra_one_hot_masked = extra_mask[:, :, None] * extra_one_hot |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
179 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
180 agreement = jnp.einsum('mrc, nrc->nm', extra_one_hot_masked, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
181 weights * msa_one_hot_masked) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
182 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
183 cluster_assignment = jax.nn.softmax(1e3 * agreement, axis=0) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
184 cluster_assignment *= jnp.einsum('mr, nr->mn', msa_mask, extra_mask) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
185 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
186 cluster_count = jnp.sum(cluster_assignment, axis=-1) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
187 cluster_count += 1. # We always include the sequence itself. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
188 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
189 msa_sum = jnp.einsum('nm, mrc->nrc', cluster_assignment, extra_one_hot_masked) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
190 msa_sum += msa_one_hot_masked |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
191 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
192 cluster_profile = msa_sum / cluster_count[:, None, None] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
193 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
194 extra_deletion_matrix = batch['extra_deletion_matrix'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
195 deletion_matrix = batch['deletion_matrix'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
196 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
197 del_sum = jnp.einsum('nm, mc->nc', cluster_assignment, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
198 extra_mask * extra_deletion_matrix) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
199 del_sum += deletion_matrix # Original sequence. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
200 cluster_deletion_mean = del_sum / cluster_count[:, None] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
201 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
202 return cluster_profile, cluster_deletion_mean |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
203 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
204 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
205 def create_msa_feat(batch): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
206 """Create and concatenate MSA features.""" |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
207 msa_1hot = jax.nn.one_hot(batch['msa'], 23) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
208 deletion_matrix = batch['deletion_matrix'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
209 has_deletion = jnp.clip(deletion_matrix, 0., 1.)[..., None] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
210 deletion_value = (jnp.arctan(deletion_matrix / 3.) * (2. / jnp.pi))[..., None] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
211 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
212 deletion_mean_value = (jnp.arctan(batch['cluster_deletion_mean'] / 3.) * |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
213 (2. / jnp.pi))[..., None] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
214 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
215 msa_feat = [ |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
216 msa_1hot, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
217 has_deletion, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
218 deletion_value, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
219 batch['cluster_profile'], |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
220 deletion_mean_value |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
221 ] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
222 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
223 return jnp.concatenate(msa_feat, axis=-1) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
224 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
225 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
226 def create_extra_msa_feature(batch, num_extra_msa): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
227 """Expand extra_msa into 1hot and concat with other extra msa features. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
228 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
229 We do this as late as possible as the one_hot extra msa can be very large. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
230 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
231 Args: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
232 batch: a dictionary with the following keys: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
233 * 'extra_msa': [num_seq, num_res] MSA that wasn't selected as a cluster |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
234 centre. Note - This isn't one-hotted. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
235 * 'extra_deletion_matrix': [num_seq, num_res] Number of deletions at given |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
236 position. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
237 num_extra_msa: Number of extra msa to use. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
238 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
239 Returns: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
240 Concatenated tensor of extra MSA features. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
241 """ |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
242 # 23 = 20 amino acids + 'X' for unknown + gap + bert mask |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
243 extra_msa = batch['extra_msa'][:num_extra_msa] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
244 deletion_matrix = batch['extra_deletion_matrix'][:num_extra_msa] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
245 msa_1hot = jax.nn.one_hot(extra_msa, 23) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
246 has_deletion = jnp.clip(deletion_matrix, 0., 1.)[..., None] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
247 deletion_value = (jnp.arctan(deletion_matrix / 3.) * (2. / jnp.pi))[..., None] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
248 extra_msa_mask = batch['extra_msa_mask'][:num_extra_msa] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
249 return jnp.concatenate([msa_1hot, has_deletion, deletion_value], |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
250 axis=-1), extra_msa_mask |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
251 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
252 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
253 def sample_msa(key, batch, max_seq): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
254 """Sample MSA randomly, remaining sequences are stored as `extra_*`. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
255 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
256 Args: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
257 key: safe key for random number generation. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
258 batch: batch to sample msa from. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
259 max_seq: number of sequences to sample. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
260 Returns: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
261 Protein with sampled msa. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
262 """ |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
263 # Sample uniformly among sequences with at least one non-masked position. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
264 logits = (jnp.clip(jnp.sum(batch['msa_mask'], axis=-1), 0., 1.) - 1.) * 1e6 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
265 # The cluster_bias_mask can be used to preserve the first row (target |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
266 # sequence) for each chain, for example. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
267 if 'cluster_bias_mask' not in batch: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
268 cluster_bias_mask = jnp.pad( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
269 jnp.zeros(batch['msa'].shape[0] - 1), (1, 0), constant_values=1.) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
270 else: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
271 cluster_bias_mask = batch['cluster_bias_mask'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
272 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
273 logits += cluster_bias_mask * 1e6 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
274 index_order = gumbel_argsort_sample_idx(key.get(), logits) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
275 sel_idx = index_order[:max_seq] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
276 extra_idx = index_order[max_seq:] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
277 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
278 for k in ['msa', 'deletion_matrix', 'msa_mask', 'bert_mask']: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
279 if k in batch: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
280 batch['extra_' + k] = batch[k][extra_idx] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
281 batch[k] = batch[k][sel_idx] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
282 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
283 return batch |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
284 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
285 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
286 def make_msa_profile(batch): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
287 """Compute the MSA profile.""" |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
288 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
289 # Compute the profile for every residue (over all MSA sequences). |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
290 return utils.mask_mean( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
291 batch['msa_mask'][:, :, None], jax.nn.one_hot(batch['msa'], 22), axis=0) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
292 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
293 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
294 class AlphaFoldIteration(hk.Module): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
295 """A single recycling iteration of AlphaFold architecture. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
296 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
297 Computes ensembled (averaged) representations from the provided features. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
298 These representations are then passed to the various heads |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
299 that have been requested by the configuration file. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
300 """ |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
301 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
302 def __init__(self, config, global_config, name='alphafold_iteration'): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
303 super().__init__(name=name) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
304 self.config = config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
305 self.global_config = global_config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
306 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
307 def __call__(self, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
308 batch, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
309 is_training, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
310 return_representations=False, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
311 safe_key=None): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
312 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
313 if is_training: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
314 num_ensemble = np.asarray(self.config.num_ensemble_train) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
315 else: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
316 num_ensemble = np.asarray(self.config.num_ensemble_eval) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
317 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
318 # Compute representations for each MSA sample and average. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
319 embedding_module = EmbeddingsAndEvoformer( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
320 self.config.embeddings_and_evoformer, self.global_config) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
321 repr_shape = hk.eval_shape( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
322 lambda: embedding_module(batch, is_training)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
323 representations = { |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
324 k: jnp.zeros(v.shape, v.dtype) for (k, v) in repr_shape.items() |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
325 } |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
326 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
327 def ensemble_body(x, unused_y): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
328 """Add into representations ensemble.""" |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
329 del unused_y |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
330 representations, safe_key = x |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
331 safe_key, safe_subkey = safe_key.split() |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
332 representations_update = embedding_module( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
333 batch, is_training, safe_key=safe_subkey) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
334 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
335 for k in representations: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
336 if k not in {'msa', 'true_msa', 'bert_mask'}: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
337 representations[k] += representations_update[k] * ( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
338 1. / num_ensemble).astype(representations[k].dtype) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
339 else: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
340 representations[k] = representations_update[k] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
341 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
342 return (representations, safe_key), None |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
343 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
344 (representations, _), _ = hk.scan( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
345 ensemble_body, (representations, safe_key), None, length=num_ensemble) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
346 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
347 self.representations = representations |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
348 self.batch = batch |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
349 self.heads = {} |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
350 for head_name, head_config in sorted(self.config.heads.items()): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
351 if not head_config.weight: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
352 continue # Do not instantiate zero-weight heads. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
353 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
354 head_factory = { |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
355 'masked_msa': |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
356 modules.MaskedMsaHead, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
357 'distogram': |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
358 modules.DistogramHead, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
359 'structure_module': |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
360 folding_multimer.StructureModule, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
361 'predicted_aligned_error': |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
362 modules.PredictedAlignedErrorHead, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
363 'predicted_lddt': |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
364 modules.PredictedLDDTHead, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
365 'experimentally_resolved': |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
366 modules.ExperimentallyResolvedHead, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
367 }[head_name] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
368 self.heads[head_name] = (head_config, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
369 head_factory(head_config, self.global_config)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
370 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
371 structure_module_output = None |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
372 if 'entity_id' in batch and 'all_atom_positions' in batch: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
373 _, fold_module = self.heads['structure_module'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
374 structure_module_output = fold_module(representations, batch, is_training) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
375 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
376 ret = {} |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
377 ret['representations'] = representations |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
378 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
379 for name, (head_config, module) in self.heads.items(): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
380 if name == 'structure_module' and structure_module_output is not None: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
381 ret[name] = structure_module_output |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
382 representations['structure_module'] = structure_module_output.pop('act') |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
383 # Skip confidence heads until StructureModule is executed. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
384 elif name in {'predicted_lddt', 'predicted_aligned_error', |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
385 'experimentally_resolved'}: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
386 continue |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
387 else: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
388 ret[name] = module(representations, batch, is_training) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
389 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
390 # Add confidence heads after StructureModule is executed. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
391 if self.config.heads.get('predicted_lddt.weight', 0.0): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
392 name = 'predicted_lddt' |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
393 head_config, module = self.heads[name] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
394 ret[name] = module(representations, batch, is_training) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
395 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
396 if self.config.heads.experimentally_resolved.weight: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
397 name = 'experimentally_resolved' |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
398 head_config, module = self.heads[name] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
399 ret[name] = module(representations, batch, is_training) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
400 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
401 if self.config.heads.get('predicted_aligned_error.weight', 0.0): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
402 name = 'predicted_aligned_error' |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
403 head_config, module = self.heads[name] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
404 ret[name] = module(representations, batch, is_training) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
405 # Will be used for ipTM computation. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
406 ret[name]['asym_id'] = batch['asym_id'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
407 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
408 return ret |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
409 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
410 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
411 class AlphaFold(hk.Module): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
412 """AlphaFold-Multimer model with recycling. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
413 """ |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
414 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
415 def __init__(self, config, name='alphafold'): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
416 super().__init__(name=name) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
417 self.config = config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
418 self.global_config = config.global_config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
419 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
420 def __call__( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
421 self, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
422 batch, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
423 is_training, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
424 return_representations=False, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
425 safe_key=None): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
426 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
427 c = self.config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
428 impl = AlphaFoldIteration(c, self.global_config) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
429 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
430 if safe_key is None: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
431 safe_key = prng.SafeKey(hk.next_rng_key()) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
432 elif isinstance(safe_key, jnp.ndarray): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
433 safe_key = prng.SafeKey(safe_key) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
434 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
435 assert isinstance(batch, dict) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
436 num_res = batch['aatype'].shape[0] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
437 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
438 def get_prev(ret): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
439 new_prev = { |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
440 'prev_pos': |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
441 ret['structure_module']['final_atom_positions'], |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
442 'prev_msa_first_row': ret['representations']['msa_first_row'], |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
443 'prev_pair': ret['representations']['pair'], |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
444 } |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
445 return jax.tree_map(jax.lax.stop_gradient, new_prev) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
446 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
447 def apply_network(prev, safe_key): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
448 recycled_batch = {**batch, **prev} |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
449 return impl( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
450 batch=recycled_batch, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
451 is_training=is_training, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
452 safe_key=safe_key) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
453 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
454 if self.config.num_recycle: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
455 emb_config = self.config.embeddings_and_evoformer |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
456 prev = { |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
457 'prev_pos': |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
458 jnp.zeros([num_res, residue_constants.atom_type_num, 3]), |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
459 'prev_msa_first_row': |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
460 jnp.zeros([num_res, emb_config.msa_channel]), |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
461 'prev_pair': |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
462 jnp.zeros([num_res, num_res, emb_config.pair_channel]), |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
463 } |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
464 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
465 if 'num_iter_recycling' in batch: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
466 # Training time: num_iter_recycling is in batch. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
467 # Value for each ensemble batch is the same, so arbitrarily taking 0-th. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
468 num_iter = batch['num_iter_recycling'][0] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
469 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
470 # Add insurance that even when ensembling, we will not run more |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
471 # recyclings than the model is configured to run. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
472 num_iter = jnp.minimum(num_iter, c.num_recycle) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
473 else: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
474 # Eval mode or tests: use the maximum number of iterations. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
475 num_iter = c.num_recycle |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
476 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
477 def recycle_body(i, x): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
478 del i |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
479 prev, safe_key = x |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
480 safe_key1, safe_key2 = safe_key.split() if c.resample_msa_in_recycling else safe_key.duplicate() # pylint: disable=line-too-long |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
481 ret = apply_network(prev=prev, safe_key=safe_key2) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
482 return get_prev(ret), safe_key1 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
483 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
484 prev, safe_key = hk.fori_loop(0, num_iter, recycle_body, (prev, safe_key)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
485 else: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
486 prev = {} |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
487 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
488 # Run extra iteration. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
489 ret = apply_network(prev=prev, safe_key=safe_key) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
490 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
491 if not return_representations: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
492 del ret['representations'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
493 return ret |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
494 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
495 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
496 class EmbeddingsAndEvoformer(hk.Module): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
497 """Embeds the input data and runs Evoformer. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
498 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
499 Produces the MSA, single and pair representations. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
500 """ |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
501 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
502 def __init__(self, config, global_config, name='evoformer'): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
503 super().__init__(name=name) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
504 self.config = config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
505 self.global_config = global_config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
506 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
507 def _relative_encoding(self, batch): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
508 """Add relative position encodings. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
509 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
510 For position (i, j), the value is (i-j) clipped to [-k, k] and one-hotted. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
511 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
512 When not using 'use_chain_relative' the residue indices are used as is, e.g. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
513 for heteromers relative positions will be computed using the positions in |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
514 the corresponding chains. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
515 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
516 When using 'use_chain_relative' we add an extra bin that denotes |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
517 'different chain'. Furthermore we also provide the relative chain index |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
518 (i.e. sym_id) clipped and one-hotted to the network. And an extra feature |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
519 which denotes whether they belong to the same chain type, i.e. it's 0 if |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
520 they are in different heteromer chains and 1 otherwise. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
521 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
522 Args: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
523 batch: batch. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
524 Returns: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
525 Feature embedding using the features as described before. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
526 """ |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
527 c = self.config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
528 rel_feats = [] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
529 pos = batch['residue_index'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
530 asym_id = batch['asym_id'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
531 asym_id_same = jnp.equal(asym_id[:, None], asym_id[None, :]) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
532 offset = pos[:, None] - pos[None, :] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
533 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
534 clipped_offset = jnp.clip( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
535 offset + c.max_relative_idx, a_min=0, a_max=2 * c.max_relative_idx) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
536 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
537 if c.use_chain_relative: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
538 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
539 final_offset = jnp.where(asym_id_same, clipped_offset, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
540 (2 * c.max_relative_idx + 1) * |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
541 jnp.ones_like(clipped_offset)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
542 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
543 rel_pos = jax.nn.one_hot(final_offset, 2 * c.max_relative_idx + 2) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
544 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
545 rel_feats.append(rel_pos) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
546 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
547 entity_id = batch['entity_id'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
548 entity_id_same = jnp.equal(entity_id[:, None], entity_id[None, :]) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
549 rel_feats.append(entity_id_same.astype(rel_pos.dtype)[..., None]) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
550 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
551 sym_id = batch['sym_id'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
552 rel_sym_id = sym_id[:, None] - sym_id[None, :] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
553 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
554 max_rel_chain = c.max_relative_chain |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
555 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
556 clipped_rel_chain = jnp.clip( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
557 rel_sym_id + max_rel_chain, a_min=0, a_max=2 * max_rel_chain) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
558 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
559 final_rel_chain = jnp.where(entity_id_same, clipped_rel_chain, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
560 (2 * max_rel_chain + 1) * |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
561 jnp.ones_like(clipped_rel_chain)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
562 rel_chain = jax.nn.one_hot(final_rel_chain, 2 * c.max_relative_chain + 2) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
563 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
564 rel_feats.append(rel_chain) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
565 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
566 else: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
567 rel_pos = jax.nn.one_hot(clipped_offset, 2 * c.max_relative_idx + 1) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
568 rel_feats.append(rel_pos) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
569 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
570 rel_feat = jnp.concatenate(rel_feats, axis=-1) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
571 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
572 return common_modules.Linear( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
573 c.pair_channel, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
574 name='position_activations')( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
575 rel_feat) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
576 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
577 def __call__(self, batch, is_training, safe_key=None): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
578 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
579 c = self.config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
580 gc = self.global_config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
581 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
582 batch = dict(batch) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
583 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
584 if safe_key is None: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
585 safe_key = prng.SafeKey(hk.next_rng_key()) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
586 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
587 output = {} |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
588 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
589 batch['msa_profile'] = make_msa_profile(batch) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
590 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
591 target_feat = jax.nn.one_hot(batch['aatype'], 21) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
592 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
593 preprocess_1d = common_modules.Linear( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
594 c.msa_channel, name='preprocess_1d')( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
595 target_feat) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
596 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
597 safe_key, sample_key, mask_key = safe_key.split(3) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
598 batch = sample_msa(sample_key, batch, c.num_msa) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
599 batch = make_masked_msa(batch, mask_key, c.masked_msa) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
600 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
601 (batch['cluster_profile'], |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
602 batch['cluster_deletion_mean']) = nearest_neighbor_clusters(batch) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
603 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
604 msa_feat = create_msa_feat(batch) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
605 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
606 preprocess_msa = common_modules.Linear( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
607 c.msa_channel, name='preprocess_msa')( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
608 msa_feat) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
609 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
610 msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
611 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
612 left_single = common_modules.Linear( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
613 c.pair_channel, name='left_single')( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
614 target_feat) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
615 right_single = common_modules.Linear( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
616 c.pair_channel, name='right_single')( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
617 target_feat) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
618 pair_activations = left_single[:, None] + right_single[None] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
619 mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
620 mask_2d = mask_2d.astype(jnp.float32) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
621 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
622 if c.recycle_pos and 'prev_pos' in batch: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
623 prev_pseudo_beta = modules.pseudo_beta_fn( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
624 batch['aatype'], batch['prev_pos'], None) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
625 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
626 dgram = modules.dgram_from_positions( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
627 prev_pseudo_beta, **self.config.prev_pos) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
628 pair_activations += common_modules.Linear( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
629 c.pair_channel, name='prev_pos_linear')( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
630 dgram) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
631 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
632 if c.recycle_features: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
633 if 'prev_msa_first_row' in batch: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
634 prev_msa_first_row = hk.LayerNorm( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
635 axis=[-1], |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
636 create_scale=True, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
637 create_offset=True, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
638 name='prev_msa_first_row_norm')( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
639 batch['prev_msa_first_row']) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
640 msa_activations = msa_activations.at[0].add(prev_msa_first_row) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
641 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
642 if 'prev_pair' in batch: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
643 pair_activations += hk.LayerNorm( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
644 axis=[-1], |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
645 create_scale=True, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
646 create_offset=True, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
647 name='prev_pair_norm')( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
648 batch['prev_pair']) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
649 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
650 if c.max_relative_idx: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
651 pair_activations += self._relative_encoding(batch) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
652 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
653 if c.template.enabled: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
654 template_module = TemplateEmbedding(c.template, gc) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
655 template_batch = { |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
656 'template_aatype': batch['template_aatype'], |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
657 'template_all_atom_positions': batch['template_all_atom_positions'], |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
658 'template_all_atom_mask': batch['template_all_atom_mask'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
659 } |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
660 # Construct a mask such that only intra-chain template features are |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
661 # computed, since all templates are for each chain individually. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
662 multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
663 safe_key, safe_subkey = safe_key.split() |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
664 template_act = template_module( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
665 query_embedding=pair_activations, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
666 template_batch=template_batch, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
667 padding_mask_2d=mask_2d, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
668 multichain_mask_2d=multichain_mask, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
669 is_training=is_training, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
670 safe_key=safe_subkey) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
671 pair_activations += template_act |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
672 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
673 # Extra MSA stack. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
674 (extra_msa_feat, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
675 extra_msa_mask) = create_extra_msa_feature(batch, c.num_extra_msa) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
676 extra_msa_activations = common_modules.Linear( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
677 c.extra_msa_channel, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
678 name='extra_msa_activations')( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
679 extra_msa_feat) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
680 extra_msa_mask = extra_msa_mask.astype(jnp.float32) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
681 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
682 extra_evoformer_input = { |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
683 'msa': extra_msa_activations, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
684 'pair': pair_activations, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
685 } |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
686 extra_masks = {'msa': extra_msa_mask, 'pair': mask_2d} |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
687 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
688 extra_evoformer_iteration = modules.EvoformerIteration( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
689 c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack') |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
690 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
691 def extra_evoformer_fn(x): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
692 act, safe_key = x |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
693 safe_key, safe_subkey = safe_key.split() |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
694 extra_evoformer_output = extra_evoformer_iteration( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
695 activations=act, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
696 masks=extra_masks, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
697 is_training=is_training, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
698 safe_key=safe_subkey) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
699 return (extra_evoformer_output, safe_key) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
700 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
701 if gc.use_remat: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
702 extra_evoformer_fn = hk.remat(extra_evoformer_fn) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
703 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
704 safe_key, safe_subkey = safe_key.split() |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
705 extra_evoformer_stack = layer_stack.layer_stack( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
706 c.extra_msa_stack_num_block)( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
707 extra_evoformer_fn) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
708 extra_evoformer_output, safe_key = extra_evoformer_stack( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
709 (extra_evoformer_input, safe_subkey)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
710 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
711 pair_activations = extra_evoformer_output['pair'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
712 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
713 # Get the size of the MSA before potentially adding templates, so we |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
714 # can crop out the templates later. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
715 num_msa_sequences = msa_activations.shape[0] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
716 evoformer_input = { |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
717 'msa': msa_activations, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
718 'pair': pair_activations, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
719 } |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
720 evoformer_masks = {'msa': batch['msa_mask'].astype(jnp.float32), |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
721 'pair': mask_2d} |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
722 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
723 if c.template.enabled: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
724 template_features, template_masks = ( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
725 template_embedding_1d(batch=batch, num_channel=c.msa_channel)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
726 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
727 evoformer_input['msa'] = jnp.concatenate( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
728 [evoformer_input['msa'], template_features], axis=0) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
729 evoformer_masks['msa'] = jnp.concatenate( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
730 [evoformer_masks['msa'], template_masks], axis=0) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
731 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
732 evoformer_iteration = modules.EvoformerIteration( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
733 c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration') |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
734 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
735 def evoformer_fn(x): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
736 act, safe_key = x |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
737 safe_key, safe_subkey = safe_key.split() |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
738 evoformer_output = evoformer_iteration( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
739 activations=act, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
740 masks=evoformer_masks, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
741 is_training=is_training, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
742 safe_key=safe_subkey) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
743 return (evoformer_output, safe_key) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
744 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
745 if gc.use_remat: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
746 evoformer_fn = hk.remat(evoformer_fn) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
747 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
748 safe_key, safe_subkey = safe_key.split() |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
749 evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
750 evoformer_fn) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
751 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
752 def run_evoformer(evoformer_input): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
753 evoformer_output, _ = evoformer_stack((evoformer_input, safe_subkey)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
754 return evoformer_output |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
755 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
756 evoformer_output = run_evoformer(evoformer_input) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
757 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
758 msa_activations = evoformer_output['msa'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
759 pair_activations = evoformer_output['pair'] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
760 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
761 single_activations = common_modules.Linear( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
762 c.seq_channel, name='single_activations')( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
763 msa_activations[0]) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
764 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
765 output.update({ |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
766 'single': |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
767 single_activations, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
768 'pair': |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
769 pair_activations, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
770 # Crop away template rows such that they are not used in MaskedMsaHead. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
771 'msa': |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
772 msa_activations[:num_msa_sequences, :, :], |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
773 'msa_first_row': |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
774 msa_activations[0], |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
775 }) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
776 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
777 return output |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
778 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
779 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
780 class TemplateEmbedding(hk.Module): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
781 """Embed a set of templates.""" |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
782 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
783 def __init__(self, config, global_config, name='template_embedding'): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
784 super().__init__(name=name) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
785 self.config = config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
786 self.global_config = global_config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
787 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
788 def __call__(self, query_embedding, template_batch, padding_mask_2d, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
789 multichain_mask_2d, is_training, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
790 safe_key=None): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
791 """Generate an embedding for a set of templates. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
792 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
793 Args: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
794 query_embedding: [num_res, num_res, num_channel] a query tensor that will |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
795 be used to attend over the templates to remove the num_templates |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
796 dimension. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
797 template_batch: A dictionary containing: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
798 `template_aatype`: [num_templates, num_res] aatype for each template. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
799 `template_all_atom_positions`: [num_templates, num_res, 37, 3] atom |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
800 positions for all templates. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
801 `template_all_atom_mask`: [num_templates, num_res, 37] mask for each |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
802 template. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
803 padding_mask_2d: [num_res, num_res] Pair mask for attention operations. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
804 multichain_mask_2d: [num_res, num_res] Mask indicating which residue pairs |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
805 are intra-chain, used to mask out residue distance based features |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
806 between chains. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
807 is_training: bool indicating where we are running in training mode. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
808 safe_key: random key generator. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
809 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
810 Returns: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
811 An embedding of size [num_res, num_res, num_channels] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
812 """ |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
813 c = self.config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
814 if safe_key is None: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
815 safe_key = prng.SafeKey(hk.next_rng_key()) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
816 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
817 num_templates = template_batch['template_aatype'].shape[0] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
818 num_res, _, query_num_channels = query_embedding.shape |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
819 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
820 # Embed each template separately. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
821 template_embedder = SingleTemplateEmbedding(self.config, self.global_config) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
822 def partial_template_embedder(template_aatype, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
823 template_all_atom_positions, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
824 template_all_atom_mask, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
825 unsafe_key): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
826 safe_key = prng.SafeKey(unsafe_key) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
827 return template_embedder(query_embedding, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
828 template_aatype, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
829 template_all_atom_positions, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
830 template_all_atom_mask, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
831 padding_mask_2d, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
832 multichain_mask_2d, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
833 is_training, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
834 safe_key) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
835 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
836 safe_key, unsafe_key = safe_key.split() |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
837 unsafe_keys = jax.random.split(unsafe_key._key, num_templates) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
838 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
839 def scan_fn(carry, x): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
840 return carry + partial_template_embedder(*x), None |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
841 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
842 scan_init = jnp.zeros((num_res, num_res, c.num_channels), |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
843 dtype=query_embedding.dtype) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
844 summed_template_embeddings, _ = hk.scan( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
845 scan_fn, scan_init, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
846 (template_batch['template_aatype'], |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
847 template_batch['template_all_atom_positions'], |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
848 template_batch['template_all_atom_mask'], unsafe_keys)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
849 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
850 embedding = summed_template_embeddings / num_templates |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
851 embedding = jax.nn.relu(embedding) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
852 embedding = common_modules.Linear( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
853 query_num_channels, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
854 initializer='relu', |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
855 name='output_linear')(embedding) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
856 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
857 return embedding |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
858 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
859 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
860 class SingleTemplateEmbedding(hk.Module): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
861 """Embed a single template.""" |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
862 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
863 def __init__(self, config, global_config, name='single_template_embedding'): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
864 super().__init__(name=name) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
865 self.config = config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
866 self.global_config = global_config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
867 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
868 def __call__(self, query_embedding, template_aatype, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
869 template_all_atom_positions, template_all_atom_mask, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
870 padding_mask_2d, multichain_mask_2d, is_training, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
871 safe_key): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
872 """Build the single template embedding graph. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
873 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
874 Args: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
875 query_embedding: (num_res, num_res, num_channels) - embedding of the |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
876 query sequence/msa. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
877 template_aatype: [num_res] aatype for each template. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
878 template_all_atom_positions: [num_res, 37, 3] atom positions for all |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
879 templates. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
880 template_all_atom_mask: [num_res, 37] mask for each template. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
881 padding_mask_2d: Padding mask (Note: this doesn't care if a template |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
882 exists, unlike the template_pseudo_beta_mask). |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
883 multichain_mask_2d: A mask indicating intra-chain residue pairs, used |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
884 to mask out between chain distances/features when templates are for |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
885 single chains. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
886 is_training: Are we in training mode. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
887 safe_key: Random key generator. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
888 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
889 Returns: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
890 A template embedding (num_res, num_res, num_channels). |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
891 """ |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
892 gc = self.global_config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
893 c = self.config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
894 assert padding_mask_2d.dtype == query_embedding.dtype |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
895 dtype = query_embedding.dtype |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
896 num_channels = self.config.num_channels |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
897 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
898 def construct_input(query_embedding, template_aatype, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
899 template_all_atom_positions, template_all_atom_mask, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
900 multichain_mask_2d): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
901 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
902 # Compute distogram feature for the template. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
903 template_positions, pseudo_beta_mask = modules.pseudo_beta_fn( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
904 template_aatype, template_all_atom_positions, template_all_atom_mask) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
905 pseudo_beta_mask_2d = (pseudo_beta_mask[:, None] * |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
906 pseudo_beta_mask[None, :]) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
907 pseudo_beta_mask_2d *= multichain_mask_2d |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
908 template_dgram = modules.dgram_from_positions( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
909 template_positions, **self.config.dgram_features) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
910 template_dgram *= pseudo_beta_mask_2d[..., None] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
911 template_dgram = template_dgram.astype(dtype) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
912 pseudo_beta_mask_2d = pseudo_beta_mask_2d.astype(dtype) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
913 to_concat = [(template_dgram, 1), (pseudo_beta_mask_2d, 0)] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
914 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
915 aatype = jax.nn.one_hot(template_aatype, 22, axis=-1, dtype=dtype) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
916 to_concat.append((aatype[None, :, :], 1)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
917 to_concat.append((aatype[:, None, :], 1)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
918 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
919 # Compute a feature representing the normalized vector between each |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
920 # backbone affine - i.e. in each residues local frame, what direction are |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
921 # each of the other residues. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
922 raw_atom_pos = template_all_atom_positions |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
923 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
924 atom_pos = geometry.Vec3Array.from_array(raw_atom_pos) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
925 rigid, backbone_mask = folding_multimer.make_backbone_affine( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
926 atom_pos, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
927 template_all_atom_mask, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
928 template_aatype) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
929 points = rigid.translation |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
930 rigid_vec = rigid[:, None].inverse().apply_to_point(points) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
931 unit_vector = rigid_vec.normalized() |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
932 unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
933 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
934 backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
935 backbone_mask_2d *= multichain_mask_2d |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
936 unit_vector = [x*backbone_mask_2d for x in unit_vector] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
937 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
938 # Note that the backbone_mask takes into account C, CA and N (unlike |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
939 # pseudo beta mask which just needs CB) so we add both masks as features. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
940 to_concat.extend([(x, 0) for x in unit_vector]) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
941 to_concat.append((backbone_mask_2d, 0)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
942 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
943 query_embedding = hk.LayerNorm( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
944 axis=[-1], |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
945 create_scale=True, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
946 create_offset=True, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
947 name='query_embedding_norm')( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
948 query_embedding) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
949 # Allow the template embedder to see the query embedding. Note this |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
950 # contains the position relative feature, so this is how the network knows |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
951 # which residues are next to each other. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
952 to_concat.append((query_embedding, 1)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
953 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
954 act = 0 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
955 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
956 for i, (x, n_input_dims) in enumerate(to_concat): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
957 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
958 act += common_modules.Linear( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
959 num_channels, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
960 num_input_dims=n_input_dims, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
961 initializer='relu', |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
962 name=f'template_pair_embedding_{i}')(x) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
963 return act |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
964 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
965 act = construct_input(query_embedding, template_aatype, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
966 template_all_atom_positions, template_all_atom_mask, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
967 multichain_mask_2d) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
968 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
969 template_iteration = TemplateEmbeddingIteration( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
970 c.template_pair_stack, gc, name='template_embedding_iteration') |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
971 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
972 def template_iteration_fn(x): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
973 act, safe_key = x |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
974 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
975 safe_key, safe_subkey = safe_key.split() |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
976 act = template_iteration( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
977 act=act, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
978 pair_mask=padding_mask_2d, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
979 is_training=is_training, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
980 safe_key=safe_subkey) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
981 return (act, safe_key) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
982 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
983 if gc.use_remat: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
984 template_iteration_fn = hk.remat(template_iteration_fn) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
985 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
986 safe_key, safe_subkey = safe_key.split() |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
987 template_stack = layer_stack.layer_stack( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
988 c.template_pair_stack.num_block)( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
989 template_iteration_fn) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
990 act, safe_key = template_stack((act, safe_subkey)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
991 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
992 act = hk.LayerNorm( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
993 axis=[-1], |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
994 create_scale=True, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
995 create_offset=True, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
996 name='output_layer_norm')( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
997 act) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
998 return act |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
999 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1000 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1001 class TemplateEmbeddingIteration(hk.Module): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1002 """Single Iteration of Template Embedding.""" |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1003 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1004 def __init__(self, config, global_config, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1005 name='template_embedding_iteration'): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1006 super().__init__(name=name) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1007 self.config = config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1008 self.global_config = global_config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1009 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1010 def __call__(self, act, pair_mask, is_training=True, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1011 safe_key=None): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1012 """Build a single iteration of the template embedder. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1013 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1014 Args: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1015 act: [num_res, num_res, num_channel] Input pairwise activations. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1016 pair_mask: [num_res, num_res] padding mask. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1017 is_training: Whether to run in training mode. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1018 safe_key: Safe pseudo-random generator key. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1019 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1020 Returns: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1021 [num_res, num_res, num_channel] tensor of activations. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1022 """ |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1023 c = self.config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1024 gc = self.global_config |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1025 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1026 if safe_key is None: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1027 safe_key = prng.SafeKey(hk.next_rng_key()) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1028 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1029 dropout_wrapper_fn = functools.partial( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1030 modules.dropout_wrapper, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1031 is_training=is_training, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1032 global_config=gc) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1033 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1034 safe_key, *sub_keys = safe_key.split(20) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1035 sub_keys = iter(sub_keys) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1036 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1037 act = dropout_wrapper_fn( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1038 modules.TriangleMultiplication(c.triangle_multiplication_outgoing, gc, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1039 name='triangle_multiplication_outgoing'), |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1040 act, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1041 pair_mask, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1042 safe_key=next(sub_keys)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1043 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1044 act = dropout_wrapper_fn( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1045 modules.TriangleMultiplication(c.triangle_multiplication_incoming, gc, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1046 name='triangle_multiplication_incoming'), |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1047 act, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1048 pair_mask, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1049 safe_key=next(sub_keys)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1050 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1051 act = dropout_wrapper_fn( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1052 modules.TriangleAttention(c.triangle_attention_starting_node, gc, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1053 name='triangle_attention_starting_node'), |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1054 act, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1055 pair_mask, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1056 safe_key=next(sub_keys)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1057 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1058 act = dropout_wrapper_fn( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1059 modules.TriangleAttention(c.triangle_attention_ending_node, gc, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1060 name='triangle_attention_ending_node'), |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1061 act, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1062 pair_mask, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1063 safe_key=next(sub_keys)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1064 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1065 act = dropout_wrapper_fn( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1066 modules.Transition(c.pair_transition, gc, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1067 name='pair_transition'), |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1068 act, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1069 pair_mask, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1070 safe_key=next(sub_keys)) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1071 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1072 return act |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1073 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1074 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1075 def template_embedding_1d(batch, num_channel): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1076 """Embed templates into an (num_res, num_templates, num_channels) embedding. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1077 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1078 Args: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1079 batch: A batch containing: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1080 template_aatype, (num_templates, num_res) aatype for the templates. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1081 template_all_atom_positions, (num_templates, num_residues, 37, 3) atom |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1082 positions for the templates. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1083 template_all_atom_mask, (num_templates, num_residues, 37) atom mask for |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1084 each template. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1085 num_channel: The number of channels in the output. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1086 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1087 Returns: |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1088 An embedding of shape (num_templates, num_res, num_channels) and a mask of |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1089 shape (num_templates, num_res). |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1090 """ |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1091 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1092 # Embed the templates aatypes. |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1093 aatype_one_hot = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1094 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1095 num_templates = batch['template_aatype'].shape[0] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1096 all_chi_angles = [] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1097 all_chi_masks = [] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1098 for i in range(num_templates): |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1099 atom_pos = geometry.Vec3Array.from_array( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1100 batch['template_all_atom_positions'][i, :, :, :]) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1101 template_chi_angles, template_chi_mask = all_atom_multimer.compute_chi_angles( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1102 atom_pos, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1103 batch['template_all_atom_mask'][i, :, :], |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1104 batch['template_aatype'][i, :]) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1105 all_chi_angles.append(template_chi_angles) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1106 all_chi_masks.append(template_chi_mask) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1107 chi_angles = jnp.stack(all_chi_angles, axis=0) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1108 chi_mask = jnp.stack(all_chi_masks, axis=0) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1109 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1110 template_features = jnp.concatenate([ |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1111 aatype_one_hot, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1112 jnp.sin(chi_angles) * chi_mask, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1113 jnp.cos(chi_angles) * chi_mask, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1114 chi_mask], axis=-1) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1115 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1116 template_mask = chi_mask[:, :, 0] |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1117 |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1118 template_activations = common_modules.Linear( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1119 num_channel, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1120 initializer='relu', |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1121 name='template_single_embedding')( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1122 template_features) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1123 template_activations = jax.nn.relu(template_activations) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1124 template_activations = common_modules.Linear( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1125 num_channel, |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1126 initializer='relu', |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1127 name='template_projection')( |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1128 template_activations) |
6c92e000d684
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
galaxy-australia
parents:
diff
changeset
|
1129 return template_activations, template_mask |