Mercurial > repos > galaxy-australia > alphafold2
comparison docker/alphafold/docker/run_docker.py @ 1:6c92e000d684 draft
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author | galaxy-australia |
---|---|
date | Tue, 01 Mar 2022 02:53:05 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
0:7ae9d78b06f5 | 1:6c92e000d684 |
---|---|
1 # Copyright 2021 DeepMind Technologies Limited | |
2 # | |
3 # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 # you may not use this file except in compliance with the License. | |
5 # You may obtain a copy of the License at | |
6 # | |
7 # http://www.apache.org/licenses/LICENSE-2.0 | |
8 # | |
9 # Unless required by applicable law or agreed to in writing, software | |
10 # distributed under the License is distributed on an "AS IS" BASIS, | |
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 # See the License for the specific language governing permissions and | |
13 # limitations under the License. | |
14 | |
15 """Docker launch script for Alphafold docker image.""" | |
16 | |
17 import os | |
18 import pathlib | |
19 import signal | |
20 from typing import Tuple | |
21 | |
22 from absl import app | |
23 from absl import flags | |
24 from absl import logging | |
25 import docker | |
26 from docker import types | |
27 | |
28 | |
29 flags.DEFINE_bool( | |
30 'use_gpu', True, 'Enable NVIDIA runtime to run with GPUs.') | |
31 flags.DEFINE_string( | |
32 'gpu_devices', 'all', | |
33 'Comma separated list of devices to pass to NVIDIA_VISIBLE_DEVICES.') | |
34 flags.DEFINE_list( | |
35 'fasta_paths', None, 'Paths to FASTA files, each containing a prediction ' | |
36 'target that will be folded one after another. If a FASTA file contains ' | |
37 'multiple sequences, then it will be folded as a multimer. Paths should be ' | |
38 'separated by commas. All FASTA paths must have a unique basename as the ' | |
39 'basename is used to name the output directories for each prediction.') | |
40 flags.DEFINE_list( | |
41 'is_prokaryote_list', None, 'Optional for multimer system, not used by the ' | |
42 'single chain system. This list should contain a boolean for each fasta ' | |
43 'specifying true where the target complex is from a prokaryote, and false ' | |
44 'where it is not, or where the origin is unknown. These values determine ' | |
45 'the pairing method for the MSA.') | |
46 flags.DEFINE_string( | |
47 'output_dir', '/tmp/alphafold', | |
48 'Path to a directory that will store the results.') | |
49 flags.DEFINE_string( | |
50 'data_dir', None, | |
51 'Path to directory with supporting data: AlphaFold parameters and genetic ' | |
52 'and template databases. Set to the target of download_all_databases.sh.') | |
53 flags.DEFINE_string( | |
54 'docker_image_name', 'alphafold', 'Name of the AlphaFold Docker image.') | |
55 flags.DEFINE_string( | |
56 'max_template_date', None, | |
57 'Maximum template release date to consider (ISO-8601 format: YYYY-MM-DD). ' | |
58 'Important if folding historical test sets.') | |
59 flags.DEFINE_enum( | |
60 'db_preset', 'full_dbs', ['full_dbs', 'reduced_dbs'], | |
61 'Choose preset MSA database configuration - smaller genetic database ' | |
62 'config (reduced_dbs) or full genetic database config (full_dbs)') | |
63 flags.DEFINE_enum( | |
64 'model_preset', 'monomer', | |
65 ['monomer', 'monomer_casp14', 'monomer_ptm', 'multimer'], | |
66 'Choose preset model configuration - the monomer model, the monomer model ' | |
67 'with extra ensembling, monomer model with pTM head, or multimer model') | |
68 flags.DEFINE_boolean( | |
69 'benchmark', False, | |
70 'Run multiple JAX model evaluations to obtain a timing that excludes the ' | |
71 'compilation time, which should be more indicative of the time required ' | |
72 'for inferencing many proteins.') | |
73 flags.DEFINE_boolean( | |
74 'use_precomputed_msas', False, | |
75 'Whether to read MSAs that have been written to disk. WARNING: This will ' | |
76 'not check if the sequence, database or configuration have changed.') | |
77 | |
78 FLAGS = flags.FLAGS | |
79 | |
80 _ROOT_MOUNT_DIRECTORY = '/mnt/' | |
81 | |
82 | |
83 def _create_mount(mount_name: str, path: str) -> Tuple[types.Mount, str]: | |
84 path = os.path.abspath(path) | |
85 source_path = os.path.dirname(path) | |
86 target_path = os.path.join(_ROOT_MOUNT_DIRECTORY, mount_name) | |
87 logging.info('Mounting %s -> %s', source_path, target_path) | |
88 mount = types.Mount(target_path, source_path, type='bind', read_only=True) | |
89 return mount, os.path.join(target_path, os.path.basename(path)) | |
90 | |
91 | |
92 def main(argv): | |
93 if len(argv) > 1: | |
94 raise app.UsageError('Too many command-line arguments.') | |
95 | |
96 # You can individually override the following paths if you have placed the | |
97 # data in locations other than the FLAGS.data_dir. | |
98 | |
99 # Path to the Uniref90 database for use by JackHMMER. | |
100 uniref90_database_path = os.path.join( | |
101 FLAGS.data_dir, 'uniref90', 'uniref90.fasta') | |
102 | |
103 # Path to the Uniprot database for use by JackHMMER. | |
104 uniprot_database_path = os.path.join( | |
105 FLAGS.data_dir, 'uniprot', 'uniprot.fasta') | |
106 | |
107 # Path to the MGnify database for use by JackHMMER. | |
108 mgnify_database_path = os.path.join( | |
109 FLAGS.data_dir, 'mgnify', 'mgy_clusters_2018_12.fa') | |
110 | |
111 # Path to the BFD database for use by HHblits. | |
112 bfd_database_path = os.path.join( | |
113 FLAGS.data_dir, 'bfd', | |
114 'bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt') | |
115 | |
116 # Path to the Small BFD database for use by JackHMMER. | |
117 small_bfd_database_path = os.path.join( | |
118 FLAGS.data_dir, 'small_bfd', 'bfd-first_non_consensus_sequences.fasta') | |
119 | |
120 # Path to the Uniclust30 database for use by HHblits. | |
121 uniclust30_database_path = os.path.join( | |
122 FLAGS.data_dir, 'uniclust30', 'uniclust30_2018_08', 'uniclust30_2018_08') | |
123 | |
124 # Path to the PDB70 database for use by HHsearch. | |
125 pdb70_database_path = os.path.join(FLAGS.data_dir, 'pdb70', 'pdb70') | |
126 | |
127 # Path to the PDB seqres database for use by hmmsearch. | |
128 pdb_seqres_database_path = os.path.join( | |
129 FLAGS.data_dir, 'pdb_seqres', 'pdb_seqres.txt') | |
130 | |
131 # Path to a directory with template mmCIF structures, each named <pdb_id>.cif. | |
132 template_mmcif_dir = os.path.join(FLAGS.data_dir, 'pdb_mmcif', 'mmcif_files') | |
133 | |
134 # Path to a file mapping obsolete PDB IDs to their replacements. | |
135 obsolete_pdbs_path = os.path.join(FLAGS.data_dir, 'pdb_mmcif', 'obsolete.dat') | |
136 | |
137 alphafold_path = pathlib.Path(__file__).parent.parent | |
138 data_dir_path = pathlib.Path(FLAGS.data_dir) | |
139 if alphafold_path == data_dir_path or alphafold_path in data_dir_path.parents: | |
140 raise app.UsageError( | |
141 f'The download directory {FLAGS.data_dir} should not be a subdirectory ' | |
142 f'in the AlphaFold repository directory. If it is, the Docker build is ' | |
143 f'slow since the large databases are copied during the image creation.') | |
144 | |
145 mounts = [] | |
146 command_args = [] | |
147 | |
148 # Mount each fasta path as a unique target directory. | |
149 target_fasta_paths = [] | |
150 for i, fasta_path in enumerate(FLAGS.fasta_paths): | |
151 mount, target_path = _create_mount(f'fasta_path_{i}', fasta_path) | |
152 mounts.append(mount) | |
153 target_fasta_paths.append(target_path) | |
154 command_args.append(f'--fasta_paths={",".join(target_fasta_paths)}') | |
155 | |
156 database_paths = [ | |
157 ('uniref90_database_path', uniref90_database_path), | |
158 ('mgnify_database_path', mgnify_database_path), | |
159 ('data_dir', FLAGS.data_dir), | |
160 ('template_mmcif_dir', template_mmcif_dir), | |
161 ('obsolete_pdbs_path', obsolete_pdbs_path), | |
162 ] | |
163 | |
164 if FLAGS.model_preset == 'multimer': | |
165 database_paths.append(('uniprot_database_path', uniprot_database_path)) | |
166 database_paths.append(('pdb_seqres_database_path', | |
167 pdb_seqres_database_path)) | |
168 else: | |
169 database_paths.append(('pdb70_database_path', pdb70_database_path)) | |
170 | |
171 if FLAGS.db_preset == 'reduced_dbs': | |
172 database_paths.append(('small_bfd_database_path', small_bfd_database_path)) | |
173 else: | |
174 database_paths.extend([ | |
175 ('uniclust30_database_path', uniclust30_database_path), | |
176 ('bfd_database_path', bfd_database_path), | |
177 ]) | |
178 for name, path in database_paths: | |
179 if path: | |
180 mount, target_path = _create_mount(name, path) | |
181 mounts.append(mount) | |
182 command_args.append(f'--{name}={target_path}') | |
183 | |
184 output_target_path = os.path.join(_ROOT_MOUNT_DIRECTORY, 'output') | |
185 mounts.append(types.Mount(output_target_path, FLAGS.output_dir, type='bind')) | |
186 | |
187 command_args.extend([ | |
188 f'--output_dir={output_target_path}', | |
189 f'--max_template_date={FLAGS.max_template_date}', | |
190 f'--db_preset={FLAGS.db_preset}', | |
191 f'--model_preset={FLAGS.model_preset}', | |
192 f'--benchmark={FLAGS.benchmark}', | |
193 f'--use_precomputed_msas={FLAGS.use_precomputed_msas}', | |
194 '--logtostderr', | |
195 ]) | |
196 | |
197 if FLAGS.is_prokaryote_list: | |
198 command_args.append( | |
199 f'--is_prokaryote_list={",".join(FLAGS.is_prokaryote_list)}') | |
200 | |
201 client = docker.from_env() | |
202 container = client.containers.run( | |
203 image=FLAGS.docker_image_name, | |
204 command=command_args, | |
205 runtime='nvidia' if FLAGS.use_gpu else None, | |
206 remove=True, | |
207 detach=True, | |
208 mounts=mounts, | |
209 environment={ | |
210 'NVIDIA_VISIBLE_DEVICES': FLAGS.gpu_devices, | |
211 # The following flags allow us to make predictions on proteins that | |
212 # would typically be too long to fit into GPU memory. | |
213 'TF_FORCE_UNIFIED_MEMORY': '1', | |
214 'XLA_PYTHON_CLIENT_MEM_FRACTION': '4.0', | |
215 }) | |
216 | |
217 # Add signal handler to ensure CTRL+C also stops the running container. | |
218 signal.signal(signal.SIGINT, | |
219 lambda unused_sig, unused_frame: container.kill()) | |
220 | |
221 for line in container.logs(stream=True): | |
222 logging.info(line.strip().decode('utf-8')) | |
223 | |
224 | |
225 if __name__ == '__main__': | |
226 flags.mark_flags_as_required([ | |
227 'data_dir', | |
228 'fasta_paths', | |
229 'max_template_date', | |
230 ]) | |
231 app.run(main) |