Mercurial > repos > galaxy-australia > alphafold2
comparison scripts/outputs.py @ 21:e7f1b552a695 draft
planemo upload for repository https://github.com/usegalaxy-au/tools-au commit 628c9fdcb77489063145a2307b6bb6a450416dd6-dirty
author | galaxy-australia |
---|---|
date | Tue, 29 Oct 2024 02:15:36 +0000 |
parents | 6ab1a261520a |
children | 3f188450ca4f |
comparison
equal
deleted
inserted
replaced
20:6ab1a261520a | 21:e7f1b552a695 |
---|---|
14 several output paths are determined dynamically. | 14 several output paths are determined dynamically. |
15 """ | 15 """ |
16 | 16 |
17 import argparse | 17 import argparse |
18 import json | 18 import json |
19 import numpy as np | |
19 import os | 20 import os |
20 import pickle as pk | 21 import pickle as pk |
21 import shutil | 22 import shutil |
22 from pathlib import Path | 23 from pathlib import Path |
23 from typing import List | 24 from typing import Dict, List |
24 | 25 |
25 from matplotlib import pyplot as plt | 26 from matplotlib import pyplot as plt |
26 | 27 |
27 # Output file paths | 28 # Output file paths |
28 OUTPUT_DIR = 'extra' | 29 OUTPUT_DIR = 'extra' |
31 'model_pae': OUTPUT_DIR + '/pae_ranked_{rank}.csv', | 32 'model_pae': OUTPUT_DIR + '/pae_ranked_{rank}.csv', |
32 'model_plot': OUTPUT_DIR + '/ranked_{rank}.png', | 33 'model_plot': OUTPUT_DIR + '/ranked_{rank}.png', |
33 'model_confidence_scores': OUTPUT_DIR + '/model_confidence_scores.tsv', | 34 'model_confidence_scores': OUTPUT_DIR + '/model_confidence_scores.tsv', |
34 'plddts': OUTPUT_DIR + '/plddts.tsv', | 35 'plddts': OUTPUT_DIR + '/plddts.tsv', |
35 'relax': OUTPUT_DIR + '/relax_metrics_ranked.json', | 36 'relax': OUTPUT_DIR + '/relax_metrics_ranked.json', |
36 } | 37 'msa': OUTPUT_DIR + '/msa_coverage.png', |
37 | |
38 # Keys for accessing confidence data from JSON/pkl files | |
39 # They change depending on whether the run was monomer or multimer | |
40 PLDDT_KEY = { | |
41 'monomer': 'plddts', | |
42 'multimer': 'iptm+ptm', | |
43 } | 38 } |
44 | 39 |
45 HTML_PATH = Path(__file__).parent / "alphafold.html" | 40 HTML_PATH = Path(__file__).parent / "alphafold.html" |
46 HTML_OUTPUT_FILENAME = 'alphafold.html' | 41 HTML_OUTPUT_FILENAME = 'alphafold.html' |
47 HTML_BUTTON_ATTR = 'class="btn" id="btn-ranked_{rank}"' | 42 HTML_BUTTON_ATTR = 'class="btn" id="btn-ranked_{rank}"' |
48 HTML_BUTTON_ATTR_DISABLED = ( | 43 HTML_BUTTON_ATTR_DISABLED = ( |
49 'class="btn disabled" id="btn-ranked_{rank}" disabled') | 44 'class="btn disabled" id="btn-ranked_{rank}" disabled') |
45 | |
46 | |
47 class PLDDT_KEY: | |
48 """Dict keys for accessing confidence data from JSON/pkl files." | |
49 Changes depending on which model PRESET was used. | |
50 """ | |
51 monomer = 'plddts' | |
52 multimer = 'iptm+ptm' | |
53 | |
54 | |
55 class PRESETS: | |
56 monomer = 'monomer' | |
57 monomer_ptm = 'monomer_ptm' | |
58 multimer = 'multimer' | |
50 | 59 |
51 | 60 |
52 class Settings: | 61 class Settings: |
53 """Parse and store settings/config.""" | 62 """Parse and store settings/config.""" |
54 def __init__(self): | 63 def __init__(self): |
61 def parse(self) -> None: | 70 def parse(self) -> None: |
62 parser = argparse.ArgumentParser() | 71 parser = argparse.ArgumentParser() |
63 parser.add_argument( | 72 parser.add_argument( |
64 "workdir", | 73 "workdir", |
65 help="alphafold output directory", | 74 help="alphafold output directory", |
66 type=str | 75 type=str, |
67 ) | 76 ) |
68 parser.add_argument( | 77 parser.add_argument( |
69 "-p", | 78 "-s", |
70 "--plddts", | 79 "--confidence-scores", |
71 help="output per-residue confidence scores (pLDDTs)", | 80 help="output per-residue confidence scores (pLDDTs)", |
72 action="store_true" | 81 action="store_true", |
73 ) | |
74 parser.add_argument( | |
75 "-m", | |
76 "--multimer", | |
77 help="parse output from AlphaFold multimer", | |
78 action="store_true" | |
79 ) | 82 ) |
80 parser.add_argument( | 83 parser.add_argument( |
81 "--pkl", | 84 "--pkl", |
82 help="rename model pkl outputs with rank order", | 85 help="rename model pkl outputs with rank order", |
83 action="store_true" | 86 action="store_true", |
84 ) | 87 ) |
85 parser.add_argument( | 88 parser.add_argument( |
86 "--pae", | 89 "--pae", |
87 help="extract PAE from pkl files to CSV format", | 90 help="extract PAE from pkl files to CSV format", |
88 action="store_true" | 91 action="store_true", |
89 ) | 92 ) |
90 parser.add_argument( | 93 parser.add_argument( |
91 "--plot", | 94 "--plot", |
92 help="Plot pLDDT and PAE for each model", | 95 help="Plot pLDDT and PAE for each model", |
93 action="store_true" | 96 action="store_true", |
97 ) | |
98 parser.add_argument( | |
99 "--plot-msa", | |
100 help="Plot multiple-sequence alignment coverage as a heatmap", | |
101 action="store_true", | |
94 ) | 102 ) |
95 args = parser.parse_args() | 103 args = parser.parse_args() |
96 self.workdir = Path(args.workdir.rstrip('/')) | 104 self.workdir = Path(args.workdir.rstrip('/')) |
97 self.output_residue_scores = args.plddts | 105 self.output_residue_scores = args.confidence_scores |
98 self.output_model_pkls = args.pkl | 106 self.output_model_pkls = args.pkl |
99 self.output_model_plots = args.plot | 107 self.output_model_plots = args.plot |
100 self.output_pae = args.pae | 108 self.output_pae = args.pae |
101 self.is_multimer = args.multimer | 109 self.plot_msa = args.plot_msa |
110 self.model_preset = self._sniff_model_preset() | |
102 self.output_dir = self.workdir / OUTPUT_DIR | 111 self.output_dir = self.workdir / OUTPUT_DIR |
103 os.makedirs(self.output_dir, exist_ok=True) | 112 os.makedirs(self.output_dir, exist_ok=True) |
113 | |
114 def _sniff_model_preset(self) -> bool: | |
115 """Check if the run was multimer or monomer.""" | |
116 with open(self.workdir / 'relax_metrics.json') as f: | |
117 if '_multimer_' in f.read(): | |
118 return PRESETS.multimer | |
119 if '_ptm_' in f.read(): | |
120 return PRESETS.monomer_ptm | |
121 return PRESETS.monomer | |
104 | 122 |
105 | 123 |
106 class ExecutionContext: | 124 class ExecutionContext: |
107 """Collect file paths etc.""" | 125 """Collect file paths etc.""" |
108 def __init__(self, settings: Settings): | 126 def __init__(self, settings: Settings): |
109 self.settings = settings | 127 self.settings = settings |
110 if settings.is_multimer: | 128 if settings.model_preset == PRESETS.multimer: |
111 self.plddt_key = PLDDT_KEY['multimer'] | 129 self.plddt_key = PLDDT_KEY.multimer |
112 else: | 130 else: |
113 self.plddt_key = PLDDT_KEY['monomer'] | 131 self.plddt_key = PLDDT_KEY.monomer |
114 | 132 |
115 def get_model_key(self, ix: int) -> str: | 133 def get_model_key(self, ix: int) -> str: |
116 """Return json key for model index. | 134 """Return json key for model index. |
117 | 135 |
118 The key format changed between minor AlphaFold versions so this | 136 The key format changed between minor AlphaFold versions so this |
190 return self.data['order'].index(model_name) | 208 return self.data['order'].index(model_name) |
191 | 209 |
192 | 210 |
193 def write_confidence_scores(ranking: ResultRanking, context: ExecutionContext): | 211 def write_confidence_scores(ranking: ResultRanking, context: ExecutionContext): |
194 """Write per-model confidence scores.""" | 212 """Write per-model confidence scores.""" |
195 path = context.settings.workdir / OUTPUTS['model_confidence_scores'] | 213 outfile = context.settings.workdir / OUTPUTS['model_confidence_scores'] |
196 with open(path, 'w') as f: | 214 scores: Dict[str, list] = {} |
197 for rank in range(1, len(context.model_pkl_paths) + 1): | 215 header = ['model', context.plddt_key] |
198 score = ranking.get_plddt_for_rank(rank) | 216 |
199 f.write(f'ranked_{rank - 1}\t{score:.2f}\n') | 217 for i, path in enumerate(context.model_pkl_paths): |
218 rank = int(path.name.split('model_')[-1][0]) | |
219 scores_ls = [ranking.get_plddt_for_rank(rank)] | |
220 with open(path, 'rb') as f: | |
221 data = pk.load(f) | |
222 if 'ptm' in data: | |
223 scores_ls.append(data['ptm']) | |
224 if i == 0: | |
225 header += ['ptm'] | |
226 if 'iptm' in data: | |
227 scores_ls.append(data['iptm']) | |
228 if i == 0: | |
229 header += ['iptm'] | |
230 scores[rank] = scores_ls | |
231 | |
232 with open(outfile, 'w') as f: | |
233 f.write('\t'.join(header) + '\n') | |
234 for rank, score_ls in scores.items(): | |
235 row = [f"ranked_{rank - 1}"] + [str(x) for x in score_ls] | |
236 f.write('\t'.join(row) + '\n') | |
200 | 237 |
201 | 238 |
202 def write_per_residue_scores( | 239 def write_per_residue_scores( |
203 ranking: ResultRanking, | 240 ranking: ResultRanking, |
204 context: ExecutionContext, | 241 context: ExecutionContext, |
302 plt.title('Predicted Aligned Error') | 339 plt.title('Predicted Aligned Error') |
303 plt.xlabel('Scored residue') | 340 plt.xlabel('Scored residue') |
304 plt.ylabel('Aligned residue') | 341 plt.ylabel('Aligned residue') |
305 | 342 |
306 plt.savefig(png_path) | 343 plt.savefig(png_path) |
344 plt.close() | |
345 | |
346 | |
347 def plot_msa(wdir: Path, dpi: int = 150): | |
348 """Plot MSA as a heatmap.""" | |
349 with open(wdir / 'features.pkl', 'rb') as f: | |
350 features = pk.load(f) | |
351 | |
352 msa = features.get('msa') | |
353 if msa is None: | |
354 print("Could not plot MSA coverage - 'msa' key not found in" | |
355 " features.pkl") | |
356 return | |
357 seqid = (np.array(msa[0] == msa).mean(-1)) | |
358 seqid_sort = seqid.argsort() | |
359 non_gaps = (msa != 21).astype(float) | |
360 non_gaps[non_gaps == 0] = np.nan | |
361 final = non_gaps[seqid_sort] * seqid[seqid_sort, None] | |
362 | |
363 plt.figure(figsize=(6, 4)) | |
364 # plt.subplot(111) | |
365 plt.title("Sequence coverage") | |
366 plt.imshow(final, | |
367 interpolation='nearest', aspect='auto', | |
368 cmap="rainbow_r", vmin=0, vmax=1, origin='lower') | |
369 plt.plot((msa != 21).sum(0), color='black') | |
370 plt.xlim(-0.5, msa.shape[1] - 0.5) | |
371 plt.ylim(-0.5, msa.shape[0] - 0.5) | |
372 plt.colorbar(label="Sequence identity to query", ) | |
373 plt.xlabel("Positions") | |
374 plt.ylabel("Sequences") | |
375 plt.tight_layout() | |
376 plt.savefig(wdir / OUTPUTS['msa'], dpi=dpi) | |
377 plt.close() | |
307 | 378 |
308 | 379 |
309 def template_html(context: ExecutionContext): | 380 def template_html(context: ExecutionContext): |
310 """Template HTML file. | 381 """Template HTML file. |
311 | 382 |
339 if settings.output_pae: | 410 if settings.output_pae: |
340 # Only created by monomer_ptm and multimer models | 411 # Only created by monomer_ptm and multimer models |
341 extract_pae_to_csv(ranking, context) | 412 extract_pae_to_csv(ranking, context) |
342 if settings.output_residue_scores: | 413 if settings.output_residue_scores: |
343 write_per_residue_scores(ranking, context) | 414 write_per_residue_scores(ranking, context) |
415 if settings.plot_msa: | |
416 plot_msa(context.settings.workdir) | |
344 | 417 |
345 | 418 |
346 if __name__ == '__main__': | 419 if __name__ == '__main__': |
347 main() | 420 main() |