comparison scripts/outputs.py @ 21:e7f1b552a695 draft

planemo upload for repository https://github.com/usegalaxy-au/tools-au commit 628c9fdcb77489063145a2307b6bb6a450416dd6-dirty
author galaxy-australia
date Tue, 29 Oct 2024 02:15:36 +0000
parents 6ab1a261520a
children 3f188450ca4f
comparison
equal deleted inserted replaced
20:6ab1a261520a 21:e7f1b552a695
14 several output paths are determined dynamically. 14 several output paths are determined dynamically.
15 """ 15 """
16 16
17 import argparse 17 import argparse
18 import json 18 import json
19 import numpy as np
19 import os 20 import os
20 import pickle as pk 21 import pickle as pk
21 import shutil 22 import shutil
22 from pathlib import Path 23 from pathlib import Path
23 from typing import List 24 from typing import Dict, List
24 25
25 from matplotlib import pyplot as plt 26 from matplotlib import pyplot as plt
26 27
27 # Output file paths 28 # Output file paths
28 OUTPUT_DIR = 'extra' 29 OUTPUT_DIR = 'extra'
31 'model_pae': OUTPUT_DIR + '/pae_ranked_{rank}.csv', 32 'model_pae': OUTPUT_DIR + '/pae_ranked_{rank}.csv',
32 'model_plot': OUTPUT_DIR + '/ranked_{rank}.png', 33 'model_plot': OUTPUT_DIR + '/ranked_{rank}.png',
33 'model_confidence_scores': OUTPUT_DIR + '/model_confidence_scores.tsv', 34 'model_confidence_scores': OUTPUT_DIR + '/model_confidence_scores.tsv',
34 'plddts': OUTPUT_DIR + '/plddts.tsv', 35 'plddts': OUTPUT_DIR + '/plddts.tsv',
35 'relax': OUTPUT_DIR + '/relax_metrics_ranked.json', 36 'relax': OUTPUT_DIR + '/relax_metrics_ranked.json',
36 } 37 'msa': OUTPUT_DIR + '/msa_coverage.png',
37
38 # Keys for accessing confidence data from JSON/pkl files
39 # They change depending on whether the run was monomer or multimer
40 PLDDT_KEY = {
41 'monomer': 'plddts',
42 'multimer': 'iptm+ptm',
43 } 38 }
44 39
45 HTML_PATH = Path(__file__).parent / "alphafold.html" 40 HTML_PATH = Path(__file__).parent / "alphafold.html"
46 HTML_OUTPUT_FILENAME = 'alphafold.html' 41 HTML_OUTPUT_FILENAME = 'alphafold.html'
47 HTML_BUTTON_ATTR = 'class="btn" id="btn-ranked_{rank}"' 42 HTML_BUTTON_ATTR = 'class="btn" id="btn-ranked_{rank}"'
48 HTML_BUTTON_ATTR_DISABLED = ( 43 HTML_BUTTON_ATTR_DISABLED = (
49 'class="btn disabled" id="btn-ranked_{rank}" disabled') 44 'class="btn disabled" id="btn-ranked_{rank}" disabled')
45
46
47 class PLDDT_KEY:
48 """Dict keys for accessing confidence data from JSON/pkl files."
49 Changes depending on which model PRESET was used.
50 """
51 monomer = 'plddts'
52 multimer = 'iptm+ptm'
53
54
55 class PRESETS:
56 monomer = 'monomer'
57 monomer_ptm = 'monomer_ptm'
58 multimer = 'multimer'
50 59
51 60
52 class Settings: 61 class Settings:
53 """Parse and store settings/config.""" 62 """Parse and store settings/config."""
54 def __init__(self): 63 def __init__(self):
61 def parse(self) -> None: 70 def parse(self) -> None:
62 parser = argparse.ArgumentParser() 71 parser = argparse.ArgumentParser()
63 parser.add_argument( 72 parser.add_argument(
64 "workdir", 73 "workdir",
65 help="alphafold output directory", 74 help="alphafold output directory",
66 type=str 75 type=str,
67 ) 76 )
68 parser.add_argument( 77 parser.add_argument(
69 "-p", 78 "-s",
70 "--plddts", 79 "--confidence-scores",
71 help="output per-residue confidence scores (pLDDTs)", 80 help="output per-residue confidence scores (pLDDTs)",
72 action="store_true" 81 action="store_true",
73 )
74 parser.add_argument(
75 "-m",
76 "--multimer",
77 help="parse output from AlphaFold multimer",
78 action="store_true"
79 ) 82 )
80 parser.add_argument( 83 parser.add_argument(
81 "--pkl", 84 "--pkl",
82 help="rename model pkl outputs with rank order", 85 help="rename model pkl outputs with rank order",
83 action="store_true" 86 action="store_true",
84 ) 87 )
85 parser.add_argument( 88 parser.add_argument(
86 "--pae", 89 "--pae",
87 help="extract PAE from pkl files to CSV format", 90 help="extract PAE from pkl files to CSV format",
88 action="store_true" 91 action="store_true",
89 ) 92 )
90 parser.add_argument( 93 parser.add_argument(
91 "--plot", 94 "--plot",
92 help="Plot pLDDT and PAE for each model", 95 help="Plot pLDDT and PAE for each model",
93 action="store_true" 96 action="store_true",
97 )
98 parser.add_argument(
99 "--plot-msa",
100 help="Plot multiple-sequence alignment coverage as a heatmap",
101 action="store_true",
94 ) 102 )
95 args = parser.parse_args() 103 args = parser.parse_args()
96 self.workdir = Path(args.workdir.rstrip('/')) 104 self.workdir = Path(args.workdir.rstrip('/'))
97 self.output_residue_scores = args.plddts 105 self.output_residue_scores = args.confidence_scores
98 self.output_model_pkls = args.pkl 106 self.output_model_pkls = args.pkl
99 self.output_model_plots = args.plot 107 self.output_model_plots = args.plot
100 self.output_pae = args.pae 108 self.output_pae = args.pae
101 self.is_multimer = args.multimer 109 self.plot_msa = args.plot_msa
110 self.model_preset = self._sniff_model_preset()
102 self.output_dir = self.workdir / OUTPUT_DIR 111 self.output_dir = self.workdir / OUTPUT_DIR
103 os.makedirs(self.output_dir, exist_ok=True) 112 os.makedirs(self.output_dir, exist_ok=True)
113
114 def _sniff_model_preset(self) -> bool:
115 """Check if the run was multimer or monomer."""
116 with open(self.workdir / 'relax_metrics.json') as f:
117 if '_multimer_' in f.read():
118 return PRESETS.multimer
119 if '_ptm_' in f.read():
120 return PRESETS.monomer_ptm
121 return PRESETS.monomer
104 122
105 123
106 class ExecutionContext: 124 class ExecutionContext:
107 """Collect file paths etc.""" 125 """Collect file paths etc."""
108 def __init__(self, settings: Settings): 126 def __init__(self, settings: Settings):
109 self.settings = settings 127 self.settings = settings
110 if settings.is_multimer: 128 if settings.model_preset == PRESETS.multimer:
111 self.plddt_key = PLDDT_KEY['multimer'] 129 self.plddt_key = PLDDT_KEY.multimer
112 else: 130 else:
113 self.plddt_key = PLDDT_KEY['monomer'] 131 self.plddt_key = PLDDT_KEY.monomer
114 132
115 def get_model_key(self, ix: int) -> str: 133 def get_model_key(self, ix: int) -> str:
116 """Return json key for model index. 134 """Return json key for model index.
117 135
118 The key format changed between minor AlphaFold versions so this 136 The key format changed between minor AlphaFold versions so this
190 return self.data['order'].index(model_name) 208 return self.data['order'].index(model_name)
191 209
192 210
193 def write_confidence_scores(ranking: ResultRanking, context: ExecutionContext): 211 def write_confidence_scores(ranking: ResultRanking, context: ExecutionContext):
194 """Write per-model confidence scores.""" 212 """Write per-model confidence scores."""
195 path = context.settings.workdir / OUTPUTS['model_confidence_scores'] 213 outfile = context.settings.workdir / OUTPUTS['model_confidence_scores']
196 with open(path, 'w') as f: 214 scores: Dict[str, list] = {}
197 for rank in range(1, len(context.model_pkl_paths) + 1): 215 header = ['model', context.plddt_key]
198 score = ranking.get_plddt_for_rank(rank) 216
199 f.write(f'ranked_{rank - 1}\t{score:.2f}\n') 217 for i, path in enumerate(context.model_pkl_paths):
218 rank = int(path.name.split('model_')[-1][0])
219 scores_ls = [ranking.get_plddt_for_rank(rank)]
220 with open(path, 'rb') as f:
221 data = pk.load(f)
222 if 'ptm' in data:
223 scores_ls.append(data['ptm'])
224 if i == 0:
225 header += ['ptm']
226 if 'iptm' in data:
227 scores_ls.append(data['iptm'])
228 if i == 0:
229 header += ['iptm']
230 scores[rank] = scores_ls
231
232 with open(outfile, 'w') as f:
233 f.write('\t'.join(header) + '\n')
234 for rank, score_ls in scores.items():
235 row = [f"ranked_{rank - 1}"] + [str(x) for x in score_ls]
236 f.write('\t'.join(row) + '\n')
200 237
201 238
202 def write_per_residue_scores( 239 def write_per_residue_scores(
203 ranking: ResultRanking, 240 ranking: ResultRanking,
204 context: ExecutionContext, 241 context: ExecutionContext,
302 plt.title('Predicted Aligned Error') 339 plt.title('Predicted Aligned Error')
303 plt.xlabel('Scored residue') 340 plt.xlabel('Scored residue')
304 plt.ylabel('Aligned residue') 341 plt.ylabel('Aligned residue')
305 342
306 plt.savefig(png_path) 343 plt.savefig(png_path)
344 plt.close()
345
346
347 def plot_msa(wdir: Path, dpi: int = 150):
348 """Plot MSA as a heatmap."""
349 with open(wdir / 'features.pkl', 'rb') as f:
350 features = pk.load(f)
351
352 msa = features.get('msa')
353 if msa is None:
354 print("Could not plot MSA coverage - 'msa' key not found in"
355 " features.pkl")
356 return
357 seqid = (np.array(msa[0] == msa).mean(-1))
358 seqid_sort = seqid.argsort()
359 non_gaps = (msa != 21).astype(float)
360 non_gaps[non_gaps == 0] = np.nan
361 final = non_gaps[seqid_sort] * seqid[seqid_sort, None]
362
363 plt.figure(figsize=(6, 4))
364 # plt.subplot(111)
365 plt.title("Sequence coverage")
366 plt.imshow(final,
367 interpolation='nearest', aspect='auto',
368 cmap="rainbow_r", vmin=0, vmax=1, origin='lower')
369 plt.plot((msa != 21).sum(0), color='black')
370 plt.xlim(-0.5, msa.shape[1] - 0.5)
371 plt.ylim(-0.5, msa.shape[0] - 0.5)
372 plt.colorbar(label="Sequence identity to query", )
373 plt.xlabel("Positions")
374 plt.ylabel("Sequences")
375 plt.tight_layout()
376 plt.savefig(wdir / OUTPUTS['msa'], dpi=dpi)
377 plt.close()
307 378
308 379
309 def template_html(context: ExecutionContext): 380 def template_html(context: ExecutionContext):
310 """Template HTML file. 381 """Template HTML file.
311 382
339 if settings.output_pae: 410 if settings.output_pae:
340 # Only created by monomer_ptm and multimer models 411 # Only created by monomer_ptm and multimer models
341 extract_pae_to_csv(ranking, context) 412 extract_pae_to_csv(ranking, context)
342 if settings.output_residue_scores: 413 if settings.output_residue_scores:
343 write_per_residue_scores(ranking, context) 414 write_per_residue_scores(ranking, context)
415 if settings.plot_msa:
416 plot_msa(context.settings.workdir)
344 417
345 418
346 if __name__ == '__main__': 419 if __name__ == '__main__':
347 main() 420 main()