Mercurial > repos > galaxy-australia > alphafold2
diff docker/alphafold/alphafold/data/pipeline_multimer.py @ 1:6c92e000d684 draft
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author | galaxy-australia |
---|---|
date | Tue, 01 Mar 2022 02:53:05 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/data/pipeline_multimer.py Tue Mar 01 02:53:05 2022 +0000 @@ -0,0 +1,288 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for building the features for the AlphaFold multimer model.""" + +import collections +import contextlib +import copy +import dataclasses +import json +import os +import tempfile +from typing import Mapping, MutableMapping, Sequence + +from absl import logging +from alphafold.common import protein +from alphafold.common import residue_constants +from alphafold.data import feature_processing +from alphafold.data import msa_pairing +from alphafold.data import parsers +from alphafold.data import pipeline +from alphafold.data.tools import jackhmmer +import numpy as np + +# Internal import (7716). + + +@dataclasses.dataclass(frozen=True) +class _FastaChain: + sequence: str + description: str + + +def _make_chain_id_map(*, + sequences: Sequence[str], + descriptions: Sequence[str], + ) -> Mapping[str, _FastaChain]: + """Makes a mapping from PDB-format chain ID to sequence and description.""" + if len(sequences) != len(descriptions): + raise ValueError('sequences and descriptions must have equal length. ' + f'Got {len(sequences)} != {len(descriptions)}.') + if len(sequences) > protein.PDB_MAX_CHAINS: + raise ValueError('Cannot process more chains than the PDB format supports. ' + f'Got {len(sequences)} chains.') + chain_id_map = {} + for chain_id, sequence, description in zip( + protein.PDB_CHAIN_IDS, sequences, descriptions): + chain_id_map[chain_id] = _FastaChain( + sequence=sequence, description=description) + return chain_id_map + + +@contextlib.contextmanager +def temp_fasta_file(fasta_str: str): + with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file: + fasta_file.write(fasta_str) + fasta_file.seek(0) + yield fasta_file.name + + +def convert_monomer_features( + monomer_features: pipeline.FeatureDict, + chain_id: str) -> pipeline.FeatureDict: + """Reshapes and modifies monomer features for multimer models.""" + converted = {} + converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_) + unnecessary_leading_dim_feats = { + 'sequence', 'domain_name', 'num_alignments', 'seq_length'} + for feature_name, feature in monomer_features.items(): + if feature_name in unnecessary_leading_dim_feats: + # asarray ensures it's a np.ndarray. + feature = np.asarray(feature[0], dtype=feature.dtype) + elif feature_name == 'aatype': + # The multimer model performs the one-hot operation itself. + feature = np.argmax(feature, axis=-1).astype(np.int32) + elif feature_name == 'template_aatype': + feature = np.argmax(feature, axis=-1).astype(np.int32) + new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + feature = np.take(new_order_list, feature.astype(np.int32), axis=0) + elif feature_name == 'template_all_atom_masks': + feature_name = 'template_all_atom_mask' + converted[feature_name] = feature + return converted + + +def int_id_to_str_id(num: int) -> str: + """Encodes a number as a string, using reverse spreadsheet style naming. + + Args: + num: A positive integer. + + Returns: + A string that encodes the positive integer using reverse spreadsheet style, + naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the + usual way to encode chain IDs in mmCIF files. + """ + if num <= 0: + raise ValueError(f'Only positive integers allowed, got {num}.') + + num = num - 1 # 1-based indexing. + output = [] + while num >= 0: + output.append(chr(num % 26 + ord('A'))) + num = num // 26 - 1 + return ''.join(output) + + +def add_assembly_features( + all_chain_features: MutableMapping[str, pipeline.FeatureDict], + ) -> MutableMapping[str, pipeline.FeatureDict]: + """Add features to distinguish between chains. + + Args: + all_chain_features: A dictionary which maps chain_id to a dictionary of + features for each chain. + + Returns: + all_chain_features: A dictionary which maps strings of the form + `<seq_id>_<sym_id>` to the corresponding chain features. E.g. two + chains from a homodimer would have keys A_1 and A_2. Two chains from a + heterodimer would have keys A_1 and B_1. + """ + # Group the chains by sequence + seq_to_entity_id = {} + grouped_chains = collections.defaultdict(list) + for chain_id, chain_features in all_chain_features.items(): + seq = str(chain_features['sequence']) + if seq not in seq_to_entity_id: + seq_to_entity_id[seq] = len(seq_to_entity_id) + 1 + grouped_chains[seq_to_entity_id[seq]].append(chain_features) + + new_all_chain_features = {} + chain_id = 1 + for entity_id, group_chain_features in grouped_chains.items(): + for sym_id, chain_features in enumerate(group_chain_features, start=1): + new_all_chain_features[ + f'{int_id_to_str_id(entity_id)}_{sym_id}'] = chain_features + seq_length = chain_features['seq_length'] + chain_features['asym_id'] = chain_id * np.ones(seq_length) + chain_features['sym_id'] = sym_id * np.ones(seq_length) + chain_features['entity_id'] = entity_id * np.ones(seq_length) + chain_id += 1 + + return new_all_chain_features + + +def pad_msa(np_example, min_num_seq): + np_example = dict(np_example) + num_seq = np_example['msa'].shape[0] + if num_seq < min_num_seq: + for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask'): + np_example[feat] = np.pad( + np_example[feat], ((0, min_num_seq - num_seq), (0, 0))) + np_example['cluster_bias_mask'] = np.pad( + np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq),)) + return np_example + + +class DataPipeline: + """Runs the alignment tools and assembles the input features.""" + + def __init__(self, + monomer_data_pipeline: pipeline.DataPipeline, + jackhmmer_binary_path: str, + uniprot_database_path: str, + max_uniprot_hits: int = 50000, + use_precomputed_msas: bool = False): + """Initializes the data pipeline. + + Args: + monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs + the data pipeline for the monomer AlphaFold system. + jackhmmer_binary_path: Location of the jackhmmer binary. + uniprot_database_path: Location of the unclustered uniprot sequences, that + will be searched with jackhmmer and used for MSA pairing. + max_uniprot_hits: The maximum number of hits to return from uniprot. + use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold. + """ + self._monomer_data_pipeline = monomer_data_pipeline + self._uniprot_msa_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=uniprot_database_path) + self._max_uniprot_hits = max_uniprot_hits + self.use_precomputed_msas = use_precomputed_msas + + def _process_single_chain( + self, + chain_id: str, + sequence: str, + description: str, + msa_output_dir: str, + is_homomer_or_monomer: bool) -> pipeline.FeatureDict: + """Runs the monomer pipeline on a single chain.""" + chain_fasta_str = f'>chain_{chain_id}\n{sequence}\n' + chain_msa_output_dir = os.path.join(msa_output_dir, chain_id) + if not os.path.exists(chain_msa_output_dir): + os.makedirs(chain_msa_output_dir) + with temp_fasta_file(chain_fasta_str) as chain_fasta_path: + logging.info('Running monomer pipeline on chain %s: %s', + chain_id, description) + chain_features = self._monomer_data_pipeline.process( + input_fasta_path=chain_fasta_path, + msa_output_dir=chain_msa_output_dir) + + # We only construct the pairing features if there are 2 or more unique + # sequences. + if not is_homomer_or_monomer: + all_seq_msa_features = self._all_seq_msa_features(chain_fasta_path, + chain_msa_output_dir) + chain_features.update(all_seq_msa_features) + return chain_features + + def _all_seq_msa_features(self, input_fasta_path, msa_output_dir): + """Get MSA features for unclustered uniprot, for pairing.""" + out_path = os.path.join(msa_output_dir, 'uniprot_hits.sto') + result = pipeline.run_msa_tool( + self._uniprot_msa_runner, input_fasta_path, out_path, 'sto', + self.use_precomputed_msas) + msa = parsers.parse_stockholm(result['sto']) + msa = msa.truncate(max_seqs=self._max_uniprot_hits) + all_seq_features = pipeline.make_msa_features([msa]) + valid_feats = msa_pairing.MSA_FEATURES + ( + 'msa_uniprot_accession_identifiers', + 'msa_species_identifiers', + ) + feats = {f'{k}_all_seq': v for k, v in all_seq_features.items() + if k in valid_feats} + return feats + + def process(self, + input_fasta_path: str, + msa_output_dir: str, + is_prokaryote: bool = False) -> pipeline.FeatureDict: + """Runs alignment tools on the input sequences and creates features.""" + with open(input_fasta_path) as f: + input_fasta_str = f.read() + input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) + + chain_id_map = _make_chain_id_map(sequences=input_seqs, + descriptions=input_descs) + chain_id_map_path = os.path.join(msa_output_dir, 'chain_id_map.json') + with open(chain_id_map_path, 'w') as f: + chain_id_map_dict = {chain_id: dataclasses.asdict(fasta_chain) + for chain_id, fasta_chain in chain_id_map.items()} + json.dump(chain_id_map_dict, f, indent=4, sort_keys=True) + + all_chain_features = {} + sequence_features = {} + is_homomer_or_monomer = len(set(input_seqs)) == 1 + for chain_id, fasta_chain in chain_id_map.items(): + if fasta_chain.sequence in sequence_features: + all_chain_features[chain_id] = copy.deepcopy( + sequence_features[fasta_chain.sequence]) + continue + chain_features = self._process_single_chain( + chain_id=chain_id, + sequence=fasta_chain.sequence, + description=fasta_chain.description, + msa_output_dir=msa_output_dir, + is_homomer_or_monomer=is_homomer_or_monomer) + + chain_features = convert_monomer_features(chain_features, + chain_id=chain_id) + all_chain_features[chain_id] = chain_features + sequence_features[fasta_chain.sequence] = chain_features + + all_chain_features = add_assembly_features(all_chain_features) + + np_example = feature_processing.pair_and_merge( + all_chain_features=all_chain_features, + is_prokaryote=is_prokaryote, + ) + + # Pad MSA to avoid zero-sized extra_msa. + np_example = pad_msa(np_example, 512) + + return np_example