Mercurial > repos > galaxy-australia > alphafold2
view docker/alphafold/alphafold/data/feature_processing.py @ 1:6c92e000d684 draft
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author | galaxy-australia |
---|---|
date | Tue, 01 Mar 2022 02:53:05 +0000 |
parents | |
children |
line wrap: on
line source
# Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Feature processing logic for multimer data pipeline.""" from typing import Iterable, MutableMapping, List from alphafold.common import residue_constants from alphafold.data import msa_pairing from alphafold.data import pipeline import numpy as np REQUIRED_FEATURES = frozenset({ 'aatype', 'all_atom_mask', 'all_atom_positions', 'all_chains_entity_ids', 'all_crops_all_chains_mask', 'all_crops_all_chains_positions', 'all_crops_all_chains_residue_ids', 'assembly_num_chains', 'asym_id', 'bert_mask', 'cluster_bias_mask', 'deletion_matrix', 'deletion_mean', 'entity_id', 'entity_mask', 'mem_peak', 'msa', 'msa_mask', 'num_alignments', 'num_templates', 'queue_size', 'residue_index', 'resolution', 'seq_length', 'seq_mask', 'sym_id', 'template_aatype', 'template_all_atom_mask', 'template_all_atom_positions' }) MAX_TEMPLATES = 4 MSA_CROP_SIZE = 2048 def _is_homomer_or_monomer(chains: Iterable[pipeline.FeatureDict]) -> bool: """Checks if a list of chains represents a homomer/monomer example.""" # Note that an entity_id of 0 indicates padding. num_unique_chains = len(np.unique(np.concatenate( [np.unique(chain['entity_id'][chain['entity_id'] > 0]) for chain in chains]))) return num_unique_chains == 1 def pair_and_merge( all_chain_features: MutableMapping[str, pipeline.FeatureDict], is_prokaryote: bool) -> pipeline.FeatureDict: """Runs processing on features to augment, pair and merge. Args: all_chain_features: A MutableMap of dictionaries of features for each chain. is_prokaryote: Whether the target complex is from a prokaryotic or eukaryotic organism. Returns: A dictionary of features. """ process_unmerged_features(all_chain_features) np_chains_list = list(all_chain_features.values()) pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list) if pair_msa_sequences: np_chains_list = msa_pairing.create_paired_features( chains=np_chains_list, prokaryotic=is_prokaryote) np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list) np_chains_list = crop_chains( np_chains_list, msa_crop_size=MSA_CROP_SIZE, pair_msa_sequences=pair_msa_sequences, max_templates=MAX_TEMPLATES) np_example = msa_pairing.merge_chain_features( np_chains_list=np_chains_list, pair_msa_sequences=pair_msa_sequences, max_templates=MAX_TEMPLATES) np_example = process_final(np_example) return np_example def crop_chains( chains_list: List[pipeline.FeatureDict], msa_crop_size: int, pair_msa_sequences: bool, max_templates: int) -> List[pipeline.FeatureDict]: """Crops the MSAs for a set of chains. Args: chains_list: A list of chains to be cropped. msa_crop_size: The total number of sequences to crop from the MSA. pair_msa_sequences: Whether we are operating in sequence-pairing mode. max_templates: The maximum templates to use per chain. Returns: The chains cropped. """ # Apply the cropping. cropped_chains = [] for chain in chains_list: cropped_chain = _crop_single_chain( chain, msa_crop_size=msa_crop_size, pair_msa_sequences=pair_msa_sequences, max_templates=max_templates) cropped_chains.append(cropped_chain) return cropped_chains def _crop_single_chain(chain: pipeline.FeatureDict, msa_crop_size: int, pair_msa_sequences: bool, max_templates: int) -> pipeline.FeatureDict: """Crops msa sequences to `msa_crop_size`.""" msa_size = chain['num_alignments'] if pair_msa_sequences: msa_size_all_seq = chain['num_alignments_all_seq'] msa_crop_size_all_seq = np.minimum(msa_size_all_seq, msa_crop_size // 2) # We reduce the number of un-paired sequences, by the number of times a # sequence from this chain's MSA is included in the paired MSA. This keeps # the MSA size for each chain roughly constant. msa_all_seq = chain['msa_all_seq'][:msa_crop_size_all_seq, :] num_non_gapped_pairs = np.sum( np.any(msa_all_seq != msa_pairing.MSA_GAP_IDX, axis=1)) num_non_gapped_pairs = np.minimum(num_non_gapped_pairs, msa_crop_size_all_seq) # Restrict the unpaired crop size so that paired+unpaired sequences do not # exceed msa_seqs_per_chain for each chain. max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0) msa_crop_size = np.minimum(msa_size, max_msa_crop_size) else: msa_crop_size = np.minimum(msa_size, msa_crop_size) include_templates = 'template_aatype' in chain and max_templates if include_templates: num_templates = chain['template_aatype'].shape[0] templates_crop_size = np.minimum(num_templates, max_templates) for k in chain: k_split = k.split('_all_seq')[0] if k_split in msa_pairing.TEMPLATE_FEATURES: chain[k] = chain[k][:templates_crop_size, :] elif k_split in msa_pairing.MSA_FEATURES: if '_all_seq' in k and pair_msa_sequences: chain[k] = chain[k][:msa_crop_size_all_seq, :] else: chain[k] = chain[k][:msa_crop_size, :] chain['num_alignments'] = np.asarray(msa_crop_size, dtype=np.int32) if include_templates: chain['num_templates'] = np.asarray(templates_crop_size, dtype=np.int32) if pair_msa_sequences: chain['num_alignments_all_seq'] = np.asarray( msa_crop_size_all_seq, dtype=np.int32) return chain def process_final(np_example: pipeline.FeatureDict) -> pipeline.FeatureDict: """Final processing steps in data pipeline, after merging and pairing.""" np_example = _correct_msa_restypes(np_example) np_example = _make_seq_mask(np_example) np_example = _make_msa_mask(np_example) np_example = _filter_features(np_example) return np_example def _correct_msa_restypes(np_example): """Correct MSA restype to have the same order as residue_constants.""" new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE np_example['msa'] = np.take(new_order_list, np_example['msa'], axis=0) np_example['msa'] = np_example['msa'].astype(np.int32) return np_example def _make_seq_mask(np_example): np_example['seq_mask'] = (np_example['entity_id'] > 0).astype(np.float32) return np_example def _make_msa_mask(np_example): """Mask features are all ones, but will later be zero-padded.""" np_example['msa_mask'] = np.ones_like(np_example['msa'], dtype=np.float32) seq_mask = (np_example['entity_id'] > 0).astype(np.float32) np_example['msa_mask'] *= seq_mask[None] return np_example def _filter_features(np_example: pipeline.FeatureDict) -> pipeline.FeatureDict: """Filters features of example to only those requested.""" return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES} def process_unmerged_features( all_chain_features: MutableMapping[str, pipeline.FeatureDict]): """Postprocessing stage for per-chain features before merging.""" num_chains = len(all_chain_features) for chain_features in all_chain_features.values(): # Convert deletion matrices to float. chain_features['deletion_matrix'] = np.asarray( chain_features.pop('deletion_matrix_int'), dtype=np.float32) if 'deletion_matrix_int_all_seq' in chain_features: chain_features['deletion_matrix_all_seq'] = np.asarray( chain_features.pop('deletion_matrix_int_all_seq'), dtype=np.float32) chain_features['deletion_mean'] = np.mean( chain_features['deletion_matrix'], axis=0) # Add all_atom_mask and dummy all_atom_positions based on aatype. all_atom_mask = residue_constants.STANDARD_ATOM_MASK[ chain_features['aatype']] chain_features['all_atom_mask'] = all_atom_mask chain_features['all_atom_positions'] = np.zeros( list(all_atom_mask.shape) + [3]) # Add assembly_num_chains. chain_features['assembly_num_chains'] = np.asarray(num_chains) # Add entity_mask. for chain_features in all_chain_features.values(): chain_features['entity_mask'] = ( chain_features['entity_id'] != 0).astype(np.int32)