comparison docker/alphafold/run_alphafold.py @ 1:6c92e000d684 draft

"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author galaxy-australia
date Tue, 01 Mar 2022 02:53:05 +0000
parents
children
comparison
equal deleted inserted replaced
0:7ae9d78b06f5 1:6c92e000d684
1 # Copyright 2021 DeepMind Technologies Limited
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 """Full AlphaFold protein structure prediction script."""
16 import json
17 import os
18 import pathlib
19 import pickle
20 import random
21 import shutil
22 import sys
23 import time
24 from typing import Dict, Union, Optional
25
26 from absl import app
27 from absl import flags
28 from absl import logging
29 from alphafold.common import protein
30 from alphafold.common import residue_constants
31 from alphafold.data import pipeline
32 from alphafold.data import pipeline_multimer
33 from alphafold.data import templates
34 from alphafold.data.tools import hhsearch
35 from alphafold.data.tools import hmmsearch
36 from alphafold.model import config
37 from alphafold.model import model
38 from alphafold.relax import relax
39 import numpy as np
40
41 from alphafold.model import data
42 # Internal import (7716).
43
44 logging.set_verbosity(logging.INFO)
45
46 flags.DEFINE_list(
47 'fasta_paths', None, 'Paths to FASTA files, each containing a prediction '
48 'target that will be folded one after another. If a FASTA file contains '
49 'multiple sequences, then it will be folded as a multimer. Paths should be '
50 'separated by commas. All FASTA paths must have a unique basename as the '
51 'basename is used to name the output directories for each prediction.')
52 flags.DEFINE_list(
53 'is_prokaryote_list', None, 'Optional for multimer system, not used by the '
54 'single chain system. This list should contain a boolean for each fasta '
55 'specifying true where the target complex is from a prokaryote, and false '
56 'where it is not, or where the origin is unknown. These values determine '
57 'the pairing method for the MSA.')
58
59 flags.DEFINE_string('data_dir', None, 'Path to directory of supporting data.')
60 flags.DEFINE_string('output_dir', None, 'Path to a directory that will '
61 'store the results.')
62 flags.DEFINE_string('jackhmmer_binary_path', shutil.which('jackhmmer'),
63 'Path to the JackHMMER executable.')
64 flags.DEFINE_string('hhblits_binary_path', shutil.which('hhblits'),
65 'Path to the HHblits executable.')
66 flags.DEFINE_string('hhsearch_binary_path', shutil.which('hhsearch'),
67 'Path to the HHsearch executable.')
68 flags.DEFINE_string('hmmsearch_binary_path', shutil.which('hmmsearch'),
69 'Path to the hmmsearch executable.')
70 flags.DEFINE_string('hmmbuild_binary_path', shutil.which('hmmbuild'),
71 'Path to the hmmbuild executable.')
72 flags.DEFINE_string('kalign_binary_path', shutil.which('kalign'),
73 'Path to the Kalign executable.')
74 flags.DEFINE_string('uniref90_database_path', None, 'Path to the Uniref90 '
75 'database for use by JackHMMER.')
76 flags.DEFINE_string('mgnify_database_path', None, 'Path to the MGnify '
77 'database for use by JackHMMER.')
78 flags.DEFINE_string('bfd_database_path', None, 'Path to the BFD '
79 'database for use by HHblits.')
80 flags.DEFINE_string('small_bfd_database_path', None, 'Path to the small '
81 'version of BFD used with the "reduced_dbs" preset.')
82 flags.DEFINE_string('uniclust30_database_path', None, 'Path to the Uniclust30 '
83 'database for use by HHblits.')
84 flags.DEFINE_string('uniprot_database_path', None, 'Path to the Uniprot '
85 'database for use by JackHMMer.')
86 flags.DEFINE_string('pdb70_database_path', None, 'Path to the PDB70 '
87 'database for use by HHsearch.')
88 flags.DEFINE_string('pdb_seqres_database_path', None, 'Path to the PDB '
89 'seqres database for use by hmmsearch.')
90 flags.DEFINE_string('template_mmcif_dir', None, 'Path to a directory with '
91 'template mmCIF structures, each named <pdb_id>.cif')
92 flags.DEFINE_string('max_template_date', None, 'Maximum template release date '
93 'to consider. Important if folding historical test sets.')
94 flags.DEFINE_string('obsolete_pdbs_path', None, 'Path to file containing a '
95 'mapping from obsolete PDB IDs to the PDB IDs of their '
96 'replacements.')
97 flags.DEFINE_enum('db_preset', 'full_dbs',
98 ['full_dbs', 'reduced_dbs'],
99 'Choose preset MSA database configuration - '
100 'smaller genetic database config (reduced_dbs) or '
101 'full genetic database config (full_dbs)')
102 flags.DEFINE_enum('model_preset', 'monomer',
103 ['monomer', 'monomer_casp14', 'monomer_ptm', 'multimer'],
104 'Choose preset model configuration - the monomer model, '
105 'the monomer model with extra ensembling, monomer model with '
106 'pTM head, or multimer model')
107 flags.DEFINE_boolean('benchmark', False, 'Run multiple JAX model evaluations '
108 'to obtain a timing that excludes the compilation time, '
109 'which should be more indicative of the time required for '
110 'inferencing many proteins.')
111 flags.DEFINE_integer('random_seed', None, 'The random seed for the data '
112 'pipeline. By default, this is randomly generated. Note '
113 'that even if this is set, Alphafold may still not be '
114 'deterministic, because processes like GPU inference are '
115 'nondeterministic.')
116 flags.DEFINE_boolean('use_precomputed_msas', False, 'Whether to read MSAs that '
117 'have been written to disk. WARNING: This will not check '
118 'if the sequence, database or configuration have changed.')
119
120 FLAGS = flags.FLAGS
121
122 MAX_TEMPLATE_HITS = 20
123 RELAX_MAX_ITERATIONS = 0
124 RELAX_ENERGY_TOLERANCE = 2.39
125 RELAX_STIFFNESS = 10.0
126 RELAX_EXCLUDE_RESIDUES = []
127 RELAX_MAX_OUTER_ITERATIONS = 3
128
129
130 def _check_flag(flag_name: str,
131 other_flag_name: str,
132 should_be_set: bool):
133 if should_be_set != bool(FLAGS[flag_name].value):
134 verb = 'be' if should_be_set else 'not be'
135 raise ValueError(f'{flag_name} must {verb} set when running with '
136 f'"--{other_flag_name}={FLAGS[other_flag_name].value}".')
137
138
139 def predict_structure(
140 fasta_path: str,
141 fasta_name: str,
142 output_dir_base: str,
143 data_pipeline: Union[pipeline.DataPipeline, pipeline_multimer.DataPipeline],
144 model_runners: Dict[str, model.RunModel],
145 amber_relaxer: relax.AmberRelaxation,
146 benchmark: bool,
147 random_seed: int,
148 is_prokaryote: Optional[bool] = None):
149 """Predicts structure using AlphaFold for the given sequence."""
150 logging.info('Predicting %s', fasta_name)
151 timings = {}
152 output_dir = os.path.join(output_dir_base, fasta_name)
153 if not os.path.exists(output_dir):
154 os.makedirs(output_dir)
155 msa_output_dir = os.path.join(output_dir, 'msas')
156 if not os.path.exists(msa_output_dir):
157 os.makedirs(msa_output_dir)
158
159 # Get features.
160 t_0 = time.time()
161 if is_prokaryote is None:
162 feature_dict = data_pipeline.process(
163 input_fasta_path=fasta_path,
164 msa_output_dir=msa_output_dir)
165 else:
166 feature_dict = data_pipeline.process(
167 input_fasta_path=fasta_path,
168 msa_output_dir=msa_output_dir,
169 is_prokaryote=is_prokaryote)
170 timings['features'] = time.time() - t_0
171
172 # Write out features as a pickled dictionary.
173 features_output_path = os.path.join(output_dir, 'features.pkl')
174 with open(features_output_path, 'wb') as f:
175 pickle.dump(feature_dict, f, protocol=4)
176
177 unrelaxed_pdbs = {}
178 relaxed_pdbs = {}
179 ranking_confidences = {}
180
181 # Run the models.
182 num_models = len(model_runners)
183 for model_index, (model_name, model_runner) in enumerate(
184 model_runners.items()):
185 logging.info('Running model %s on %s', model_name, fasta_name)
186 t_0 = time.time()
187 model_random_seed = model_index + random_seed * num_models
188 processed_feature_dict = model_runner.process_features(
189 feature_dict, random_seed=model_random_seed)
190 timings[f'process_features_{model_name}'] = time.time() - t_0
191
192 t_0 = time.time()
193 prediction_result = model_runner.predict(processed_feature_dict,
194 random_seed=model_random_seed)
195 t_diff = time.time() - t_0
196 timings[f'predict_and_compile_{model_name}'] = t_diff
197 logging.info(
198 'Total JAX model %s on %s predict time (includes compilation time, see --benchmark): %.1fs',
199 model_name, fasta_name, t_diff)
200
201 if benchmark:
202 t_0 = time.time()
203 model_runner.predict(processed_feature_dict,
204 random_seed=model_random_seed)
205 t_diff = time.time() - t_0
206 timings[f'predict_benchmark_{model_name}'] = t_diff
207 logging.info(
208 'Total JAX model %s on %s predict time (excludes compilation time): %.1fs',
209 model_name, fasta_name, t_diff)
210
211 plddt = prediction_result['plddt']
212 ranking_confidences[model_name] = prediction_result['ranking_confidence']
213
214 # Save the model outputs.
215 result_output_path = os.path.join(output_dir, f'result_{model_name}.pkl')
216 with open(result_output_path, 'wb') as f:
217 pickle.dump(prediction_result, f, protocol=4)
218
219 # Add the predicted LDDT in the b-factor column.
220 # Note that higher predicted LDDT value means higher model confidence.
221 plddt_b_factors = np.repeat(
222 plddt[:, None], residue_constants.atom_type_num, axis=-1)
223 unrelaxed_protein = protein.from_prediction(
224 features=processed_feature_dict,
225 result=prediction_result,
226 b_factors=plddt_b_factors,
227 remove_leading_feature_dimension=not model_runner.multimer_mode)
228
229 unrelaxed_pdbs[model_name] = protein.to_pdb(unrelaxed_protein)
230 unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb')
231 with open(unrelaxed_pdb_path, 'w') as f:
232 f.write(unrelaxed_pdbs[model_name])
233
234 if amber_relaxer:
235 # Relax the prediction.
236 t_0 = time.time()
237 relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
238 timings[f'relax_{model_name}'] = time.time() - t_0
239
240 relaxed_pdbs[model_name] = relaxed_pdb_str
241
242 # Save the relaxed PDB.
243 relaxed_output_path = os.path.join(
244 output_dir, f'relaxed_{model_name}.pdb')
245 with open(relaxed_output_path, 'w') as f:
246 f.write(relaxed_pdb_str)
247
248 # Rank by model confidence and write out relaxed PDBs in rank order.
249 ranked_order = []
250 for idx, (model_name, _) in enumerate(
251 sorted(ranking_confidences.items(), key=lambda x: x[1], reverse=True)):
252 ranked_order.append(model_name)
253 ranked_output_path = os.path.join(output_dir, f'ranked_{idx}.pdb')
254 with open(ranked_output_path, 'w') as f:
255 if amber_relaxer:
256 f.write(relaxed_pdbs[model_name])
257 else:
258 f.write(unrelaxed_pdbs[model_name])
259
260 ranking_output_path = os.path.join(output_dir, 'ranking_debug.json')
261 with open(ranking_output_path, 'w') as f:
262 label = 'iptm+ptm' if 'iptm' in prediction_result else 'plddts'
263 f.write(json.dumps(
264 {label: ranking_confidences, 'order': ranked_order}, indent=4))
265
266 logging.info('Final timings for %s: %s', fasta_name, timings)
267
268 timings_output_path = os.path.join(output_dir, 'timings.json')
269 with open(timings_output_path, 'w') as f:
270 f.write(json.dumps(timings, indent=4))
271
272
273 def main(argv):
274 if len(argv) > 1:
275 raise app.UsageError('Too many command-line arguments.')
276
277 for tool_name in (
278 'jackhmmer', 'hhblits', 'hhsearch', 'hmmsearch', 'hmmbuild', 'kalign'):
279 if not FLAGS[f'{tool_name}_binary_path'].value:
280 raise ValueError(f'Could not find path to the "{tool_name}" binary. Make '
281 'sure it is installed on your system.')
282
283 use_small_bfd = FLAGS.db_preset == 'reduced_dbs'
284 _check_flag('small_bfd_database_path', 'db_preset',
285 should_be_set=use_small_bfd)
286 _check_flag('bfd_database_path', 'db_preset',
287 should_be_set=not use_small_bfd)
288 _check_flag('uniclust30_database_path', 'db_preset',
289 should_be_set=not use_small_bfd)
290
291 run_multimer_system = 'multimer' in FLAGS.model_preset
292 _check_flag('pdb70_database_path', 'model_preset',
293 should_be_set=not run_multimer_system)
294 _check_flag('pdb_seqres_database_path', 'model_preset',
295 should_be_set=run_multimer_system)
296 _check_flag('uniprot_database_path', 'model_preset',
297 should_be_set=run_multimer_system)
298
299 if FLAGS.model_preset == 'monomer_casp14':
300 num_ensemble = 8
301 else:
302 num_ensemble = 1
303
304 # Check for duplicate FASTA file names.
305 fasta_names = [pathlib.Path(p).stem for p in FLAGS.fasta_paths]
306 if len(fasta_names) != len(set(fasta_names)):
307 raise ValueError('All FASTA paths must have a unique basename.')
308
309 # Check that is_prokaryote_list has same number of elements as fasta_paths,
310 # and convert to bool.
311 if FLAGS.is_prokaryote_list:
312 if len(FLAGS.is_prokaryote_list) != len(FLAGS.fasta_paths):
313 raise ValueError('--is_prokaryote_list must either be omitted or match '
314 'length of --fasta_paths.')
315 is_prokaryote_list = []
316 for s in FLAGS.is_prokaryote_list:
317 if s in ('true', 'false'):
318 is_prokaryote_list.append(s == 'true')
319 else:
320 raise ValueError('--is_prokaryote_list must contain comma separated '
321 'true or false values.')
322 else: # Default is_prokaryote to False.
323 is_prokaryote_list = [False] * len(fasta_names)
324
325 if run_multimer_system:
326 template_searcher = hmmsearch.Hmmsearch(
327 binary_path=FLAGS.hmmsearch_binary_path,
328 hmmbuild_binary_path=FLAGS.hmmbuild_binary_path,
329 database_path=FLAGS.pdb_seqres_database_path)
330 template_featurizer = templates.HmmsearchHitFeaturizer(
331 mmcif_dir=FLAGS.template_mmcif_dir,
332 max_template_date=FLAGS.max_template_date,
333 max_hits=MAX_TEMPLATE_HITS,
334 kalign_binary_path=FLAGS.kalign_binary_path,
335 release_dates_path=None,
336 obsolete_pdbs_path=FLAGS.obsolete_pdbs_path)
337 else:
338 template_searcher = hhsearch.HHSearch(
339 binary_path=FLAGS.hhsearch_binary_path,
340 databases=[FLAGS.pdb70_database_path])
341 template_featurizer = templates.HhsearchHitFeaturizer(
342 mmcif_dir=FLAGS.template_mmcif_dir,
343 max_template_date=FLAGS.max_template_date,
344 max_hits=MAX_TEMPLATE_HITS,
345 kalign_binary_path=FLAGS.kalign_binary_path,
346 release_dates_path=None,
347 obsolete_pdbs_path=FLAGS.obsolete_pdbs_path)
348
349 monomer_data_pipeline = pipeline.DataPipeline(
350 jackhmmer_binary_path=FLAGS.jackhmmer_binary_path,
351 hhblits_binary_path=FLAGS.hhblits_binary_path,
352 uniref90_database_path=FLAGS.uniref90_database_path,
353 mgnify_database_path=FLAGS.mgnify_database_path,
354 bfd_database_path=FLAGS.bfd_database_path,
355 uniclust30_database_path=FLAGS.uniclust30_database_path,
356 small_bfd_database_path=FLAGS.small_bfd_database_path,
357 template_searcher=template_searcher,
358 template_featurizer=template_featurizer,
359 use_small_bfd=use_small_bfd,
360 use_precomputed_msas=FLAGS.use_precomputed_msas)
361
362 if run_multimer_system:
363 data_pipeline = pipeline_multimer.DataPipeline(
364 monomer_data_pipeline=monomer_data_pipeline,
365 jackhmmer_binary_path=FLAGS.jackhmmer_binary_path,
366 uniprot_database_path=FLAGS.uniprot_database_path,
367 use_precomputed_msas=FLAGS.use_precomputed_msas)
368 else:
369 data_pipeline = monomer_data_pipeline
370
371 model_runners = {}
372 model_names = config.MODEL_PRESETS[FLAGS.model_preset]
373 for model_name in model_names:
374 model_config = config.model_config(model_name)
375 if run_multimer_system:
376 model_config.model.num_ensemble_eval = num_ensemble
377 else:
378 model_config.data.eval.num_ensemble = num_ensemble
379 model_params = data.get_model_haiku_params(
380 model_name=model_name, data_dir=FLAGS.data_dir)
381 model_runner = model.RunModel(model_config, model_params)
382 model_runners[model_name] = model_runner
383
384 logging.info('Have %d models: %s', len(model_runners),
385 list(model_runners.keys()))
386
387 amber_relaxer = relax.AmberRelaxation(
388 max_iterations=RELAX_MAX_ITERATIONS,
389 tolerance=RELAX_ENERGY_TOLERANCE,
390 stiffness=RELAX_STIFFNESS,
391 exclude_residues=RELAX_EXCLUDE_RESIDUES,
392 max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS)
393
394 random_seed = FLAGS.random_seed
395 if random_seed is None:
396 random_seed = random.randrange(sys.maxsize // len(model_names))
397 logging.info('Using random seed %d for the data pipeline', random_seed)
398
399 # Predict structure for each of the sequences.
400 for i, fasta_path in enumerate(FLAGS.fasta_paths):
401 is_prokaryote = is_prokaryote_list[i] if run_multimer_system else None
402 fasta_name = fasta_names[i]
403 predict_structure(
404 fasta_path=fasta_path,
405 fasta_name=fasta_name,
406 output_dir_base=FLAGS.output_dir,
407 data_pipeline=data_pipeline,
408 model_runners=model_runners,
409 amber_relaxer=amber_relaxer,
410 benchmark=FLAGS.benchmark,
411 random_seed=random_seed,
412 is_prokaryote=is_prokaryote)
413
414
415 if __name__ == '__main__':
416 flags.mark_flags_as_required([
417 'fasta_paths',
418 'output_dir',
419 'data_dir',
420 'uniref90_database_path',
421 'mgnify_database_path',
422 'template_mmcif_dir',
423 'max_template_date',
424 'obsolete_pdbs_path',
425 ])
426
427 app.run(main)