Mercurial > repos > galaxy-australia > alphafold2
comparison docker/alphafold/alphafold/data/msa_pairing.py @ 1:6c92e000d684 draft
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author | galaxy-australia |
---|---|
date | Tue, 01 Mar 2022 02:53:05 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
0:7ae9d78b06f5 | 1:6c92e000d684 |
---|---|
1 # Copyright 2021 DeepMind Technologies Limited | |
2 # | |
3 # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 # you may not use this file except in compliance with the License. | |
5 # You may obtain a copy of the License at | |
6 # | |
7 # http://www.apache.org/licenses/LICENSE-2.0 | |
8 # | |
9 # Unless required by applicable law or agreed to in writing, software | |
10 # distributed under the License is distributed on an "AS IS" BASIS, | |
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 # See the License for the specific language governing permissions and | |
13 # limitations under the License. | |
14 | |
15 """Pairing logic for multimer data pipeline.""" | |
16 | |
17 import collections | |
18 import functools | |
19 import re | |
20 import string | |
21 from typing import Any, Dict, Iterable, List, Sequence | |
22 | |
23 from alphafold.common import residue_constants | |
24 from alphafold.data import pipeline | |
25 import numpy as np | |
26 import pandas as pd | |
27 import scipy.linalg | |
28 | |
29 ALPHA_ACCESSION_ID_MAP = {x: y for y, x in enumerate(string.ascii_uppercase)} | |
30 ALPHANUM_ACCESSION_ID_MAP = { | |
31 chr: num for num, chr in enumerate(string.ascii_uppercase + string.digits) | |
32 } # A-Z,0-9 | |
33 NUM_ACCESSION_ID_MAP = {str(x): x for x in range(10)} # 0-9 | |
34 | |
35 MSA_GAP_IDX = residue_constants.restypes_with_x_and_gap.index('-') | |
36 SEQUENCE_GAP_CUTOFF = 0.5 | |
37 SEQUENCE_SIMILARITY_CUTOFF = 0.9 | |
38 | |
39 MSA_PAD_VALUES = {'msa_all_seq': MSA_GAP_IDX, | |
40 'msa_mask_all_seq': 1, | |
41 'deletion_matrix_all_seq': 0, | |
42 'deletion_matrix_int_all_seq': 0, | |
43 'msa': MSA_GAP_IDX, | |
44 'msa_mask': 1, | |
45 'deletion_matrix': 0, | |
46 'deletion_matrix_int': 0} | |
47 | |
48 MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int') | |
49 SEQ_FEATURES = ('residue_index', 'aatype', 'all_atom_positions', | |
50 'all_atom_mask', 'seq_mask', 'between_segment_residues', | |
51 'has_alt_locations', 'has_hetatoms', 'asym_id', 'entity_id', | |
52 'sym_id', 'entity_mask', 'deletion_mean', | |
53 'prediction_atom_mask', | |
54 'literature_positions', 'atom_indices_to_group_indices', | |
55 'rigid_group_default_frame') | |
56 TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions', | |
57 'template_all_atom_mask') | |
58 CHAIN_FEATURES = ('num_alignments', 'seq_length') | |
59 | |
60 | |
61 domain_name_pattern = re.compile( | |
62 r'''^(?P<pdb>[a-z\d]{4}) | |
63 \{(?P<bioassembly>[\d+(\+\d+)?])\} | |
64 (?P<chain>[a-zA-Z\d]+) | |
65 \{(?P<transform_index>\d+)\}$ | |
66 ''', re.VERBOSE) | |
67 | |
68 | |
69 def create_paired_features( | |
70 chains: Iterable[pipeline.FeatureDict], | |
71 prokaryotic: bool, | |
72 ) -> List[pipeline.FeatureDict]: | |
73 """Returns the original chains with paired NUM_SEQ features. | |
74 | |
75 Args: | |
76 chains: A list of feature dictionaries for each chain. | |
77 prokaryotic: Whether the target complex is from a prokaryotic organism. | |
78 Used to determine the distance metric for pairing. | |
79 | |
80 Returns: | |
81 A list of feature dictionaries with sequence features including only | |
82 rows to be paired. | |
83 """ | |
84 chains = list(chains) | |
85 chain_keys = chains[0].keys() | |
86 | |
87 if len(chains) < 2: | |
88 return chains | |
89 else: | |
90 updated_chains = [] | |
91 paired_chains_to_paired_row_indices = pair_sequences( | |
92 chains, prokaryotic) | |
93 paired_rows = reorder_paired_rows( | |
94 paired_chains_to_paired_row_indices) | |
95 | |
96 for chain_num, chain in enumerate(chains): | |
97 new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k} | |
98 for feature_name in chain_keys: | |
99 if feature_name.endswith('_all_seq'): | |
100 feats_padded = pad_features(chain[feature_name], feature_name) | |
101 new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]] | |
102 new_chain['num_alignments_all_seq'] = np.asarray( | |
103 len(paired_rows[:, chain_num])) | |
104 updated_chains.append(new_chain) | |
105 return updated_chains | |
106 | |
107 | |
108 def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray: | |
109 """Add a 'padding' row at the end of the features list. | |
110 | |
111 The padding row will be selected as a 'paired' row in the case of partial | |
112 alignment - for the chain that doesn't have paired alignment. | |
113 | |
114 Args: | |
115 feature: The feature to be padded. | |
116 feature_name: The name of the feature to be padded. | |
117 | |
118 Returns: | |
119 The feature with an additional padding row. | |
120 """ | |
121 assert feature.dtype != np.dtype(np.string_) | |
122 if feature_name in ('msa_all_seq', 'msa_mask_all_seq', | |
123 'deletion_matrix_all_seq', 'deletion_matrix_int_all_seq'): | |
124 num_res = feature.shape[1] | |
125 padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res], | |
126 feature.dtype) | |
127 elif feature_name in ('msa_uniprot_accession_identifiers_all_seq', | |
128 'msa_species_identifiers_all_seq'): | |
129 padding = [b''] | |
130 else: | |
131 return feature | |
132 feats_padded = np.concatenate([feature, padding], axis=0) | |
133 return feats_padded | |
134 | |
135 | |
136 def _make_msa_df(chain_features: pipeline.FeatureDict) -> pd.DataFrame: | |
137 """Makes dataframe with msa features needed for msa pairing.""" | |
138 chain_msa = chain_features['msa_all_seq'] | |
139 query_seq = chain_msa[0] | |
140 per_seq_similarity = np.sum( | |
141 query_seq[None] == chain_msa, axis=-1) / float(len(query_seq)) | |
142 per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq)) | |
143 msa_df = pd.DataFrame({ | |
144 'msa_species_identifiers': | |
145 chain_features['msa_species_identifiers_all_seq'], | |
146 'msa_uniprot_accession_identifiers': | |
147 chain_features['msa_uniprot_accession_identifiers_all_seq'], | |
148 'msa_row': | |
149 np.arange(len( | |
150 chain_features['msa_uniprot_accession_identifiers_all_seq'])), | |
151 'msa_similarity': per_seq_similarity, | |
152 'gap': per_seq_gap | |
153 }) | |
154 return msa_df | |
155 | |
156 | |
157 def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]: | |
158 """Creates mapping from species to msa dataframe of that species.""" | |
159 species_lookup = {} | |
160 for species, species_df in msa_df.groupby('msa_species_identifiers'): | |
161 species_lookup[species] = species_df | |
162 return species_lookup | |
163 | |
164 | |
165 @functools.lru_cache(maxsize=65536) | |
166 def encode_accession(accession_id: str) -> int: | |
167 """Map accession codes to the serial order in which they were assigned.""" | |
168 alpha = ALPHA_ACCESSION_ID_MAP # A-Z | |
169 alphanum = ALPHANUM_ACCESSION_ID_MAP # A-Z,0-9 | |
170 num = NUM_ACCESSION_ID_MAP # 0-9 | |
171 | |
172 coding = 0 | |
173 | |
174 # This is based on the uniprot accession id format | |
175 # https://www.uniprot.org/help/accession_numbers | |
176 if accession_id[0] in {'O', 'P', 'Q'}: | |
177 bases = (alpha, num, alphanum, alphanum, alphanum, num) | |
178 elif len(accession_id) == 6: | |
179 bases = (alpha, num, alpha, alphanum, alphanum, num) | |
180 elif len(accession_id) == 10: | |
181 bases = (alpha, num, alpha, alphanum, alphanum, num, alpha, alphanum, | |
182 alphanum, num) | |
183 | |
184 product = 1 | |
185 for place, base in zip(reversed(accession_id), reversed(bases)): | |
186 coding += base[place] * product | |
187 product *= len(base) | |
188 | |
189 return coding | |
190 | |
191 | |
192 def _calc_id_diff(id_a: bytes, id_b: bytes) -> int: | |
193 return abs(encode_accession(id_a.decode()) - encode_accession(id_b.decode())) | |
194 | |
195 | |
196 def _find_all_accession_matches(accession_id_lists: List[List[bytes]], | |
197 diff_cutoff: int = 20 | |
198 ) -> List[List[Any]]: | |
199 """Finds accession id matches across the chains based on their difference.""" | |
200 all_accession_tuples = [] | |
201 current_tuple = [] | |
202 tokens_used_in_answer = set() | |
203 | |
204 def _matches_all_in_current_tuple(inp: bytes, diff_cutoff: int) -> bool: | |
205 return all((_calc_id_diff(s, inp) < diff_cutoff for s in current_tuple)) | |
206 | |
207 def _all_tokens_not_used_before() -> bool: | |
208 return all((s not in tokens_used_in_answer for s in current_tuple)) | |
209 | |
210 def dfs(level, accession_id, diff_cutoff=diff_cutoff) -> None: | |
211 if level == len(accession_id_lists) - 1: | |
212 if _all_tokens_not_used_before(): | |
213 all_accession_tuples.append(list(current_tuple)) | |
214 for s in current_tuple: | |
215 tokens_used_in_answer.add(s) | |
216 return | |
217 | |
218 if level == -1: | |
219 new_list = accession_id_lists[level+1] | |
220 else: | |
221 new_list = [(_calc_id_diff(accession_id, s), s) for | |
222 s in accession_id_lists[level+1]] | |
223 new_list = sorted(new_list) | |
224 new_list = [s for d, s in new_list] | |
225 | |
226 for s in new_list: | |
227 if (_matches_all_in_current_tuple(s, diff_cutoff) and | |
228 s not in tokens_used_in_answer): | |
229 current_tuple.append(s) | |
230 dfs(level + 1, s) | |
231 current_tuple.pop() | |
232 dfs(-1, '') | |
233 return all_accession_tuples | |
234 | |
235 | |
236 def _accession_row(msa_df: pd.DataFrame, accession_id: bytes) -> pd.Series: | |
237 matched_df = msa_df[msa_df.msa_uniprot_accession_identifiers == accession_id] | |
238 return matched_df.iloc[0] | |
239 | |
240 | |
241 def _match_rows_by_genetic_distance( | |
242 this_species_msa_dfs: List[pd.DataFrame], | |
243 cutoff: int = 20) -> List[List[int]]: | |
244 """Finds MSA sequence pairings across chains within a genetic distance cutoff. | |
245 | |
246 The genetic distance between two sequences is approximated by taking the | |
247 difference in their UniProt accession ids. | |
248 | |
249 Args: | |
250 this_species_msa_dfs: a list of dataframes containing MSA features for | |
251 sequences for a specific species. If species is missing for a chain, the | |
252 dataframe is set to None. | |
253 cutoff: the genetic distance cutoff. | |
254 | |
255 Returns: | |
256 A list of lists, each containing M indices corresponding to paired MSA rows, | |
257 where M is the number of chains. | |
258 """ | |
259 num_examples = len(this_species_msa_dfs) # N | |
260 | |
261 accession_id_lists = [] # M | |
262 match_index_to_chain_index = {} | |
263 for chain_index, species_df in enumerate(this_species_msa_dfs): | |
264 if species_df is not None: | |
265 accession_id_lists.append( | |
266 list(species_df.msa_uniprot_accession_identifiers.values)) | |
267 # Keep track of which of the this_species_msa_dfs are not None. | |
268 match_index_to_chain_index[len(accession_id_lists) - 1] = chain_index | |
269 | |
270 all_accession_id_matches = _find_all_accession_matches( | |
271 accession_id_lists, cutoff) # [k, M] | |
272 | |
273 all_paired_msa_rows = [] # [k, N] | |
274 for accession_id_match in all_accession_id_matches: | |
275 paired_msa_rows = [] | |
276 for match_index, accession_id in enumerate(accession_id_match): | |
277 # Map back to chain index. | |
278 chain_index = match_index_to_chain_index[match_index] | |
279 seq_series = _accession_row( | |
280 this_species_msa_dfs[chain_index], accession_id) | |
281 | |
282 if (seq_series.msa_similarity > SEQUENCE_SIMILARITY_CUTOFF or | |
283 seq_series.gap > SEQUENCE_GAP_CUTOFF): | |
284 continue | |
285 else: | |
286 paired_msa_rows.append(seq_series.msa_row) | |
287 # If a sequence is skipped based on sequence similarity to the respective | |
288 # target sequence or a gap cuttoff, the lengths of accession_id_match and | |
289 # paired_msa_rows will be different. Skip this match. | |
290 if len(paired_msa_rows) == len(accession_id_match): | |
291 paired_and_non_paired_msa_rows = np.array([-1] * num_examples) | |
292 matched_chain_indices = list(match_index_to_chain_index.values()) | |
293 paired_and_non_paired_msa_rows[matched_chain_indices] = paired_msa_rows | |
294 all_paired_msa_rows.append(list(paired_and_non_paired_msa_rows)) | |
295 return all_paired_msa_rows | |
296 | |
297 | |
298 def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame] | |
299 ) -> List[List[int]]: | |
300 """Finds MSA sequence pairings across chains based on sequence similarity. | |
301 | |
302 Each chain's MSA sequences are first sorted by their sequence similarity to | |
303 their respective target sequence. The sequences are then paired, starting | |
304 from the sequences most similar to their target sequence. | |
305 | |
306 Args: | |
307 this_species_msa_dfs: a list of dataframes containing MSA features for | |
308 sequences for a specific species. | |
309 | |
310 Returns: | |
311 A list of lists, each containing M indices corresponding to paired MSA rows, | |
312 where M is the number of chains. | |
313 """ | |
314 all_paired_msa_rows = [] | |
315 | |
316 num_seqs = [len(species_df) for species_df in this_species_msa_dfs | |
317 if species_df is not None] | |
318 take_num_seqs = np.min(num_seqs) | |
319 | |
320 sort_by_similarity = ( | |
321 lambda x: x.sort_values('msa_similarity', axis=0, ascending=False)) | |
322 | |
323 for species_df in this_species_msa_dfs: | |
324 if species_df is not None: | |
325 species_df_sorted = sort_by_similarity(species_df) | |
326 msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values | |
327 else: | |
328 msa_rows = [-1] * take_num_seqs # take the last 'padding' row | |
329 all_paired_msa_rows.append(msa_rows) | |
330 all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose()) | |
331 return all_paired_msa_rows | |
332 | |
333 | |
334 def pair_sequences(examples: List[pipeline.FeatureDict], | |
335 prokaryotic: bool) -> Dict[int, np.ndarray]: | |
336 """Returns indices for paired MSA sequences across chains.""" | |
337 | |
338 num_examples = len(examples) | |
339 | |
340 all_chain_species_dict = [] | |
341 common_species = set() | |
342 for chain_features in examples: | |
343 msa_df = _make_msa_df(chain_features) | |
344 species_dict = _create_species_dict(msa_df) | |
345 all_chain_species_dict.append(species_dict) | |
346 common_species.update(set(species_dict)) | |
347 | |
348 common_species = sorted(common_species) | |
349 common_species.remove(b'') # Remove target sequence species. | |
350 | |
351 all_paired_msa_rows = [np.zeros(len(examples), int)] | |
352 all_paired_msa_rows_dict = {k: [] for k in range(num_examples)} | |
353 all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)] | |
354 | |
355 for species in common_species: | |
356 if not species: | |
357 continue | |
358 this_species_msa_dfs = [] | |
359 species_dfs_present = 0 | |
360 for species_dict in all_chain_species_dict: | |
361 if species in species_dict: | |
362 this_species_msa_dfs.append(species_dict[species]) | |
363 species_dfs_present += 1 | |
364 else: | |
365 this_species_msa_dfs.append(None) | |
366 | |
367 # Skip species that are present in only one chain. | |
368 if species_dfs_present <= 1: | |
369 continue | |
370 | |
371 if np.any( | |
372 np.array([len(species_df) for species_df in | |
373 this_species_msa_dfs if | |
374 isinstance(species_df, pd.DataFrame)]) > 600): | |
375 continue | |
376 | |
377 # In prokaryotes (and some eukaryotes), interacting genes are often | |
378 # co-located on the chromosome into operons. Because of that we can assume | |
379 # that if two proteins' intergenic distance is less than a threshold, they | |
380 # two proteins will form an an interacting pair. | |
381 # In most eukaryotes, a single protein's MSA can contain many paralogs. | |
382 # Two genes may interact even if they are not close by genomic distance. | |
383 # In case of eukaryotes, some methods pair MSA sequences using sequence | |
384 # similarity method. | |
385 # See Jinbo Xu's work: | |
386 # https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6030867/#B28. | |
387 if prokaryotic: | |
388 paired_msa_rows = _match_rows_by_genetic_distance(this_species_msa_dfs) | |
389 | |
390 if not paired_msa_rows: | |
391 continue | |
392 else: | |
393 paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs) | |
394 all_paired_msa_rows.extend(paired_msa_rows) | |
395 all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows) | |
396 all_paired_msa_rows_dict = { | |
397 num_examples: np.array(paired_msa_rows) for | |
398 num_examples, paired_msa_rows in all_paired_msa_rows_dict.items() | |
399 } | |
400 return all_paired_msa_rows_dict | |
401 | |
402 | |
403 def reorder_paired_rows(all_paired_msa_rows_dict: Dict[int, np.ndarray] | |
404 ) -> np.ndarray: | |
405 """Creates a list of indices of paired MSA rows across chains. | |
406 | |
407 Args: | |
408 all_paired_msa_rows_dict: a mapping from the number of paired chains to the | |
409 paired indices. | |
410 | |
411 Returns: | |
412 a list of lists, each containing indices of paired MSA rows across chains. | |
413 The paired-index lists are ordered by: | |
414 1) the number of chains in the paired alignment, i.e, all-chain pairings | |
415 will come first. | |
416 2) e-values | |
417 """ | |
418 all_paired_msa_rows = [] | |
419 | |
420 for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True): | |
421 paired_rows = all_paired_msa_rows_dict[num_pairings] | |
422 paired_rows_product = abs(np.array([np.prod(rows) for rows in paired_rows])) | |
423 paired_rows_sort_index = np.argsort(paired_rows_product) | |
424 all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index]) | |
425 | |
426 return np.array(all_paired_msa_rows) | |
427 | |
428 | |
429 def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray: | |
430 """Like scipy.linalg.block_diag but with an optional padding value.""" | |
431 ones_arrs = [np.ones_like(x) for x in arrs] | |
432 off_diag_mask = 1.0 - scipy.linalg.block_diag(*ones_arrs) | |
433 diag = scipy.linalg.block_diag(*arrs) | |
434 diag += (off_diag_mask * pad_value).astype(diag.dtype) | |
435 return diag | |
436 | |
437 | |
438 def _correct_post_merged_feats( | |
439 np_example: pipeline.FeatureDict, | |
440 np_chains_list: Sequence[pipeline.FeatureDict], | |
441 pair_msa_sequences: bool) -> pipeline.FeatureDict: | |
442 """Adds features that need to be computed/recomputed post merging.""" | |
443 | |
444 np_example['seq_length'] = np.asarray(np_example['aatype'].shape[0], | |
445 dtype=np.int32) | |
446 np_example['num_alignments'] = np.asarray(np_example['msa'].shape[0], | |
447 dtype=np.int32) | |
448 | |
449 if not pair_msa_sequences: | |
450 # Generate a bias that is 1 for the first row of every block in the | |
451 # block diagonal MSA - i.e. make sure the cluster stack always includes | |
452 # the query sequences for each chain (since the first row is the query | |
453 # sequence). | |
454 cluster_bias_masks = [] | |
455 for chain in np_chains_list: | |
456 mask = np.zeros(chain['msa'].shape[0]) | |
457 mask[0] = 1 | |
458 cluster_bias_masks.append(mask) | |
459 np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks) | |
460 | |
461 # Initialize Bert mask with masked out off diagonals. | |
462 msa_masks = [np.ones(x['msa'].shape, dtype=np.float32) | |
463 for x in np_chains_list] | |
464 | |
465 np_example['bert_mask'] = block_diag( | |
466 *msa_masks, pad_value=0) | |
467 else: | |
468 np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0]) | |
469 np_example['cluster_bias_mask'][0] = 1 | |
470 | |
471 # Initialize Bert mask with masked out off diagonals. | |
472 msa_masks = [np.ones(x['msa'].shape, dtype=np.float32) for | |
473 x in np_chains_list] | |
474 msa_masks_all_seq = [np.ones(x['msa_all_seq'].shape, dtype=np.float32) for | |
475 x in np_chains_list] | |
476 | |
477 msa_mask_block_diag = block_diag( | |
478 *msa_masks, pad_value=0) | |
479 msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1) | |
480 np_example['bert_mask'] = np.concatenate( | |
481 [msa_mask_all_seq, msa_mask_block_diag], axis=0) | |
482 return np_example | |
483 | |
484 | |
485 def _pad_templates(chains: Sequence[pipeline.FeatureDict], | |
486 max_templates: int) -> Sequence[pipeline.FeatureDict]: | |
487 """For each chain pad the number of templates to a fixed size. | |
488 | |
489 Args: | |
490 chains: A list of protein chains. | |
491 max_templates: Each chain will be padded to have this many templates. | |
492 | |
493 Returns: | |
494 The list of chains, updated to have template features padded to | |
495 max_templates. | |
496 """ | |
497 for chain in chains: | |
498 for k, v in chain.items(): | |
499 if k in TEMPLATE_FEATURES: | |
500 padding = np.zeros_like(v.shape) | |
501 padding[0] = max_templates - v.shape[0] | |
502 padding = [(0, p) for p in padding] | |
503 chain[k] = np.pad(v, padding, mode='constant') | |
504 return chains | |
505 | |
506 | |
507 def _merge_features_from_multiple_chains( | |
508 chains: Sequence[pipeline.FeatureDict], | |
509 pair_msa_sequences: bool) -> pipeline.FeatureDict: | |
510 """Merge features from multiple chains. | |
511 | |
512 Args: | |
513 chains: A list of feature dictionaries that we want to merge. | |
514 pair_msa_sequences: Whether to concatenate MSA features along the | |
515 num_res dimension (if True), or to block diagonalize them (if False). | |
516 | |
517 Returns: | |
518 A feature dictionary for the merged example. | |
519 """ | |
520 merged_example = {} | |
521 for feature_name in chains[0]: | |
522 feats = [x[feature_name] for x in chains] | |
523 feature_name_split = feature_name.split('_all_seq')[0] | |
524 if feature_name_split in MSA_FEATURES: | |
525 if pair_msa_sequences or '_all_seq' in feature_name: | |
526 merged_example[feature_name] = np.concatenate(feats, axis=1) | |
527 else: | |
528 merged_example[feature_name] = block_diag( | |
529 *feats, pad_value=MSA_PAD_VALUES[feature_name]) | |
530 elif feature_name_split in SEQ_FEATURES: | |
531 merged_example[feature_name] = np.concatenate(feats, axis=0) | |
532 elif feature_name_split in TEMPLATE_FEATURES: | |
533 merged_example[feature_name] = np.concatenate(feats, axis=1) | |
534 elif feature_name_split in CHAIN_FEATURES: | |
535 merged_example[feature_name] = np.sum(x for x in feats).astype(np.int32) | |
536 else: | |
537 merged_example[feature_name] = feats[0] | |
538 return merged_example | |
539 | |
540 | |
541 def _merge_homomers_dense_msa( | |
542 chains: Iterable[pipeline.FeatureDict]) -> Sequence[pipeline.FeatureDict]: | |
543 """Merge all identical chains, making the resulting MSA dense. | |
544 | |
545 Args: | |
546 chains: An iterable of features for each chain. | |
547 | |
548 Returns: | |
549 A list of feature dictionaries. All features with the same entity_id | |
550 will be merged - MSA features will be concatenated along the num_res | |
551 dimension - making them dense. | |
552 """ | |
553 entity_chains = collections.defaultdict(list) | |
554 for chain in chains: | |
555 entity_id = chain['entity_id'][0] | |
556 entity_chains[entity_id].append(chain) | |
557 | |
558 grouped_chains = [] | |
559 for entity_id in sorted(entity_chains): | |
560 chains = entity_chains[entity_id] | |
561 grouped_chains.append(chains) | |
562 chains = [ | |
563 _merge_features_from_multiple_chains(chains, pair_msa_sequences=True) | |
564 for chains in grouped_chains] | |
565 return chains | |
566 | |
567 | |
568 def _concatenate_paired_and_unpaired_features( | |
569 example: pipeline.FeatureDict) -> pipeline.FeatureDict: | |
570 """Merges paired and block-diagonalised features.""" | |
571 features = MSA_FEATURES | |
572 for feature_name in features: | |
573 if feature_name in example: | |
574 feat = example[feature_name] | |
575 feat_all_seq = example[feature_name + '_all_seq'] | |
576 merged_feat = np.concatenate([feat_all_seq, feat], axis=0) | |
577 example[feature_name] = merged_feat | |
578 example['num_alignments'] = np.array(example['msa'].shape[0], | |
579 dtype=np.int32) | |
580 return example | |
581 | |
582 | |
583 def merge_chain_features(np_chains_list: List[pipeline.FeatureDict], | |
584 pair_msa_sequences: bool, | |
585 max_templates: int) -> pipeline.FeatureDict: | |
586 """Merges features for multiple chains to single FeatureDict. | |
587 | |
588 Args: | |
589 np_chains_list: List of FeatureDicts for each chain. | |
590 pair_msa_sequences: Whether to merge paired MSAs. | |
591 max_templates: The maximum number of templates to include. | |
592 | |
593 Returns: | |
594 Single FeatureDict for entire complex. | |
595 """ | |
596 np_chains_list = _pad_templates( | |
597 np_chains_list, max_templates=max_templates) | |
598 np_chains_list = _merge_homomers_dense_msa(np_chains_list) | |
599 # Unpaired MSA features will be always block-diagonalised; paired MSA | |
600 # features will be concatenated. | |
601 np_example = _merge_features_from_multiple_chains( | |
602 np_chains_list, pair_msa_sequences=False) | |
603 if pair_msa_sequences: | |
604 np_example = _concatenate_paired_and_unpaired_features(np_example) | |
605 np_example = _correct_post_merged_feats( | |
606 np_example=np_example, | |
607 np_chains_list=np_chains_list, | |
608 pair_msa_sequences=pair_msa_sequences) | |
609 | |
610 return np_example | |
611 | |
612 | |
613 def deduplicate_unpaired_sequences( | |
614 np_chains: List[pipeline.FeatureDict]) -> List[pipeline.FeatureDict]: | |
615 """Removes unpaired sequences which duplicate a paired sequence.""" | |
616 | |
617 feature_names = np_chains[0].keys() | |
618 msa_features = MSA_FEATURES | |
619 | |
620 for chain in np_chains: | |
621 sequence_set = set(tuple(s) for s in chain['msa_all_seq']) | |
622 keep_rows = [] | |
623 # Go through unpaired MSA seqs and remove any rows that correspond to the | |
624 # sequences that are already present in the paired MSA. | |
625 for row_num, seq in enumerate(chain['msa']): | |
626 if tuple(seq) not in sequence_set: | |
627 keep_rows.append(row_num) | |
628 for feature_name in feature_names: | |
629 if feature_name in msa_features: | |
630 if keep_rows: | |
631 chain[feature_name] = chain[feature_name][keep_rows] | |
632 else: | |
633 new_shape = list(chain[feature_name].shape) | |
634 new_shape[0] = 0 | |
635 chain[feature_name] = np.zeros(new_shape, | |
636 dtype=chain[feature_name].dtype) | |
637 chain['num_alignments'] = np.array(chain['msa'].shape[0], dtype=np.int32) | |
638 return np_chains |