comparison docker/alphafold/alphafold/data/pipeline_multimer.py @ 1:6c92e000d684 draft

"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author galaxy-australia
date Tue, 01 Mar 2022 02:53:05 +0000
parents
children
comparison
equal deleted inserted replaced
0:7ae9d78b06f5 1:6c92e000d684
1 # Copyright 2021 DeepMind Technologies Limited
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 """Functions for building the features for the AlphaFold multimer model."""
16
17 import collections
18 import contextlib
19 import copy
20 import dataclasses
21 import json
22 import os
23 import tempfile
24 from typing import Mapping, MutableMapping, Sequence
25
26 from absl import logging
27 from alphafold.common import protein
28 from alphafold.common import residue_constants
29 from alphafold.data import feature_processing
30 from alphafold.data import msa_pairing
31 from alphafold.data import parsers
32 from alphafold.data import pipeline
33 from alphafold.data.tools import jackhmmer
34 import numpy as np
35
36 # Internal import (7716).
37
38
39 @dataclasses.dataclass(frozen=True)
40 class _FastaChain:
41 sequence: str
42 description: str
43
44
45 def _make_chain_id_map(*,
46 sequences: Sequence[str],
47 descriptions: Sequence[str],
48 ) -> Mapping[str, _FastaChain]:
49 """Makes a mapping from PDB-format chain ID to sequence and description."""
50 if len(sequences) != len(descriptions):
51 raise ValueError('sequences and descriptions must have equal length. '
52 f'Got {len(sequences)} != {len(descriptions)}.')
53 if len(sequences) > protein.PDB_MAX_CHAINS:
54 raise ValueError('Cannot process more chains than the PDB format supports. '
55 f'Got {len(sequences)} chains.')
56 chain_id_map = {}
57 for chain_id, sequence, description in zip(
58 protein.PDB_CHAIN_IDS, sequences, descriptions):
59 chain_id_map[chain_id] = _FastaChain(
60 sequence=sequence, description=description)
61 return chain_id_map
62
63
64 @contextlib.contextmanager
65 def temp_fasta_file(fasta_str: str):
66 with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file:
67 fasta_file.write(fasta_str)
68 fasta_file.seek(0)
69 yield fasta_file.name
70
71
72 def convert_monomer_features(
73 monomer_features: pipeline.FeatureDict,
74 chain_id: str) -> pipeline.FeatureDict:
75 """Reshapes and modifies monomer features for multimer models."""
76 converted = {}
77 converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_)
78 unnecessary_leading_dim_feats = {
79 'sequence', 'domain_name', 'num_alignments', 'seq_length'}
80 for feature_name, feature in monomer_features.items():
81 if feature_name in unnecessary_leading_dim_feats:
82 # asarray ensures it's a np.ndarray.
83 feature = np.asarray(feature[0], dtype=feature.dtype)
84 elif feature_name == 'aatype':
85 # The multimer model performs the one-hot operation itself.
86 feature = np.argmax(feature, axis=-1).astype(np.int32)
87 elif feature_name == 'template_aatype':
88 feature = np.argmax(feature, axis=-1).astype(np.int32)
89 new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
90 feature = np.take(new_order_list, feature.astype(np.int32), axis=0)
91 elif feature_name == 'template_all_atom_masks':
92 feature_name = 'template_all_atom_mask'
93 converted[feature_name] = feature
94 return converted
95
96
97 def int_id_to_str_id(num: int) -> str:
98 """Encodes a number as a string, using reverse spreadsheet style naming.
99
100 Args:
101 num: A positive integer.
102
103 Returns:
104 A string that encodes the positive integer using reverse spreadsheet style,
105 naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the
106 usual way to encode chain IDs in mmCIF files.
107 """
108 if num <= 0:
109 raise ValueError(f'Only positive integers allowed, got {num}.')
110
111 num = num - 1 # 1-based indexing.
112 output = []
113 while num >= 0:
114 output.append(chr(num % 26 + ord('A')))
115 num = num // 26 - 1
116 return ''.join(output)
117
118
119 def add_assembly_features(
120 all_chain_features: MutableMapping[str, pipeline.FeatureDict],
121 ) -> MutableMapping[str, pipeline.FeatureDict]:
122 """Add features to distinguish between chains.
123
124 Args:
125 all_chain_features: A dictionary which maps chain_id to a dictionary of
126 features for each chain.
127
128 Returns:
129 all_chain_features: A dictionary which maps strings of the form
130 `<seq_id>_<sym_id>` to the corresponding chain features. E.g. two
131 chains from a homodimer would have keys A_1 and A_2. Two chains from a
132 heterodimer would have keys A_1 and B_1.
133 """
134 # Group the chains by sequence
135 seq_to_entity_id = {}
136 grouped_chains = collections.defaultdict(list)
137 for chain_id, chain_features in all_chain_features.items():
138 seq = str(chain_features['sequence'])
139 if seq not in seq_to_entity_id:
140 seq_to_entity_id[seq] = len(seq_to_entity_id) + 1
141 grouped_chains[seq_to_entity_id[seq]].append(chain_features)
142
143 new_all_chain_features = {}
144 chain_id = 1
145 for entity_id, group_chain_features in grouped_chains.items():
146 for sym_id, chain_features in enumerate(group_chain_features, start=1):
147 new_all_chain_features[
148 f'{int_id_to_str_id(entity_id)}_{sym_id}'] = chain_features
149 seq_length = chain_features['seq_length']
150 chain_features['asym_id'] = chain_id * np.ones(seq_length)
151 chain_features['sym_id'] = sym_id * np.ones(seq_length)
152 chain_features['entity_id'] = entity_id * np.ones(seq_length)
153 chain_id += 1
154
155 return new_all_chain_features
156
157
158 def pad_msa(np_example, min_num_seq):
159 np_example = dict(np_example)
160 num_seq = np_example['msa'].shape[0]
161 if num_seq < min_num_seq:
162 for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask'):
163 np_example[feat] = np.pad(
164 np_example[feat], ((0, min_num_seq - num_seq), (0, 0)))
165 np_example['cluster_bias_mask'] = np.pad(
166 np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq),))
167 return np_example
168
169
170 class DataPipeline:
171 """Runs the alignment tools and assembles the input features."""
172
173 def __init__(self,
174 monomer_data_pipeline: pipeline.DataPipeline,
175 jackhmmer_binary_path: str,
176 uniprot_database_path: str,
177 max_uniprot_hits: int = 50000,
178 use_precomputed_msas: bool = False):
179 """Initializes the data pipeline.
180
181 Args:
182 monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs
183 the data pipeline for the monomer AlphaFold system.
184 jackhmmer_binary_path: Location of the jackhmmer binary.
185 uniprot_database_path: Location of the unclustered uniprot sequences, that
186 will be searched with jackhmmer and used for MSA pairing.
187 max_uniprot_hits: The maximum number of hits to return from uniprot.
188 use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold.
189 """
190 self._monomer_data_pipeline = monomer_data_pipeline
191 self._uniprot_msa_runner = jackhmmer.Jackhmmer(
192 binary_path=jackhmmer_binary_path,
193 database_path=uniprot_database_path)
194 self._max_uniprot_hits = max_uniprot_hits
195 self.use_precomputed_msas = use_precomputed_msas
196
197 def _process_single_chain(
198 self,
199 chain_id: str,
200 sequence: str,
201 description: str,
202 msa_output_dir: str,
203 is_homomer_or_monomer: bool) -> pipeline.FeatureDict:
204 """Runs the monomer pipeline on a single chain."""
205 chain_fasta_str = f'>chain_{chain_id}\n{sequence}\n'
206 chain_msa_output_dir = os.path.join(msa_output_dir, chain_id)
207 if not os.path.exists(chain_msa_output_dir):
208 os.makedirs(chain_msa_output_dir)
209 with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
210 logging.info('Running monomer pipeline on chain %s: %s',
211 chain_id, description)
212 chain_features = self._monomer_data_pipeline.process(
213 input_fasta_path=chain_fasta_path,
214 msa_output_dir=chain_msa_output_dir)
215
216 # We only construct the pairing features if there are 2 or more unique
217 # sequences.
218 if not is_homomer_or_monomer:
219 all_seq_msa_features = self._all_seq_msa_features(chain_fasta_path,
220 chain_msa_output_dir)
221 chain_features.update(all_seq_msa_features)
222 return chain_features
223
224 def _all_seq_msa_features(self, input_fasta_path, msa_output_dir):
225 """Get MSA features for unclustered uniprot, for pairing."""
226 out_path = os.path.join(msa_output_dir, 'uniprot_hits.sto')
227 result = pipeline.run_msa_tool(
228 self._uniprot_msa_runner, input_fasta_path, out_path, 'sto',
229 self.use_precomputed_msas)
230 msa = parsers.parse_stockholm(result['sto'])
231 msa = msa.truncate(max_seqs=self._max_uniprot_hits)
232 all_seq_features = pipeline.make_msa_features([msa])
233 valid_feats = msa_pairing.MSA_FEATURES + (
234 'msa_uniprot_accession_identifiers',
235 'msa_species_identifiers',
236 )
237 feats = {f'{k}_all_seq': v for k, v in all_seq_features.items()
238 if k in valid_feats}
239 return feats
240
241 def process(self,
242 input_fasta_path: str,
243 msa_output_dir: str,
244 is_prokaryote: bool = False) -> pipeline.FeatureDict:
245 """Runs alignment tools on the input sequences and creates features."""
246 with open(input_fasta_path) as f:
247 input_fasta_str = f.read()
248 input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
249
250 chain_id_map = _make_chain_id_map(sequences=input_seqs,
251 descriptions=input_descs)
252 chain_id_map_path = os.path.join(msa_output_dir, 'chain_id_map.json')
253 with open(chain_id_map_path, 'w') as f:
254 chain_id_map_dict = {chain_id: dataclasses.asdict(fasta_chain)
255 for chain_id, fasta_chain in chain_id_map.items()}
256 json.dump(chain_id_map_dict, f, indent=4, sort_keys=True)
257
258 all_chain_features = {}
259 sequence_features = {}
260 is_homomer_or_monomer = len(set(input_seqs)) == 1
261 for chain_id, fasta_chain in chain_id_map.items():
262 if fasta_chain.sequence in sequence_features:
263 all_chain_features[chain_id] = copy.deepcopy(
264 sequence_features[fasta_chain.sequence])
265 continue
266 chain_features = self._process_single_chain(
267 chain_id=chain_id,
268 sequence=fasta_chain.sequence,
269 description=fasta_chain.description,
270 msa_output_dir=msa_output_dir,
271 is_homomer_or_monomer=is_homomer_or_monomer)
272
273 chain_features = convert_monomer_features(chain_features,
274 chain_id=chain_id)
275 all_chain_features[chain_id] = chain_features
276 sequence_features[fasta_chain.sequence] = chain_features
277
278 all_chain_features = add_assembly_features(all_chain_features)
279
280 np_example = feature_processing.pair_and_merge(
281 all_chain_features=all_chain_features,
282 is_prokaryote=is_prokaryote,
283 )
284
285 # Pad MSA to avoid zero-sized extra_msa.
286 np_example = pad_msa(np_example, 512)
287
288 return np_example