Mercurial > repos > galaxy-australia > alphafold2
comparison docker/alphafold/alphafold/data/pipeline_multimer.py @ 1:6c92e000d684 draft
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author | galaxy-australia |
---|---|
date | Tue, 01 Mar 2022 02:53:05 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
0:7ae9d78b06f5 | 1:6c92e000d684 |
---|---|
1 # Copyright 2021 DeepMind Technologies Limited | |
2 # | |
3 # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 # you may not use this file except in compliance with the License. | |
5 # You may obtain a copy of the License at | |
6 # | |
7 # http://www.apache.org/licenses/LICENSE-2.0 | |
8 # | |
9 # Unless required by applicable law or agreed to in writing, software | |
10 # distributed under the License is distributed on an "AS IS" BASIS, | |
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 # See the License for the specific language governing permissions and | |
13 # limitations under the License. | |
14 | |
15 """Functions for building the features for the AlphaFold multimer model.""" | |
16 | |
17 import collections | |
18 import contextlib | |
19 import copy | |
20 import dataclasses | |
21 import json | |
22 import os | |
23 import tempfile | |
24 from typing import Mapping, MutableMapping, Sequence | |
25 | |
26 from absl import logging | |
27 from alphafold.common import protein | |
28 from alphafold.common import residue_constants | |
29 from alphafold.data import feature_processing | |
30 from alphafold.data import msa_pairing | |
31 from alphafold.data import parsers | |
32 from alphafold.data import pipeline | |
33 from alphafold.data.tools import jackhmmer | |
34 import numpy as np | |
35 | |
36 # Internal import (7716). | |
37 | |
38 | |
39 @dataclasses.dataclass(frozen=True) | |
40 class _FastaChain: | |
41 sequence: str | |
42 description: str | |
43 | |
44 | |
45 def _make_chain_id_map(*, | |
46 sequences: Sequence[str], | |
47 descriptions: Sequence[str], | |
48 ) -> Mapping[str, _FastaChain]: | |
49 """Makes a mapping from PDB-format chain ID to sequence and description.""" | |
50 if len(sequences) != len(descriptions): | |
51 raise ValueError('sequences and descriptions must have equal length. ' | |
52 f'Got {len(sequences)} != {len(descriptions)}.') | |
53 if len(sequences) > protein.PDB_MAX_CHAINS: | |
54 raise ValueError('Cannot process more chains than the PDB format supports. ' | |
55 f'Got {len(sequences)} chains.') | |
56 chain_id_map = {} | |
57 for chain_id, sequence, description in zip( | |
58 protein.PDB_CHAIN_IDS, sequences, descriptions): | |
59 chain_id_map[chain_id] = _FastaChain( | |
60 sequence=sequence, description=description) | |
61 return chain_id_map | |
62 | |
63 | |
64 @contextlib.contextmanager | |
65 def temp_fasta_file(fasta_str: str): | |
66 with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file: | |
67 fasta_file.write(fasta_str) | |
68 fasta_file.seek(0) | |
69 yield fasta_file.name | |
70 | |
71 | |
72 def convert_monomer_features( | |
73 monomer_features: pipeline.FeatureDict, | |
74 chain_id: str) -> pipeline.FeatureDict: | |
75 """Reshapes and modifies monomer features for multimer models.""" | |
76 converted = {} | |
77 converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_) | |
78 unnecessary_leading_dim_feats = { | |
79 'sequence', 'domain_name', 'num_alignments', 'seq_length'} | |
80 for feature_name, feature in monomer_features.items(): | |
81 if feature_name in unnecessary_leading_dim_feats: | |
82 # asarray ensures it's a np.ndarray. | |
83 feature = np.asarray(feature[0], dtype=feature.dtype) | |
84 elif feature_name == 'aatype': | |
85 # The multimer model performs the one-hot operation itself. | |
86 feature = np.argmax(feature, axis=-1).astype(np.int32) | |
87 elif feature_name == 'template_aatype': | |
88 feature = np.argmax(feature, axis=-1).astype(np.int32) | |
89 new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE | |
90 feature = np.take(new_order_list, feature.astype(np.int32), axis=0) | |
91 elif feature_name == 'template_all_atom_masks': | |
92 feature_name = 'template_all_atom_mask' | |
93 converted[feature_name] = feature | |
94 return converted | |
95 | |
96 | |
97 def int_id_to_str_id(num: int) -> str: | |
98 """Encodes a number as a string, using reverse spreadsheet style naming. | |
99 | |
100 Args: | |
101 num: A positive integer. | |
102 | |
103 Returns: | |
104 A string that encodes the positive integer using reverse spreadsheet style, | |
105 naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the | |
106 usual way to encode chain IDs in mmCIF files. | |
107 """ | |
108 if num <= 0: | |
109 raise ValueError(f'Only positive integers allowed, got {num}.') | |
110 | |
111 num = num - 1 # 1-based indexing. | |
112 output = [] | |
113 while num >= 0: | |
114 output.append(chr(num % 26 + ord('A'))) | |
115 num = num // 26 - 1 | |
116 return ''.join(output) | |
117 | |
118 | |
119 def add_assembly_features( | |
120 all_chain_features: MutableMapping[str, pipeline.FeatureDict], | |
121 ) -> MutableMapping[str, pipeline.FeatureDict]: | |
122 """Add features to distinguish between chains. | |
123 | |
124 Args: | |
125 all_chain_features: A dictionary which maps chain_id to a dictionary of | |
126 features for each chain. | |
127 | |
128 Returns: | |
129 all_chain_features: A dictionary which maps strings of the form | |
130 `<seq_id>_<sym_id>` to the corresponding chain features. E.g. two | |
131 chains from a homodimer would have keys A_1 and A_2. Two chains from a | |
132 heterodimer would have keys A_1 and B_1. | |
133 """ | |
134 # Group the chains by sequence | |
135 seq_to_entity_id = {} | |
136 grouped_chains = collections.defaultdict(list) | |
137 for chain_id, chain_features in all_chain_features.items(): | |
138 seq = str(chain_features['sequence']) | |
139 if seq not in seq_to_entity_id: | |
140 seq_to_entity_id[seq] = len(seq_to_entity_id) + 1 | |
141 grouped_chains[seq_to_entity_id[seq]].append(chain_features) | |
142 | |
143 new_all_chain_features = {} | |
144 chain_id = 1 | |
145 for entity_id, group_chain_features in grouped_chains.items(): | |
146 for sym_id, chain_features in enumerate(group_chain_features, start=1): | |
147 new_all_chain_features[ | |
148 f'{int_id_to_str_id(entity_id)}_{sym_id}'] = chain_features | |
149 seq_length = chain_features['seq_length'] | |
150 chain_features['asym_id'] = chain_id * np.ones(seq_length) | |
151 chain_features['sym_id'] = sym_id * np.ones(seq_length) | |
152 chain_features['entity_id'] = entity_id * np.ones(seq_length) | |
153 chain_id += 1 | |
154 | |
155 return new_all_chain_features | |
156 | |
157 | |
158 def pad_msa(np_example, min_num_seq): | |
159 np_example = dict(np_example) | |
160 num_seq = np_example['msa'].shape[0] | |
161 if num_seq < min_num_seq: | |
162 for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask'): | |
163 np_example[feat] = np.pad( | |
164 np_example[feat], ((0, min_num_seq - num_seq), (0, 0))) | |
165 np_example['cluster_bias_mask'] = np.pad( | |
166 np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq),)) | |
167 return np_example | |
168 | |
169 | |
170 class DataPipeline: | |
171 """Runs the alignment tools and assembles the input features.""" | |
172 | |
173 def __init__(self, | |
174 monomer_data_pipeline: pipeline.DataPipeline, | |
175 jackhmmer_binary_path: str, | |
176 uniprot_database_path: str, | |
177 max_uniprot_hits: int = 50000, | |
178 use_precomputed_msas: bool = False): | |
179 """Initializes the data pipeline. | |
180 | |
181 Args: | |
182 monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs | |
183 the data pipeline for the monomer AlphaFold system. | |
184 jackhmmer_binary_path: Location of the jackhmmer binary. | |
185 uniprot_database_path: Location of the unclustered uniprot sequences, that | |
186 will be searched with jackhmmer and used for MSA pairing. | |
187 max_uniprot_hits: The maximum number of hits to return from uniprot. | |
188 use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold. | |
189 """ | |
190 self._monomer_data_pipeline = monomer_data_pipeline | |
191 self._uniprot_msa_runner = jackhmmer.Jackhmmer( | |
192 binary_path=jackhmmer_binary_path, | |
193 database_path=uniprot_database_path) | |
194 self._max_uniprot_hits = max_uniprot_hits | |
195 self.use_precomputed_msas = use_precomputed_msas | |
196 | |
197 def _process_single_chain( | |
198 self, | |
199 chain_id: str, | |
200 sequence: str, | |
201 description: str, | |
202 msa_output_dir: str, | |
203 is_homomer_or_monomer: bool) -> pipeline.FeatureDict: | |
204 """Runs the monomer pipeline on a single chain.""" | |
205 chain_fasta_str = f'>chain_{chain_id}\n{sequence}\n' | |
206 chain_msa_output_dir = os.path.join(msa_output_dir, chain_id) | |
207 if not os.path.exists(chain_msa_output_dir): | |
208 os.makedirs(chain_msa_output_dir) | |
209 with temp_fasta_file(chain_fasta_str) as chain_fasta_path: | |
210 logging.info('Running monomer pipeline on chain %s: %s', | |
211 chain_id, description) | |
212 chain_features = self._monomer_data_pipeline.process( | |
213 input_fasta_path=chain_fasta_path, | |
214 msa_output_dir=chain_msa_output_dir) | |
215 | |
216 # We only construct the pairing features if there are 2 or more unique | |
217 # sequences. | |
218 if not is_homomer_or_monomer: | |
219 all_seq_msa_features = self._all_seq_msa_features(chain_fasta_path, | |
220 chain_msa_output_dir) | |
221 chain_features.update(all_seq_msa_features) | |
222 return chain_features | |
223 | |
224 def _all_seq_msa_features(self, input_fasta_path, msa_output_dir): | |
225 """Get MSA features for unclustered uniprot, for pairing.""" | |
226 out_path = os.path.join(msa_output_dir, 'uniprot_hits.sto') | |
227 result = pipeline.run_msa_tool( | |
228 self._uniprot_msa_runner, input_fasta_path, out_path, 'sto', | |
229 self.use_precomputed_msas) | |
230 msa = parsers.parse_stockholm(result['sto']) | |
231 msa = msa.truncate(max_seqs=self._max_uniprot_hits) | |
232 all_seq_features = pipeline.make_msa_features([msa]) | |
233 valid_feats = msa_pairing.MSA_FEATURES + ( | |
234 'msa_uniprot_accession_identifiers', | |
235 'msa_species_identifiers', | |
236 ) | |
237 feats = {f'{k}_all_seq': v for k, v in all_seq_features.items() | |
238 if k in valid_feats} | |
239 return feats | |
240 | |
241 def process(self, | |
242 input_fasta_path: str, | |
243 msa_output_dir: str, | |
244 is_prokaryote: bool = False) -> pipeline.FeatureDict: | |
245 """Runs alignment tools on the input sequences and creates features.""" | |
246 with open(input_fasta_path) as f: | |
247 input_fasta_str = f.read() | |
248 input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) | |
249 | |
250 chain_id_map = _make_chain_id_map(sequences=input_seqs, | |
251 descriptions=input_descs) | |
252 chain_id_map_path = os.path.join(msa_output_dir, 'chain_id_map.json') | |
253 with open(chain_id_map_path, 'w') as f: | |
254 chain_id_map_dict = {chain_id: dataclasses.asdict(fasta_chain) | |
255 for chain_id, fasta_chain in chain_id_map.items()} | |
256 json.dump(chain_id_map_dict, f, indent=4, sort_keys=True) | |
257 | |
258 all_chain_features = {} | |
259 sequence_features = {} | |
260 is_homomer_or_monomer = len(set(input_seqs)) == 1 | |
261 for chain_id, fasta_chain in chain_id_map.items(): | |
262 if fasta_chain.sequence in sequence_features: | |
263 all_chain_features[chain_id] = copy.deepcopy( | |
264 sequence_features[fasta_chain.sequence]) | |
265 continue | |
266 chain_features = self._process_single_chain( | |
267 chain_id=chain_id, | |
268 sequence=fasta_chain.sequence, | |
269 description=fasta_chain.description, | |
270 msa_output_dir=msa_output_dir, | |
271 is_homomer_or_monomer=is_homomer_or_monomer) | |
272 | |
273 chain_features = convert_monomer_features(chain_features, | |
274 chain_id=chain_id) | |
275 all_chain_features[chain_id] = chain_features | |
276 sequence_features[fasta_chain.sequence] = chain_features | |
277 | |
278 all_chain_features = add_assembly_features(all_chain_features) | |
279 | |
280 np_example = feature_processing.pair_and_merge( | |
281 all_chain_features=all_chain_features, | |
282 is_prokaryote=is_prokaryote, | |
283 ) | |
284 | |
285 # Pad MSA to avoid zero-sized extra_msa. | |
286 np_example = pad_msa(np_example, 512) | |
287 | |
288 return np_example |