Mercurial > repos > galaxy-australia > alphafold2
comparison outputs.py @ 14:d00e15139065 draft
planemo upload for repository https://github.com/usegalaxy-au/tools-au commit d490defa32d9c318137d2d781243b392cb14110d-dirty
author | galaxy-australia |
---|---|
date | Tue, 28 Feb 2023 01:15:42 +0000 |
parents | |
children | a58f7eb0df2c |
comparison
equal
deleted
inserted
replaced
13:c0e71cb2bd1b | 14:d00e15139065 |
---|---|
1 """Generate additional output files not produced by AlphaFold. | |
2 | |
3 Currently this is includes: | |
4 - model confidence scores | |
5 - per-residue confidence scores (pLDDTs - optional output) | |
6 - model_*.pkl files renamed with rank order | |
7 | |
8 N.B. There have been issues with this script breaking between AlphaFold | |
9 versions due to minor changes in the output directory structure across minor | |
10 versions. It will likely need updating with future releases of AlphaFold. | |
11 | |
12 This code is more complex than you might expect due to the output files | |
13 'moving around' considerably, depending on run parameters. You will see that | |
14 several output paths are determined dynamically. | |
15 """ | |
16 | |
17 import argparse | |
18 import json | |
19 import os | |
20 import pickle as pk | |
21 import shutil | |
22 from pathlib import Path | |
23 from typing import List | |
24 | |
25 # Output file names | |
26 OUTPUT_DIR = 'extra' | |
27 OUTPUTS = { | |
28 'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl', | |
29 'model_confidence_scores': OUTPUT_DIR + '/model_confidence_scores.tsv', | |
30 'plddts': OUTPUT_DIR + '/plddts.tsv', | |
31 'relax': OUTPUT_DIR + '/relax_metrics_ranked.json', | |
32 } | |
33 | |
34 # Keys for accessing confidence data from JSON/pkl files | |
35 # They change depending on whether the run was monomer or multimer | |
36 PLDDT_KEY = { | |
37 'monomer': 'plddts', | |
38 'multimer': 'iptm+ptm', | |
39 } | |
40 | |
41 | |
42 class Settings: | |
43 """Parse and store settings/config.""" | |
44 def __init__(self): | |
45 self.workdir = None | |
46 self.output_confidence_scores = True | |
47 self.output_residue_scores = False | |
48 self.is_multimer = False | |
49 | |
50 def parse_settings(self) -> None: | |
51 parser = argparse.ArgumentParser() | |
52 parser.add_argument( | |
53 "workdir", | |
54 help="alphafold output directory", | |
55 type=str | |
56 ) | |
57 parser.add_argument( | |
58 "-p", | |
59 "--plddts", | |
60 help="output per-residue confidence scores (pLDDTs)", | |
61 action="store_true" | |
62 ) | |
63 parser.add_argument( | |
64 "-m", | |
65 "--multimer", | |
66 help="parse output from AlphaFold multimer", | |
67 action="store_true" | |
68 ) | |
69 parser.add_argument( | |
70 "--model-pkl", | |
71 dest="model_pkl", | |
72 help="rename model pkl outputs with rank order", | |
73 action="store_true" | |
74 ) | |
75 args = parser.parse_args() | |
76 self.workdir = Path(args.workdir.rstrip('/')) | |
77 self.output_residue_scores = args.plddts | |
78 self.output_model_pkls = args.model_pkl | |
79 self.is_multimer = args.multimer | |
80 self.output_dir = self.workdir / OUTPUT_DIR | |
81 os.makedirs(self.output_dir, exist_ok=True) | |
82 | |
83 | |
84 class ExecutionContext: | |
85 """Collect file paths etc.""" | |
86 def __init__(self, settings: Settings): | |
87 self.settings = settings | |
88 if settings.is_multimer: | |
89 self.plddt_key = PLDDT_KEY['multimer'] | |
90 else: | |
91 self.plddt_key = PLDDT_KEY['monomer'] | |
92 | |
93 def get_model_key(self, ix: int) -> str: | |
94 """Return json key for model index. | |
95 | |
96 The key format changed between minor AlphaFold versions so this | |
97 function determines the correct key. | |
98 """ | |
99 with open(self.ranking_debug) as f: | |
100 data = json.load(f) | |
101 model_keys = list(data[self.plddt_key].keys()) | |
102 for k in model_keys: | |
103 if k.startswith(f"model_{ix}_"): | |
104 return k | |
105 return KeyError( | |
106 f'Could not find key for index={ix} in' | |
107 ' ranking_debug.json') | |
108 | |
109 @property | |
110 def ranking_debug(self) -> str: | |
111 return self.settings.workdir / 'ranking_debug.json' | |
112 | |
113 @property | |
114 def relax_metrics(self) -> str: | |
115 return self.settings.workdir / 'relax_metrics.json' | |
116 | |
117 @property | |
118 def relax_metrics_ranked(self) -> str: | |
119 return self.settings.workdir / 'relax_metrics_ranked.json' | |
120 | |
121 @property | |
122 def model_pkl_paths(self) -> List[str]: | |
123 return sorted([ | |
124 self.settings.workdir / f | |
125 for f in os.listdir(self.settings.workdir) | |
126 if f.startswith('result_model_') and f.endswith('.pkl') | |
127 ]) | |
128 | |
129 | |
130 class ResultModelPrediction: | |
131 """Load and manipulate data from result_model_*.pkl files.""" | |
132 def __init__(self, path: str, context: ExecutionContext): | |
133 self.context = context | |
134 self.path = path | |
135 self.name = os.path.basename(path).replace('result_', '').split('.')[0] | |
136 with open(path, 'rb') as path: | |
137 self.data = pk.load(path) | |
138 | |
139 @property | |
140 def plddts(self) -> List[float]: | |
141 """Return pLDDT scores for each residue.""" | |
142 return list(self.data['plddt']) | |
143 | |
144 | |
145 class ResultRanking: | |
146 """Load and manipulate data from ranking_debug.json file.""" | |
147 | |
148 def __init__(self, context: ExecutionContext): | |
149 self.path = context.ranking_debug | |
150 self.context = context | |
151 with open(self.path, 'r') as f: | |
152 self.data = json.load(f) | |
153 | |
154 @property | |
155 def order(self) -> List[str]: | |
156 """Return ordered list of model indexes.""" | |
157 return self.data['order'] | |
158 | |
159 def get_plddt_for_rank(self, rank: int) -> List[float]: | |
160 """Get pLDDT score for model instance.""" | |
161 return self.data[self.context.plddt_key][self.data['order'][rank - 1]] | |
162 | |
163 def get_rank_for_model(self, model_name: str) -> int: | |
164 """Return 0-indexed rank for given model name. | |
165 | |
166 Model names are expressed in result_model_*.pkl file names. | |
167 """ | |
168 return self.data['order'].index(model_name) | |
169 | |
170 | |
171 def write_confidence_scores(ranking: ResultRanking, context: ExecutionContext): | |
172 """Write per-model confidence scores.""" | |
173 path = context.settings.workdir / OUTPUTS['model_confidence_scores'] | |
174 with open(path, 'w') as f: | |
175 for rank in range(1, 6): | |
176 score = ranking.get_plddt_for_rank(rank) | |
177 f.write(f'ranked_{rank - 1}\t{score:.2f}\n') | |
178 | |
179 | |
180 def write_per_residue_scores( | |
181 ranking: ResultRanking, | |
182 context: ExecutionContext, | |
183 ): | |
184 """Write per-residue plddts for each model. | |
185 | |
186 A row of plddt values is written for each model in tabular format. | |
187 """ | |
188 model_plddts = {} | |
189 for i, path in enumerate(context.model_pkl_paths): | |
190 model = ResultModelPrediction(path, context) | |
191 rank = ranking.get_rank_for_model(model.name) | |
192 model_plddts[rank] = model.plddts | |
193 | |
194 path = context.settings.workdir / OUTPUTS['plddts'] | |
195 with open(path, 'w') as f: | |
196 for i in sorted(list(model_plddts.keys())): | |
197 row = [f'ranked_{i}'] + [ | |
198 str(x) for x in model_plddts[i] | |
199 ] | |
200 f.write('\t'.join(row) + '\n') | |
201 | |
202 | |
203 def rename_model_pkls(ranking: ResultRanking, context: ExecutionContext): | |
204 """Rename model.pkl files so the rank order is implicit.""" | |
205 for path in context.model_pkl_paths: | |
206 model = ResultModelPrediction(path, context) | |
207 rank = ranking.get_rank_for_model(model.name) | |
208 new_path = ( | |
209 context.settings.workdir | |
210 / OUTPUTS['model_pkl'].format(rank=rank) | |
211 ) | |
212 shutil.copyfile(path, new_path) | |
213 | |
214 | |
215 def rekey_relax_metrics(ranking: ResultRanking, context: ExecutionContext): | |
216 """Replace keys in relax_metrics.json with 0-indexed rank.""" | |
217 with open(context.relax_metrics) as f: | |
218 data = json.load(f) | |
219 for k in list(data.keys()): | |
220 rank = ranking.get_rank_for_model(k) | |
221 data[f'ranked_{rank}'] = data.pop(k) | |
222 new_path = context.settings.workdir / OUTPUTS['relax'] | |
223 with open(new_path, 'w') as f: | |
224 json.dump(data, f) | |
225 | |
226 | |
227 def main(): | |
228 """Parse output files and generate additional output files.""" | |
229 settings = Settings() | |
230 settings.parse_settings() | |
231 context = ExecutionContext(settings) | |
232 ranking = ResultRanking(context) | |
233 write_confidence_scores(ranking, context) | |
234 rekey_relax_metrics(ranking, context) | |
235 | |
236 # Optional outputs | |
237 if settings.output_model_pkls: | |
238 rename_model_pkls(ranking, context) | |
239 | |
240 if settings.output_residue_scores: | |
241 write_per_residue_scores(ranking, context) | |
242 | |
243 | |
244 if __name__ == '__main__': | |
245 main() |