Mercurial > repos > galaxy-australia > alphafold2
diff docker/alphafold/run_alphafold.py @ 1:6c92e000d684 draft
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author | galaxy-australia |
---|---|
date | Tue, 01 Mar 2022 02:53:05 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/run_alphafold.py Tue Mar 01 02:53:05 2022 +0000 @@ -0,0 +1,427 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Full AlphaFold protein structure prediction script.""" +import json +import os +import pathlib +import pickle +import random +import shutil +import sys +import time +from typing import Dict, Union, Optional + +from absl import app +from absl import flags +from absl import logging +from alphafold.common import protein +from alphafold.common import residue_constants +from alphafold.data import pipeline +from alphafold.data import pipeline_multimer +from alphafold.data import templates +from alphafold.data.tools import hhsearch +from alphafold.data.tools import hmmsearch +from alphafold.model import config +from alphafold.model import model +from alphafold.relax import relax +import numpy as np + +from alphafold.model import data +# Internal import (7716). + +logging.set_verbosity(logging.INFO) + +flags.DEFINE_list( + 'fasta_paths', None, 'Paths to FASTA files, each containing a prediction ' + 'target that will be folded one after another. If a FASTA file contains ' + 'multiple sequences, then it will be folded as a multimer. Paths should be ' + 'separated by commas. All FASTA paths must have a unique basename as the ' + 'basename is used to name the output directories for each prediction.') +flags.DEFINE_list( + 'is_prokaryote_list', None, 'Optional for multimer system, not used by the ' + 'single chain system. This list should contain a boolean for each fasta ' + 'specifying true where the target complex is from a prokaryote, and false ' + 'where it is not, or where the origin is unknown. These values determine ' + 'the pairing method for the MSA.') + +flags.DEFINE_string('data_dir', None, 'Path to directory of supporting data.') +flags.DEFINE_string('output_dir', None, 'Path to a directory that will ' + 'store the results.') +flags.DEFINE_string('jackhmmer_binary_path', shutil.which('jackhmmer'), + 'Path to the JackHMMER executable.') +flags.DEFINE_string('hhblits_binary_path', shutil.which('hhblits'), + 'Path to the HHblits executable.') +flags.DEFINE_string('hhsearch_binary_path', shutil.which('hhsearch'), + 'Path to the HHsearch executable.') +flags.DEFINE_string('hmmsearch_binary_path', shutil.which('hmmsearch'), + 'Path to the hmmsearch executable.') +flags.DEFINE_string('hmmbuild_binary_path', shutil.which('hmmbuild'), + 'Path to the hmmbuild executable.') +flags.DEFINE_string('kalign_binary_path', shutil.which('kalign'), + 'Path to the Kalign executable.') +flags.DEFINE_string('uniref90_database_path', None, 'Path to the Uniref90 ' + 'database for use by JackHMMER.') +flags.DEFINE_string('mgnify_database_path', None, 'Path to the MGnify ' + 'database for use by JackHMMER.') +flags.DEFINE_string('bfd_database_path', None, 'Path to the BFD ' + 'database for use by HHblits.') +flags.DEFINE_string('small_bfd_database_path', None, 'Path to the small ' + 'version of BFD used with the "reduced_dbs" preset.') +flags.DEFINE_string('uniclust30_database_path', None, 'Path to the Uniclust30 ' + 'database for use by HHblits.') +flags.DEFINE_string('uniprot_database_path', None, 'Path to the Uniprot ' + 'database for use by JackHMMer.') +flags.DEFINE_string('pdb70_database_path', None, 'Path to the PDB70 ' + 'database for use by HHsearch.') +flags.DEFINE_string('pdb_seqres_database_path', None, 'Path to the PDB ' + 'seqres database for use by hmmsearch.') +flags.DEFINE_string('template_mmcif_dir', None, 'Path to a directory with ' + 'template mmCIF structures, each named <pdb_id>.cif') +flags.DEFINE_string('max_template_date', None, 'Maximum template release date ' + 'to consider. Important if folding historical test sets.') +flags.DEFINE_string('obsolete_pdbs_path', None, 'Path to file containing a ' + 'mapping from obsolete PDB IDs to the PDB IDs of their ' + 'replacements.') +flags.DEFINE_enum('db_preset', 'full_dbs', + ['full_dbs', 'reduced_dbs'], + 'Choose preset MSA database configuration - ' + 'smaller genetic database config (reduced_dbs) or ' + 'full genetic database config (full_dbs)') +flags.DEFINE_enum('model_preset', 'monomer', + ['monomer', 'monomer_casp14', 'monomer_ptm', 'multimer'], + 'Choose preset model configuration - the monomer model, ' + 'the monomer model with extra ensembling, monomer model with ' + 'pTM head, or multimer model') +flags.DEFINE_boolean('benchmark', False, 'Run multiple JAX model evaluations ' + 'to obtain a timing that excludes the compilation time, ' + 'which should be more indicative of the time required for ' + 'inferencing many proteins.') +flags.DEFINE_integer('random_seed', None, 'The random seed for the data ' + 'pipeline. By default, this is randomly generated. Note ' + 'that even if this is set, Alphafold may still not be ' + 'deterministic, because processes like GPU inference are ' + 'nondeterministic.') +flags.DEFINE_boolean('use_precomputed_msas', False, 'Whether to read MSAs that ' + 'have been written to disk. WARNING: This will not check ' + 'if the sequence, database or configuration have changed.') + +FLAGS = flags.FLAGS + +MAX_TEMPLATE_HITS = 20 +RELAX_MAX_ITERATIONS = 0 +RELAX_ENERGY_TOLERANCE = 2.39 +RELAX_STIFFNESS = 10.0 +RELAX_EXCLUDE_RESIDUES = [] +RELAX_MAX_OUTER_ITERATIONS = 3 + + +def _check_flag(flag_name: str, + other_flag_name: str, + should_be_set: bool): + if should_be_set != bool(FLAGS[flag_name].value): + verb = 'be' if should_be_set else 'not be' + raise ValueError(f'{flag_name} must {verb} set when running with ' + f'"--{other_flag_name}={FLAGS[other_flag_name].value}".') + + +def predict_structure( + fasta_path: str, + fasta_name: str, + output_dir_base: str, + data_pipeline: Union[pipeline.DataPipeline, pipeline_multimer.DataPipeline], + model_runners: Dict[str, model.RunModel], + amber_relaxer: relax.AmberRelaxation, + benchmark: bool, + random_seed: int, + is_prokaryote: Optional[bool] = None): + """Predicts structure using AlphaFold for the given sequence.""" + logging.info('Predicting %s', fasta_name) + timings = {} + output_dir = os.path.join(output_dir_base, fasta_name) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + msa_output_dir = os.path.join(output_dir, 'msas') + if not os.path.exists(msa_output_dir): + os.makedirs(msa_output_dir) + + # Get features. + t_0 = time.time() + if is_prokaryote is None: + feature_dict = data_pipeline.process( + input_fasta_path=fasta_path, + msa_output_dir=msa_output_dir) + else: + feature_dict = data_pipeline.process( + input_fasta_path=fasta_path, + msa_output_dir=msa_output_dir, + is_prokaryote=is_prokaryote) + timings['features'] = time.time() - t_0 + + # Write out features as a pickled dictionary. + features_output_path = os.path.join(output_dir, 'features.pkl') + with open(features_output_path, 'wb') as f: + pickle.dump(feature_dict, f, protocol=4) + + unrelaxed_pdbs = {} + relaxed_pdbs = {} + ranking_confidences = {} + + # Run the models. + num_models = len(model_runners) + for model_index, (model_name, model_runner) in enumerate( + model_runners.items()): + logging.info('Running model %s on %s', model_name, fasta_name) + t_0 = time.time() + model_random_seed = model_index + random_seed * num_models + processed_feature_dict = model_runner.process_features( + feature_dict, random_seed=model_random_seed) + timings[f'process_features_{model_name}'] = time.time() - t_0 + + t_0 = time.time() + prediction_result = model_runner.predict(processed_feature_dict, + random_seed=model_random_seed) + t_diff = time.time() - t_0 + timings[f'predict_and_compile_{model_name}'] = t_diff + logging.info( + 'Total JAX model %s on %s predict time (includes compilation time, see --benchmark): %.1fs', + model_name, fasta_name, t_diff) + + if benchmark: + t_0 = time.time() + model_runner.predict(processed_feature_dict, + random_seed=model_random_seed) + t_diff = time.time() - t_0 + timings[f'predict_benchmark_{model_name}'] = t_diff + logging.info( + 'Total JAX model %s on %s predict time (excludes compilation time): %.1fs', + model_name, fasta_name, t_diff) + + plddt = prediction_result['plddt'] + ranking_confidences[model_name] = prediction_result['ranking_confidence'] + + # Save the model outputs. + result_output_path = os.path.join(output_dir, f'result_{model_name}.pkl') + with open(result_output_path, 'wb') as f: + pickle.dump(prediction_result, f, protocol=4) + + # Add the predicted LDDT in the b-factor column. + # Note that higher predicted LDDT value means higher model confidence. + plddt_b_factors = np.repeat( + plddt[:, None], residue_constants.atom_type_num, axis=-1) + unrelaxed_protein = protein.from_prediction( + features=processed_feature_dict, + result=prediction_result, + b_factors=plddt_b_factors, + remove_leading_feature_dimension=not model_runner.multimer_mode) + + unrelaxed_pdbs[model_name] = protein.to_pdb(unrelaxed_protein) + unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb') + with open(unrelaxed_pdb_path, 'w') as f: + f.write(unrelaxed_pdbs[model_name]) + + if amber_relaxer: + # Relax the prediction. + t_0 = time.time() + relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) + timings[f'relax_{model_name}'] = time.time() - t_0 + + relaxed_pdbs[model_name] = relaxed_pdb_str + + # Save the relaxed PDB. + relaxed_output_path = os.path.join( + output_dir, f'relaxed_{model_name}.pdb') + with open(relaxed_output_path, 'w') as f: + f.write(relaxed_pdb_str) + + # Rank by model confidence and write out relaxed PDBs in rank order. + ranked_order = [] + for idx, (model_name, _) in enumerate( + sorted(ranking_confidences.items(), key=lambda x: x[1], reverse=True)): + ranked_order.append(model_name) + ranked_output_path = os.path.join(output_dir, f'ranked_{idx}.pdb') + with open(ranked_output_path, 'w') as f: + if amber_relaxer: + f.write(relaxed_pdbs[model_name]) + else: + f.write(unrelaxed_pdbs[model_name]) + + ranking_output_path = os.path.join(output_dir, 'ranking_debug.json') + with open(ranking_output_path, 'w') as f: + label = 'iptm+ptm' if 'iptm' in prediction_result else 'plddts' + f.write(json.dumps( + {label: ranking_confidences, 'order': ranked_order}, indent=4)) + + logging.info('Final timings for %s: %s', fasta_name, timings) + + timings_output_path = os.path.join(output_dir, 'timings.json') + with open(timings_output_path, 'w') as f: + f.write(json.dumps(timings, indent=4)) + + +def main(argv): + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + + for tool_name in ( + 'jackhmmer', 'hhblits', 'hhsearch', 'hmmsearch', 'hmmbuild', 'kalign'): + if not FLAGS[f'{tool_name}_binary_path'].value: + raise ValueError(f'Could not find path to the "{tool_name}" binary. Make ' + 'sure it is installed on your system.') + + use_small_bfd = FLAGS.db_preset == 'reduced_dbs' + _check_flag('small_bfd_database_path', 'db_preset', + should_be_set=use_small_bfd) + _check_flag('bfd_database_path', 'db_preset', + should_be_set=not use_small_bfd) + _check_flag('uniclust30_database_path', 'db_preset', + should_be_set=not use_small_bfd) + + run_multimer_system = 'multimer' in FLAGS.model_preset + _check_flag('pdb70_database_path', 'model_preset', + should_be_set=not run_multimer_system) + _check_flag('pdb_seqres_database_path', 'model_preset', + should_be_set=run_multimer_system) + _check_flag('uniprot_database_path', 'model_preset', + should_be_set=run_multimer_system) + + if FLAGS.model_preset == 'monomer_casp14': + num_ensemble = 8 + else: + num_ensemble = 1 + + # Check for duplicate FASTA file names. + fasta_names = [pathlib.Path(p).stem for p in FLAGS.fasta_paths] + if len(fasta_names) != len(set(fasta_names)): + raise ValueError('All FASTA paths must have a unique basename.') + + # Check that is_prokaryote_list has same number of elements as fasta_paths, + # and convert to bool. + if FLAGS.is_prokaryote_list: + if len(FLAGS.is_prokaryote_list) != len(FLAGS.fasta_paths): + raise ValueError('--is_prokaryote_list must either be omitted or match ' + 'length of --fasta_paths.') + is_prokaryote_list = [] + for s in FLAGS.is_prokaryote_list: + if s in ('true', 'false'): + is_prokaryote_list.append(s == 'true') + else: + raise ValueError('--is_prokaryote_list must contain comma separated ' + 'true or false values.') + else: # Default is_prokaryote to False. + is_prokaryote_list = [False] * len(fasta_names) + + if run_multimer_system: + template_searcher = hmmsearch.Hmmsearch( + binary_path=FLAGS.hmmsearch_binary_path, + hmmbuild_binary_path=FLAGS.hmmbuild_binary_path, + database_path=FLAGS.pdb_seqres_database_path) + template_featurizer = templates.HmmsearchHitFeaturizer( + mmcif_dir=FLAGS.template_mmcif_dir, + max_template_date=FLAGS.max_template_date, + max_hits=MAX_TEMPLATE_HITS, + kalign_binary_path=FLAGS.kalign_binary_path, + release_dates_path=None, + obsolete_pdbs_path=FLAGS.obsolete_pdbs_path) + else: + template_searcher = hhsearch.HHSearch( + binary_path=FLAGS.hhsearch_binary_path, + databases=[FLAGS.pdb70_database_path]) + template_featurizer = templates.HhsearchHitFeaturizer( + mmcif_dir=FLAGS.template_mmcif_dir, + max_template_date=FLAGS.max_template_date, + max_hits=MAX_TEMPLATE_HITS, + kalign_binary_path=FLAGS.kalign_binary_path, + release_dates_path=None, + obsolete_pdbs_path=FLAGS.obsolete_pdbs_path) + + monomer_data_pipeline = pipeline.DataPipeline( + jackhmmer_binary_path=FLAGS.jackhmmer_binary_path, + hhblits_binary_path=FLAGS.hhblits_binary_path, + uniref90_database_path=FLAGS.uniref90_database_path, + mgnify_database_path=FLAGS.mgnify_database_path, + bfd_database_path=FLAGS.bfd_database_path, + uniclust30_database_path=FLAGS.uniclust30_database_path, + small_bfd_database_path=FLAGS.small_bfd_database_path, + template_searcher=template_searcher, + template_featurizer=template_featurizer, + use_small_bfd=use_small_bfd, + use_precomputed_msas=FLAGS.use_precomputed_msas) + + if run_multimer_system: + data_pipeline = pipeline_multimer.DataPipeline( + monomer_data_pipeline=monomer_data_pipeline, + jackhmmer_binary_path=FLAGS.jackhmmer_binary_path, + uniprot_database_path=FLAGS.uniprot_database_path, + use_precomputed_msas=FLAGS.use_precomputed_msas) + else: + data_pipeline = monomer_data_pipeline + + model_runners = {} + model_names = config.MODEL_PRESETS[FLAGS.model_preset] + for model_name in model_names: + model_config = config.model_config(model_name) + if run_multimer_system: + model_config.model.num_ensemble_eval = num_ensemble + else: + model_config.data.eval.num_ensemble = num_ensemble + model_params = data.get_model_haiku_params( + model_name=model_name, data_dir=FLAGS.data_dir) + model_runner = model.RunModel(model_config, model_params) + model_runners[model_name] = model_runner + + logging.info('Have %d models: %s', len(model_runners), + list(model_runners.keys())) + + amber_relaxer = relax.AmberRelaxation( + max_iterations=RELAX_MAX_ITERATIONS, + tolerance=RELAX_ENERGY_TOLERANCE, + stiffness=RELAX_STIFFNESS, + exclude_residues=RELAX_EXCLUDE_RESIDUES, + max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS) + + random_seed = FLAGS.random_seed + if random_seed is None: + random_seed = random.randrange(sys.maxsize // len(model_names)) + logging.info('Using random seed %d for the data pipeline', random_seed) + + # Predict structure for each of the sequences. + for i, fasta_path in enumerate(FLAGS.fasta_paths): + is_prokaryote = is_prokaryote_list[i] if run_multimer_system else None + fasta_name = fasta_names[i] + predict_structure( + fasta_path=fasta_path, + fasta_name=fasta_name, + output_dir_base=FLAGS.output_dir, + data_pipeline=data_pipeline, + model_runners=model_runners, + amber_relaxer=amber_relaxer, + benchmark=FLAGS.benchmark, + random_seed=random_seed, + is_prokaryote=is_prokaryote) + + +if __name__ == '__main__': + flags.mark_flags_as_required([ + 'fasta_paths', + 'output_dir', + 'data_dir', + 'uniref90_database_path', + 'mgnify_database_path', + 'template_mmcif_dir', + 'max_template_date', + 'obsolete_pdbs_path', + ]) + + app.run(main)