comparison docker/alphafold/alphafold/data/msa_pairing.py @ 1:6c92e000d684 draft

"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author galaxy-australia
date Tue, 01 Mar 2022 02:53:05 +0000
parents
children
comparison
equal deleted inserted replaced
0:7ae9d78b06f5 1:6c92e000d684
1 # Copyright 2021 DeepMind Technologies Limited
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 """Pairing logic for multimer data pipeline."""
16
17 import collections
18 import functools
19 import re
20 import string
21 from typing import Any, Dict, Iterable, List, Sequence
22
23 from alphafold.common import residue_constants
24 from alphafold.data import pipeline
25 import numpy as np
26 import pandas as pd
27 import scipy.linalg
28
29 ALPHA_ACCESSION_ID_MAP = {x: y for y, x in enumerate(string.ascii_uppercase)}
30 ALPHANUM_ACCESSION_ID_MAP = {
31 chr: num for num, chr in enumerate(string.ascii_uppercase + string.digits)
32 } # A-Z,0-9
33 NUM_ACCESSION_ID_MAP = {str(x): x for x in range(10)} # 0-9
34
35 MSA_GAP_IDX = residue_constants.restypes_with_x_and_gap.index('-')
36 SEQUENCE_GAP_CUTOFF = 0.5
37 SEQUENCE_SIMILARITY_CUTOFF = 0.9
38
39 MSA_PAD_VALUES = {'msa_all_seq': MSA_GAP_IDX,
40 'msa_mask_all_seq': 1,
41 'deletion_matrix_all_seq': 0,
42 'deletion_matrix_int_all_seq': 0,
43 'msa': MSA_GAP_IDX,
44 'msa_mask': 1,
45 'deletion_matrix': 0,
46 'deletion_matrix_int': 0}
47
48 MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int')
49 SEQ_FEATURES = ('residue_index', 'aatype', 'all_atom_positions',
50 'all_atom_mask', 'seq_mask', 'between_segment_residues',
51 'has_alt_locations', 'has_hetatoms', 'asym_id', 'entity_id',
52 'sym_id', 'entity_mask', 'deletion_mean',
53 'prediction_atom_mask',
54 'literature_positions', 'atom_indices_to_group_indices',
55 'rigid_group_default_frame')
56 TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions',
57 'template_all_atom_mask')
58 CHAIN_FEATURES = ('num_alignments', 'seq_length')
59
60
61 domain_name_pattern = re.compile(
62 r'''^(?P<pdb>[a-z\d]{4})
63 \{(?P<bioassembly>[\d+(\+\d+)?])\}
64 (?P<chain>[a-zA-Z\d]+)
65 \{(?P<transform_index>\d+)\}$
66 ''', re.VERBOSE)
67
68
69 def create_paired_features(
70 chains: Iterable[pipeline.FeatureDict],
71 prokaryotic: bool,
72 ) -> List[pipeline.FeatureDict]:
73 """Returns the original chains with paired NUM_SEQ features.
74
75 Args:
76 chains: A list of feature dictionaries for each chain.
77 prokaryotic: Whether the target complex is from a prokaryotic organism.
78 Used to determine the distance metric for pairing.
79
80 Returns:
81 A list of feature dictionaries with sequence features including only
82 rows to be paired.
83 """
84 chains = list(chains)
85 chain_keys = chains[0].keys()
86
87 if len(chains) < 2:
88 return chains
89 else:
90 updated_chains = []
91 paired_chains_to_paired_row_indices = pair_sequences(
92 chains, prokaryotic)
93 paired_rows = reorder_paired_rows(
94 paired_chains_to_paired_row_indices)
95
96 for chain_num, chain in enumerate(chains):
97 new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k}
98 for feature_name in chain_keys:
99 if feature_name.endswith('_all_seq'):
100 feats_padded = pad_features(chain[feature_name], feature_name)
101 new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]]
102 new_chain['num_alignments_all_seq'] = np.asarray(
103 len(paired_rows[:, chain_num]))
104 updated_chains.append(new_chain)
105 return updated_chains
106
107
108 def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
109 """Add a 'padding' row at the end of the features list.
110
111 The padding row will be selected as a 'paired' row in the case of partial
112 alignment - for the chain that doesn't have paired alignment.
113
114 Args:
115 feature: The feature to be padded.
116 feature_name: The name of the feature to be padded.
117
118 Returns:
119 The feature with an additional padding row.
120 """
121 assert feature.dtype != np.dtype(np.string_)
122 if feature_name in ('msa_all_seq', 'msa_mask_all_seq',
123 'deletion_matrix_all_seq', 'deletion_matrix_int_all_seq'):
124 num_res = feature.shape[1]
125 padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res],
126 feature.dtype)
127 elif feature_name in ('msa_uniprot_accession_identifiers_all_seq',
128 'msa_species_identifiers_all_seq'):
129 padding = [b'']
130 else:
131 return feature
132 feats_padded = np.concatenate([feature, padding], axis=0)
133 return feats_padded
134
135
136 def _make_msa_df(chain_features: pipeline.FeatureDict) -> pd.DataFrame:
137 """Makes dataframe with msa features needed for msa pairing."""
138 chain_msa = chain_features['msa_all_seq']
139 query_seq = chain_msa[0]
140 per_seq_similarity = np.sum(
141 query_seq[None] == chain_msa, axis=-1) / float(len(query_seq))
142 per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq))
143 msa_df = pd.DataFrame({
144 'msa_species_identifiers':
145 chain_features['msa_species_identifiers_all_seq'],
146 'msa_uniprot_accession_identifiers':
147 chain_features['msa_uniprot_accession_identifiers_all_seq'],
148 'msa_row':
149 np.arange(len(
150 chain_features['msa_uniprot_accession_identifiers_all_seq'])),
151 'msa_similarity': per_seq_similarity,
152 'gap': per_seq_gap
153 })
154 return msa_df
155
156
157 def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
158 """Creates mapping from species to msa dataframe of that species."""
159 species_lookup = {}
160 for species, species_df in msa_df.groupby('msa_species_identifiers'):
161 species_lookup[species] = species_df
162 return species_lookup
163
164
165 @functools.lru_cache(maxsize=65536)
166 def encode_accession(accession_id: str) -> int:
167 """Map accession codes to the serial order in which they were assigned."""
168 alpha = ALPHA_ACCESSION_ID_MAP # A-Z
169 alphanum = ALPHANUM_ACCESSION_ID_MAP # A-Z,0-9
170 num = NUM_ACCESSION_ID_MAP # 0-9
171
172 coding = 0
173
174 # This is based on the uniprot accession id format
175 # https://www.uniprot.org/help/accession_numbers
176 if accession_id[0] in {'O', 'P', 'Q'}:
177 bases = (alpha, num, alphanum, alphanum, alphanum, num)
178 elif len(accession_id) == 6:
179 bases = (alpha, num, alpha, alphanum, alphanum, num)
180 elif len(accession_id) == 10:
181 bases = (alpha, num, alpha, alphanum, alphanum, num, alpha, alphanum,
182 alphanum, num)
183
184 product = 1
185 for place, base in zip(reversed(accession_id), reversed(bases)):
186 coding += base[place] * product
187 product *= len(base)
188
189 return coding
190
191
192 def _calc_id_diff(id_a: bytes, id_b: bytes) -> int:
193 return abs(encode_accession(id_a.decode()) - encode_accession(id_b.decode()))
194
195
196 def _find_all_accession_matches(accession_id_lists: List[List[bytes]],
197 diff_cutoff: int = 20
198 ) -> List[List[Any]]:
199 """Finds accession id matches across the chains based on their difference."""
200 all_accession_tuples = []
201 current_tuple = []
202 tokens_used_in_answer = set()
203
204 def _matches_all_in_current_tuple(inp: bytes, diff_cutoff: int) -> bool:
205 return all((_calc_id_diff(s, inp) < diff_cutoff for s in current_tuple))
206
207 def _all_tokens_not_used_before() -> bool:
208 return all((s not in tokens_used_in_answer for s in current_tuple))
209
210 def dfs(level, accession_id, diff_cutoff=diff_cutoff) -> None:
211 if level == len(accession_id_lists) - 1:
212 if _all_tokens_not_used_before():
213 all_accession_tuples.append(list(current_tuple))
214 for s in current_tuple:
215 tokens_used_in_answer.add(s)
216 return
217
218 if level == -1:
219 new_list = accession_id_lists[level+1]
220 else:
221 new_list = [(_calc_id_diff(accession_id, s), s) for
222 s in accession_id_lists[level+1]]
223 new_list = sorted(new_list)
224 new_list = [s for d, s in new_list]
225
226 for s in new_list:
227 if (_matches_all_in_current_tuple(s, diff_cutoff) and
228 s not in tokens_used_in_answer):
229 current_tuple.append(s)
230 dfs(level + 1, s)
231 current_tuple.pop()
232 dfs(-1, '')
233 return all_accession_tuples
234
235
236 def _accession_row(msa_df: pd.DataFrame, accession_id: bytes) -> pd.Series:
237 matched_df = msa_df[msa_df.msa_uniprot_accession_identifiers == accession_id]
238 return matched_df.iloc[0]
239
240
241 def _match_rows_by_genetic_distance(
242 this_species_msa_dfs: List[pd.DataFrame],
243 cutoff: int = 20) -> List[List[int]]:
244 """Finds MSA sequence pairings across chains within a genetic distance cutoff.
245
246 The genetic distance between two sequences is approximated by taking the
247 difference in their UniProt accession ids.
248
249 Args:
250 this_species_msa_dfs: a list of dataframes containing MSA features for
251 sequences for a specific species. If species is missing for a chain, the
252 dataframe is set to None.
253 cutoff: the genetic distance cutoff.
254
255 Returns:
256 A list of lists, each containing M indices corresponding to paired MSA rows,
257 where M is the number of chains.
258 """
259 num_examples = len(this_species_msa_dfs) # N
260
261 accession_id_lists = [] # M
262 match_index_to_chain_index = {}
263 for chain_index, species_df in enumerate(this_species_msa_dfs):
264 if species_df is not None:
265 accession_id_lists.append(
266 list(species_df.msa_uniprot_accession_identifiers.values))
267 # Keep track of which of the this_species_msa_dfs are not None.
268 match_index_to_chain_index[len(accession_id_lists) - 1] = chain_index
269
270 all_accession_id_matches = _find_all_accession_matches(
271 accession_id_lists, cutoff) # [k, M]
272
273 all_paired_msa_rows = [] # [k, N]
274 for accession_id_match in all_accession_id_matches:
275 paired_msa_rows = []
276 for match_index, accession_id in enumerate(accession_id_match):
277 # Map back to chain index.
278 chain_index = match_index_to_chain_index[match_index]
279 seq_series = _accession_row(
280 this_species_msa_dfs[chain_index], accession_id)
281
282 if (seq_series.msa_similarity > SEQUENCE_SIMILARITY_CUTOFF or
283 seq_series.gap > SEQUENCE_GAP_CUTOFF):
284 continue
285 else:
286 paired_msa_rows.append(seq_series.msa_row)
287 # If a sequence is skipped based on sequence similarity to the respective
288 # target sequence or a gap cuttoff, the lengths of accession_id_match and
289 # paired_msa_rows will be different. Skip this match.
290 if len(paired_msa_rows) == len(accession_id_match):
291 paired_and_non_paired_msa_rows = np.array([-1] * num_examples)
292 matched_chain_indices = list(match_index_to_chain_index.values())
293 paired_and_non_paired_msa_rows[matched_chain_indices] = paired_msa_rows
294 all_paired_msa_rows.append(list(paired_and_non_paired_msa_rows))
295 return all_paired_msa_rows
296
297
298 def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
299 ) -> List[List[int]]:
300 """Finds MSA sequence pairings across chains based on sequence similarity.
301
302 Each chain's MSA sequences are first sorted by their sequence similarity to
303 their respective target sequence. The sequences are then paired, starting
304 from the sequences most similar to their target sequence.
305
306 Args:
307 this_species_msa_dfs: a list of dataframes containing MSA features for
308 sequences for a specific species.
309
310 Returns:
311 A list of lists, each containing M indices corresponding to paired MSA rows,
312 where M is the number of chains.
313 """
314 all_paired_msa_rows = []
315
316 num_seqs = [len(species_df) for species_df in this_species_msa_dfs
317 if species_df is not None]
318 take_num_seqs = np.min(num_seqs)
319
320 sort_by_similarity = (
321 lambda x: x.sort_values('msa_similarity', axis=0, ascending=False))
322
323 for species_df in this_species_msa_dfs:
324 if species_df is not None:
325 species_df_sorted = sort_by_similarity(species_df)
326 msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values
327 else:
328 msa_rows = [-1] * take_num_seqs # take the last 'padding' row
329 all_paired_msa_rows.append(msa_rows)
330 all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose())
331 return all_paired_msa_rows
332
333
334 def pair_sequences(examples: List[pipeline.FeatureDict],
335 prokaryotic: bool) -> Dict[int, np.ndarray]:
336 """Returns indices for paired MSA sequences across chains."""
337
338 num_examples = len(examples)
339
340 all_chain_species_dict = []
341 common_species = set()
342 for chain_features in examples:
343 msa_df = _make_msa_df(chain_features)
344 species_dict = _create_species_dict(msa_df)
345 all_chain_species_dict.append(species_dict)
346 common_species.update(set(species_dict))
347
348 common_species = sorted(common_species)
349 common_species.remove(b'') # Remove target sequence species.
350
351 all_paired_msa_rows = [np.zeros(len(examples), int)]
352 all_paired_msa_rows_dict = {k: [] for k in range(num_examples)}
353 all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)]
354
355 for species in common_species:
356 if not species:
357 continue
358 this_species_msa_dfs = []
359 species_dfs_present = 0
360 for species_dict in all_chain_species_dict:
361 if species in species_dict:
362 this_species_msa_dfs.append(species_dict[species])
363 species_dfs_present += 1
364 else:
365 this_species_msa_dfs.append(None)
366
367 # Skip species that are present in only one chain.
368 if species_dfs_present <= 1:
369 continue
370
371 if np.any(
372 np.array([len(species_df) for species_df in
373 this_species_msa_dfs if
374 isinstance(species_df, pd.DataFrame)]) > 600):
375 continue
376
377 # In prokaryotes (and some eukaryotes), interacting genes are often
378 # co-located on the chromosome into operons. Because of that we can assume
379 # that if two proteins' intergenic distance is less than a threshold, they
380 # two proteins will form an an interacting pair.
381 # In most eukaryotes, a single protein's MSA can contain many paralogs.
382 # Two genes may interact even if they are not close by genomic distance.
383 # In case of eukaryotes, some methods pair MSA sequences using sequence
384 # similarity method.
385 # See Jinbo Xu's work:
386 # https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6030867/#B28.
387 if prokaryotic:
388 paired_msa_rows = _match_rows_by_genetic_distance(this_species_msa_dfs)
389
390 if not paired_msa_rows:
391 continue
392 else:
393 paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs)
394 all_paired_msa_rows.extend(paired_msa_rows)
395 all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows)
396 all_paired_msa_rows_dict = {
397 num_examples: np.array(paired_msa_rows) for
398 num_examples, paired_msa_rows in all_paired_msa_rows_dict.items()
399 }
400 return all_paired_msa_rows_dict
401
402
403 def reorder_paired_rows(all_paired_msa_rows_dict: Dict[int, np.ndarray]
404 ) -> np.ndarray:
405 """Creates a list of indices of paired MSA rows across chains.
406
407 Args:
408 all_paired_msa_rows_dict: a mapping from the number of paired chains to the
409 paired indices.
410
411 Returns:
412 a list of lists, each containing indices of paired MSA rows across chains.
413 The paired-index lists are ordered by:
414 1) the number of chains in the paired alignment, i.e, all-chain pairings
415 will come first.
416 2) e-values
417 """
418 all_paired_msa_rows = []
419
420 for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True):
421 paired_rows = all_paired_msa_rows_dict[num_pairings]
422 paired_rows_product = abs(np.array([np.prod(rows) for rows in paired_rows]))
423 paired_rows_sort_index = np.argsort(paired_rows_product)
424 all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index])
425
426 return np.array(all_paired_msa_rows)
427
428
429 def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
430 """Like scipy.linalg.block_diag but with an optional padding value."""
431 ones_arrs = [np.ones_like(x) for x in arrs]
432 off_diag_mask = 1.0 - scipy.linalg.block_diag(*ones_arrs)
433 diag = scipy.linalg.block_diag(*arrs)
434 diag += (off_diag_mask * pad_value).astype(diag.dtype)
435 return diag
436
437
438 def _correct_post_merged_feats(
439 np_example: pipeline.FeatureDict,
440 np_chains_list: Sequence[pipeline.FeatureDict],
441 pair_msa_sequences: bool) -> pipeline.FeatureDict:
442 """Adds features that need to be computed/recomputed post merging."""
443
444 np_example['seq_length'] = np.asarray(np_example['aatype'].shape[0],
445 dtype=np.int32)
446 np_example['num_alignments'] = np.asarray(np_example['msa'].shape[0],
447 dtype=np.int32)
448
449 if not pair_msa_sequences:
450 # Generate a bias that is 1 for the first row of every block in the
451 # block diagonal MSA - i.e. make sure the cluster stack always includes
452 # the query sequences for each chain (since the first row is the query
453 # sequence).
454 cluster_bias_masks = []
455 for chain in np_chains_list:
456 mask = np.zeros(chain['msa'].shape[0])
457 mask[0] = 1
458 cluster_bias_masks.append(mask)
459 np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks)
460
461 # Initialize Bert mask with masked out off diagonals.
462 msa_masks = [np.ones(x['msa'].shape, dtype=np.float32)
463 for x in np_chains_list]
464
465 np_example['bert_mask'] = block_diag(
466 *msa_masks, pad_value=0)
467 else:
468 np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0])
469 np_example['cluster_bias_mask'][0] = 1
470
471 # Initialize Bert mask with masked out off diagonals.
472 msa_masks = [np.ones(x['msa'].shape, dtype=np.float32) for
473 x in np_chains_list]
474 msa_masks_all_seq = [np.ones(x['msa_all_seq'].shape, dtype=np.float32) for
475 x in np_chains_list]
476
477 msa_mask_block_diag = block_diag(
478 *msa_masks, pad_value=0)
479 msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1)
480 np_example['bert_mask'] = np.concatenate(
481 [msa_mask_all_seq, msa_mask_block_diag], axis=0)
482 return np_example
483
484
485 def _pad_templates(chains: Sequence[pipeline.FeatureDict],
486 max_templates: int) -> Sequence[pipeline.FeatureDict]:
487 """For each chain pad the number of templates to a fixed size.
488
489 Args:
490 chains: A list of protein chains.
491 max_templates: Each chain will be padded to have this many templates.
492
493 Returns:
494 The list of chains, updated to have template features padded to
495 max_templates.
496 """
497 for chain in chains:
498 for k, v in chain.items():
499 if k in TEMPLATE_FEATURES:
500 padding = np.zeros_like(v.shape)
501 padding[0] = max_templates - v.shape[0]
502 padding = [(0, p) for p in padding]
503 chain[k] = np.pad(v, padding, mode='constant')
504 return chains
505
506
507 def _merge_features_from_multiple_chains(
508 chains: Sequence[pipeline.FeatureDict],
509 pair_msa_sequences: bool) -> pipeline.FeatureDict:
510 """Merge features from multiple chains.
511
512 Args:
513 chains: A list of feature dictionaries that we want to merge.
514 pair_msa_sequences: Whether to concatenate MSA features along the
515 num_res dimension (if True), or to block diagonalize them (if False).
516
517 Returns:
518 A feature dictionary for the merged example.
519 """
520 merged_example = {}
521 for feature_name in chains[0]:
522 feats = [x[feature_name] for x in chains]
523 feature_name_split = feature_name.split('_all_seq')[0]
524 if feature_name_split in MSA_FEATURES:
525 if pair_msa_sequences or '_all_seq' in feature_name:
526 merged_example[feature_name] = np.concatenate(feats, axis=1)
527 else:
528 merged_example[feature_name] = block_diag(
529 *feats, pad_value=MSA_PAD_VALUES[feature_name])
530 elif feature_name_split in SEQ_FEATURES:
531 merged_example[feature_name] = np.concatenate(feats, axis=0)
532 elif feature_name_split in TEMPLATE_FEATURES:
533 merged_example[feature_name] = np.concatenate(feats, axis=1)
534 elif feature_name_split in CHAIN_FEATURES:
535 merged_example[feature_name] = np.sum(x for x in feats).astype(np.int32)
536 else:
537 merged_example[feature_name] = feats[0]
538 return merged_example
539
540
541 def _merge_homomers_dense_msa(
542 chains: Iterable[pipeline.FeatureDict]) -> Sequence[pipeline.FeatureDict]:
543 """Merge all identical chains, making the resulting MSA dense.
544
545 Args:
546 chains: An iterable of features for each chain.
547
548 Returns:
549 A list of feature dictionaries. All features with the same entity_id
550 will be merged - MSA features will be concatenated along the num_res
551 dimension - making them dense.
552 """
553 entity_chains = collections.defaultdict(list)
554 for chain in chains:
555 entity_id = chain['entity_id'][0]
556 entity_chains[entity_id].append(chain)
557
558 grouped_chains = []
559 for entity_id in sorted(entity_chains):
560 chains = entity_chains[entity_id]
561 grouped_chains.append(chains)
562 chains = [
563 _merge_features_from_multiple_chains(chains, pair_msa_sequences=True)
564 for chains in grouped_chains]
565 return chains
566
567
568 def _concatenate_paired_and_unpaired_features(
569 example: pipeline.FeatureDict) -> pipeline.FeatureDict:
570 """Merges paired and block-diagonalised features."""
571 features = MSA_FEATURES
572 for feature_name in features:
573 if feature_name in example:
574 feat = example[feature_name]
575 feat_all_seq = example[feature_name + '_all_seq']
576 merged_feat = np.concatenate([feat_all_seq, feat], axis=0)
577 example[feature_name] = merged_feat
578 example['num_alignments'] = np.array(example['msa'].shape[0],
579 dtype=np.int32)
580 return example
581
582
583 def merge_chain_features(np_chains_list: List[pipeline.FeatureDict],
584 pair_msa_sequences: bool,
585 max_templates: int) -> pipeline.FeatureDict:
586 """Merges features for multiple chains to single FeatureDict.
587
588 Args:
589 np_chains_list: List of FeatureDicts for each chain.
590 pair_msa_sequences: Whether to merge paired MSAs.
591 max_templates: The maximum number of templates to include.
592
593 Returns:
594 Single FeatureDict for entire complex.
595 """
596 np_chains_list = _pad_templates(
597 np_chains_list, max_templates=max_templates)
598 np_chains_list = _merge_homomers_dense_msa(np_chains_list)
599 # Unpaired MSA features will be always block-diagonalised; paired MSA
600 # features will be concatenated.
601 np_example = _merge_features_from_multiple_chains(
602 np_chains_list, pair_msa_sequences=False)
603 if pair_msa_sequences:
604 np_example = _concatenate_paired_and_unpaired_features(np_example)
605 np_example = _correct_post_merged_feats(
606 np_example=np_example,
607 np_chains_list=np_chains_list,
608 pair_msa_sequences=pair_msa_sequences)
609
610 return np_example
611
612
613 def deduplicate_unpaired_sequences(
614 np_chains: List[pipeline.FeatureDict]) -> List[pipeline.FeatureDict]:
615 """Removes unpaired sequences which duplicate a paired sequence."""
616
617 feature_names = np_chains[0].keys()
618 msa_features = MSA_FEATURES
619
620 for chain in np_chains:
621 sequence_set = set(tuple(s) for s in chain['msa_all_seq'])
622 keep_rows = []
623 # Go through unpaired MSA seqs and remove any rows that correspond to the
624 # sequences that are already present in the paired MSA.
625 for row_num, seq in enumerate(chain['msa']):
626 if tuple(seq) not in sequence_set:
627 keep_rows.append(row_num)
628 for feature_name in feature_names:
629 if feature_name in msa_features:
630 if keep_rows:
631 chain[feature_name] = chain[feature_name][keep_rows]
632 else:
633 new_shape = list(chain[feature_name].shape)
634 new_shape[0] = 0
635 chain[feature_name] = np.zeros(new_shape,
636 dtype=chain[feature_name].dtype)
637 chain['num_alignments'] = np.array(chain['msa'].shape[0], dtype=np.int32)
638 return np_chains