comparison scripts/outputs.py @ 24:31f648b7555a draft

planemo upload for repository https://github.com/usegalaxy-au/tools-au commit 44db277529c0e189149235cf60a627193a792fba
author galaxy-australia
date Sat, 05 Jul 2025 03:56:38 +0000
parents 2891385d6ace
children
comparison
equal deleted inserted replaced
23:2891385d6ace 24:31f648b7555a
18 import json 18 import json
19 import numpy as np 19 import numpy as np
20 import os 20 import os
21 import pickle as pk 21 import pickle as pk
22 import shutil 22 import shutil
23 import zipfile
24 from matplotlib import pyplot as plt
25 from pathlib import Path 23 from pathlib import Path
26 from typing import Dict, List 24 from typing import Dict, List
27 25
26 from matplotlib import pyplot as plt
27
28 # Output file paths
28 OUTPUT_DIR = 'extra' 29 OUTPUT_DIR = 'extra'
29 OUTPUTS = { 30 OUTPUTS = {
30 'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl', 31 'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl',
31 'model_pae': OUTPUT_DIR + '/pae_ranked_{rank}.csv', 32 'model_pae': OUTPUT_DIR + '/pae_ranked_{rank}.csv',
32 'model_plot': OUTPUT_DIR + '/ranked_{rank}.png', 33 'model_plot': OUTPUT_DIR + '/ranked_{rank}.png',
61 """Parse and store settings/config.""" 62 """Parse and store settings/config."""
62 def __init__(self): 63 def __init__(self):
63 self.workdir = None 64 self.workdir = None
64 self.output_confidence_scores = True 65 self.output_confidence_scores = True
65 self.output_residue_scores = False 66 self.output_residue_scores = False
67 self.is_multimer = False
66 self.parse() 68 self.parse()
67 69
68 def parse(self) -> None: 70 def parse(self) -> None:
69 parser = argparse.ArgumentParser() 71 parser = argparse.ArgumentParser()
70 parser.add_argument( 72 parser.add_argument(
94 action="store_true", 96 action="store_true",
95 ) 97 )
96 parser.add_argument( 98 parser.add_argument(
97 "--plot-msa", 99 "--plot-msa",
98 help="Plot multiple-sequence alignment coverage as a heatmap", 100 help="Plot multiple-sequence alignment coverage as a heatmap",
99 action="store_true",
100 )
101 parser.add_argument(
102 "--msa",
103 help="Collect multiple-sequence alignments as ZIP archives",
104 action="store_true",
105 )
106 parser.add_argument(
107 "--msa_only",
108 help="Alphafold generated MSA files only - skip all other outputs",
109 action="store_true", 101 action="store_true",
110 ) 102 )
111 args = parser.parse_args() 103 args = parser.parse_args()
112 self.workdir = Path(args.workdir.rstrip('/')) 104 self.workdir = Path(args.workdir.rstrip('/'))
113 self.output_residue_scores = args.confidence_scores 105 self.output_residue_scores = args.confidence_scores
114 self.output_model_pkls = args.pkl 106 self.output_model_pkls = args.pkl
115 self.output_model_plots = args.plot 107 self.output_model_plots = args.plot
116 self.output_pae = args.pae 108 self.output_pae = args.pae
117 self.plot_msa = args.plot_msa 109 self.plot_msa = args.plot_msa
118 self.collect_msas = args.msa
119 self.model_preset = self._sniff_model_preset() 110 self.model_preset = self._sniff_model_preset()
120 self.is_multimer = self.model_preset == PRESETS.multimer
121 self.output_dir = self.workdir / OUTPUT_DIR 111 self.output_dir = self.workdir / OUTPUT_DIR
122 self.msa_only = args.msa_only
123 os.makedirs(self.output_dir, exist_ok=True) 112 os.makedirs(self.output_dir, exist_ok=True)
124 113
125 def _sniff_model_preset(self) -> bool: 114 def _sniff_model_preset(self) -> bool:
126 """Check if the run was multimer or monomer.""" 115 """Check if the run was multimer or monomer."""
127 for path in self.workdir.glob('*.pkl'): 116 for path in self.workdir.glob('*.pkl'):
129 if '_multimer_' in path.name: 118 if '_multimer_' in path.name:
130 return PRESETS.multimer 119 return PRESETS.multimer
131 if '_ptm_' in path.name: 120 if '_ptm_' in path.name:
132 return PRESETS.monomer_ptm 121 return PRESETS.monomer_ptm
133 return PRESETS.monomer 122 return PRESETS.monomer
134 return PRESETS.monomer
135 123
136 124
137 class ExecutionContext: 125 class ExecutionContext:
138 """Collect file paths etc.""" 126 """Collect file paths etc."""
139 def __init__(self, settings: Settings): 127 def __init__(self, settings: Settings):
140 self.settings = settings 128 self.settings = settings
141 if settings.is_multimer: 129 if settings.model_preset == PRESETS.multimer:
142 self.plddt_key = PLDDT_KEY.multimer 130 self.plddt_key = PLDDT_KEY.multimer
143 else: 131 else:
144 self.plddt_key = PLDDT_KEY.monomer 132 self.plddt_key = PLDDT_KEY.monomer
145 133
146 def get_model_key(self, ix: int) -> str: 134 def get_model_key(self, ix: int) -> str:
209 """Return ordered list of model indexes.""" 197 """Return ordered list of model indexes."""
210 return self.data['order'] 198 return self.data['order']
211 199
212 def get_plddt_for_rank(self, rank: int) -> List[float]: 200 def get_plddt_for_rank(self, rank: int) -> List[float]:
213 """Get pLDDT score for model instance.""" 201 """Get pLDDT score for model instance."""
214 return self.data[self.context.plddt_key][self.data['order'][rank - 1]] 202 return self.data[self.context.plddt_key][self.data['order'][rank]]
215 203
216 def get_rank_for_model(self, model_name: str) -> int: 204 def get_rank_for_model(self, model_name: str) -> int:
217 """Return 0-indexed rank for given model name. 205 """Return 0-indexed rank for given model name.
218 206
219 Model names are expressed in result_model_*.pkl file names. 207 Model names are expressed in result_model_*.pkl file names.
226 outfile = context.settings.workdir / OUTPUTS['model_confidence_scores'] 214 outfile = context.settings.workdir / OUTPUTS['model_confidence_scores']
227 scores: Dict[str, list] = {} 215 scores: Dict[str, list] = {}
228 header = ['model', context.plddt_key] 216 header = ['model', context.plddt_key]
229 217
230 for i, path in enumerate(context.model_pkl_paths): 218 for i, path in enumerate(context.model_pkl_paths):
231 rank = int(path.name.split('model_')[-1][0]) 219 model_name = 'model_' + path.stem.split('model_')[1]
220 rank = ranking.get_rank_for_model(model_name)
232 scores_ls = [ranking.get_plddt_for_rank(rank)] 221 scores_ls = [ranking.get_plddt_for_rank(rank)]
233 with open(path, 'rb') as f: 222 with open(path, 'rb') as f:
234 data = pk.load(f) 223 data = pk.load(f)
235 if 'ptm' in data: 224 if 'ptm' in data:
236 scores_ls.append(data['ptm']) 225 scores_ls.append(data['ptm'])
242 header += ['iptm'] 231 header += ['iptm']
243 scores[rank] = scores_ls 232 scores[rank] = scores_ls
244 233
245 with open(outfile, 'w') as f: 234 with open(outfile, 'w') as f:
246 f.write('\t'.join(header) + '\n') 235 f.write('\t'.join(header) + '\n')
247 for rank, score_ls in scores.items(): 236 for rank in sorted(scores):
248 row = [f"ranked_{rank - 1}"] + [str(x) for x in score_ls] 237 score_ls = scores[rank]
238 row = [f"ranked_{rank}"] + [str(x) for x in score_ls]
249 f.write('\t'.join(row) + '\n') 239 f.write('\t'.join(row) + '\n')
250 240
251 241
252 def write_per_residue_scores( 242 def write_per_residue_scores(
253 ranking: ResultRanking, 243 ranking: ResultRanking,
388 plt.tight_layout() 378 plt.tight_layout()
389 plt.savefig(wdir / OUTPUTS['msa'], dpi=dpi) 379 plt.savefig(wdir / OUTPUTS['msa'], dpi=dpi)
390 plt.close() 380 plt.close()
391 381
392 382
393 def collect_msas(settings: Settings):
394 """Collect MSA files into ZIP archive(s)."""
395
396 def zip_dir(directory: Path, is_multimer: bool, name: str):
397 chain_id = directory.with_suffix('.zip').stem
398 msa_dir = settings.output_dir / 'msas'
399 msa_dir.mkdir(exist_ok=True)
400 zip_name = (
401 f"MSA-{chain_id}-{name}.zip"
402 if is_multimer
403 else f"MSA-{name}.zip")
404 zip_path = msa_dir / zip_name
405 with zipfile.ZipFile(zip_path, 'w') as z:
406 for path in directory.glob('*'):
407 z.write(path, path.name)
408
409 print("Collecting MSA archives...")
410 chain_names = get_input_sequence_ids(
411 settings.workdir.parent.parent / 'alphafold.fasta')
412 msa_dir = settings.workdir / 'msas'
413 is_multimer = (msa_dir / 'A').exists()
414 if is_multimer:
415 msa_dirs = sorted([
416 path for path in msa_dir.glob('*')
417 if path.is_dir()
418 ])
419 for i, path in enumerate(msa_dirs):
420 zip_dir(path, is_multimer, chain_names[i])
421 else:
422 zip_dir(msa_dir, is_multimer, chain_names[0])
423
424
425 def get_input_sequence_ids(fasta_file: Path) -> List[str]:
426 """Read headers from the input FASTA file.
427 Split them to get a sequence ID and truncate to 20 chars max.
428 """
429 headers = []
430 for line in fasta_file.read_text().split('\n'):
431 if line.startswith('>'):
432 seq_id = line[1:].split(' ')[0]
433 seq_id_trunc = seq_id[:20].strip()
434 if len(seq_id) > 20:
435 seq_id_trunc += '...'
436 headers.append(seq_id_trunc)
437 return headers
438
439
440 def template_html(context: ExecutionContext): 383 def template_html(context: ExecutionContext):
441 """Template HTML file. 384 """Template HTML file.
442 385
443 Remove buttons that are redundant with limited model outputs. 386 Remove buttons that are redundant with limited model outputs.
444 """ 387 """
454 397
455 398
456 def main(): 399 def main():
457 """Parse output files and generate additional output files.""" 400 """Parse output files and generate additional output files."""
458 settings = Settings() 401 settings = Settings()
459 if not settings.msa_only: 402 context = ExecutionContext(settings)
460 context = ExecutionContext(settings) 403 ranking = ResultRanking(context)
461 ranking = ResultRanking(context) 404 write_confidence_scores(ranking, context)
462 write_confidence_scores(ranking, context) 405 rekey_relax_metrics(ranking, context)
463 rekey_relax_metrics(ranking, context) 406 template_html(context)
464 template_html(context) 407
465 408 # Optional outputs
466 # Optional outputs 409 if settings.output_model_pkls:
467 if settings.output_model_pkls: 410 rename_model_pkls(ranking, context)
468 rename_model_pkls(ranking, context) 411 if settings.output_model_plots:
469 if settings.output_model_plots: 412 plddt_pae_plots(ranking, context)
470 plddt_pae_plots(ranking, context) 413 if settings.output_pae:
471 if settings.output_pae: 414 # Only created by monomer_ptm and multimer models
472 # Only created by monomer_ptm and multimer models 415 extract_pae_to_csv(ranking, context)
473 extract_pae_to_csv(ranking, context) 416 if settings.output_residue_scores:
474 if settings.output_residue_scores: 417 write_per_residue_scores(ranking, context)
475 write_per_residue_scores(ranking, context) 418 if settings.plot_msa:
476 if settings.plot_msa: 419 plot_msa(context.settings.workdir)
477 plot_msa(settings.workdir)
478 if settings.collect_msas or settings.msa_only:
479 collect_msas(settings)
480 420
481 421
482 if __name__ == '__main__': 422 if __name__ == '__main__':
483 main() 423 main()