comparison scripts/outputs.py @ 23:2891385d6ace draft default tip

planemo upload for repository https://github.com/usegalaxy-au/tools-au commit b347c6ccc82b14fcbff360b3357050d1d43e3ef5-dirty
author galaxy-australia
date Wed, 16 Apr 2025 05:46:58 +0000
parents 3f188450ca4f
children
comparison
equal deleted inserted replaced
22:3f188450ca4f 23:2891385d6ace
18 import json 18 import json
19 import numpy as np 19 import numpy as np
20 import os 20 import os
21 import pickle as pk 21 import pickle as pk
22 import shutil 22 import shutil
23 import zipfile
24 from matplotlib import pyplot as plt
23 from pathlib import Path 25 from pathlib import Path
24 from typing import Dict, List 26 from typing import Dict, List
25 27
26 from matplotlib import pyplot as plt
27
28 # Output file paths
29 OUTPUT_DIR = 'extra' 28 OUTPUT_DIR = 'extra'
30 OUTPUTS = { 29 OUTPUTS = {
31 'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl', 30 'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl',
32 'model_pae': OUTPUT_DIR + '/pae_ranked_{rank}.csv', 31 'model_pae': OUTPUT_DIR + '/pae_ranked_{rank}.csv',
33 'model_plot': OUTPUT_DIR + '/ranked_{rank}.png', 32 'model_plot': OUTPUT_DIR + '/ranked_{rank}.png',
62 """Parse and store settings/config.""" 61 """Parse and store settings/config."""
63 def __init__(self): 62 def __init__(self):
64 self.workdir = None 63 self.workdir = None
65 self.output_confidence_scores = True 64 self.output_confidence_scores = True
66 self.output_residue_scores = False 65 self.output_residue_scores = False
67 self.is_multimer = False
68 self.parse() 66 self.parse()
69 67
70 def parse(self) -> None: 68 def parse(self) -> None:
71 parser = argparse.ArgumentParser() 69 parser = argparse.ArgumentParser()
72 parser.add_argument( 70 parser.add_argument(
96 action="store_true", 94 action="store_true",
97 ) 95 )
98 parser.add_argument( 96 parser.add_argument(
99 "--plot-msa", 97 "--plot-msa",
100 help="Plot multiple-sequence alignment coverage as a heatmap", 98 help="Plot multiple-sequence alignment coverage as a heatmap",
99 action="store_true",
100 )
101 parser.add_argument(
102 "--msa",
103 help="Collect multiple-sequence alignments as ZIP archives",
104 action="store_true",
105 )
106 parser.add_argument(
107 "--msa_only",
108 help="Alphafold generated MSA files only - skip all other outputs",
101 action="store_true", 109 action="store_true",
102 ) 110 )
103 args = parser.parse_args() 111 args = parser.parse_args()
104 self.workdir = Path(args.workdir.rstrip('/')) 112 self.workdir = Path(args.workdir.rstrip('/'))
105 self.output_residue_scores = args.confidence_scores 113 self.output_residue_scores = args.confidence_scores
106 self.output_model_pkls = args.pkl 114 self.output_model_pkls = args.pkl
107 self.output_model_plots = args.plot 115 self.output_model_plots = args.plot
108 self.output_pae = args.pae 116 self.output_pae = args.pae
109 self.plot_msa = args.plot_msa 117 self.plot_msa = args.plot_msa
118 self.collect_msas = args.msa
110 self.model_preset = self._sniff_model_preset() 119 self.model_preset = self._sniff_model_preset()
120 self.is_multimer = self.model_preset == PRESETS.multimer
111 self.output_dir = self.workdir / OUTPUT_DIR 121 self.output_dir = self.workdir / OUTPUT_DIR
122 self.msa_only = args.msa_only
112 os.makedirs(self.output_dir, exist_ok=True) 123 os.makedirs(self.output_dir, exist_ok=True)
113 124
114 def _sniff_model_preset(self) -> bool: 125 def _sniff_model_preset(self) -> bool:
115 """Check if the run was multimer or monomer.""" 126 """Check if the run was multimer or monomer."""
116 for path in self.workdir.glob('*.pkl'): 127 for path in self.workdir.glob('*.pkl'):
118 if '_multimer_' in path.name: 129 if '_multimer_' in path.name:
119 return PRESETS.multimer 130 return PRESETS.multimer
120 if '_ptm_' in path.name: 131 if '_ptm_' in path.name:
121 return PRESETS.monomer_ptm 132 return PRESETS.monomer_ptm
122 return PRESETS.monomer 133 return PRESETS.monomer
134 return PRESETS.monomer
123 135
124 136
125 class ExecutionContext: 137 class ExecutionContext:
126 """Collect file paths etc.""" 138 """Collect file paths etc."""
127 def __init__(self, settings: Settings): 139 def __init__(self, settings: Settings):
128 self.settings = settings 140 self.settings = settings
129 if settings.model_preset == PRESETS.multimer: 141 if settings.is_multimer:
130 self.plddt_key = PLDDT_KEY.multimer 142 self.plddt_key = PLDDT_KEY.multimer
131 else: 143 else:
132 self.plddt_key = PLDDT_KEY.monomer 144 self.plddt_key = PLDDT_KEY.monomer
133 145
134 def get_model_key(self, ix: int) -> str: 146 def get_model_key(self, ix: int) -> str:
376 plt.tight_layout() 388 plt.tight_layout()
377 plt.savefig(wdir / OUTPUTS['msa'], dpi=dpi) 389 plt.savefig(wdir / OUTPUTS['msa'], dpi=dpi)
378 plt.close() 390 plt.close()
379 391
380 392
393 def collect_msas(settings: Settings):
394 """Collect MSA files into ZIP archive(s)."""
395
396 def zip_dir(directory: Path, is_multimer: bool, name: str):
397 chain_id = directory.with_suffix('.zip').stem
398 msa_dir = settings.output_dir / 'msas'
399 msa_dir.mkdir(exist_ok=True)
400 zip_name = (
401 f"MSA-{chain_id}-{name}.zip"
402 if is_multimer
403 else f"MSA-{name}.zip")
404 zip_path = msa_dir / zip_name
405 with zipfile.ZipFile(zip_path, 'w') as z:
406 for path in directory.glob('*'):
407 z.write(path, path.name)
408
409 print("Collecting MSA archives...")
410 chain_names = get_input_sequence_ids(
411 settings.workdir.parent.parent / 'alphafold.fasta')
412 msa_dir = settings.workdir / 'msas'
413 is_multimer = (msa_dir / 'A').exists()
414 if is_multimer:
415 msa_dirs = sorted([
416 path for path in msa_dir.glob('*')
417 if path.is_dir()
418 ])
419 for i, path in enumerate(msa_dirs):
420 zip_dir(path, is_multimer, chain_names[i])
421 else:
422 zip_dir(msa_dir, is_multimer, chain_names[0])
423
424
425 def get_input_sequence_ids(fasta_file: Path) -> List[str]:
426 """Read headers from the input FASTA file.
427 Split them to get a sequence ID and truncate to 20 chars max.
428 """
429 headers = []
430 for line in fasta_file.read_text().split('\n'):
431 if line.startswith('>'):
432 seq_id = line[1:].split(' ')[0]
433 seq_id_trunc = seq_id[:20].strip()
434 if len(seq_id) > 20:
435 seq_id_trunc += '...'
436 headers.append(seq_id_trunc)
437 return headers
438
439
381 def template_html(context: ExecutionContext): 440 def template_html(context: ExecutionContext):
382 """Template HTML file. 441 """Template HTML file.
383 442
384 Remove buttons that are redundant with limited model outputs. 443 Remove buttons that are redundant with limited model outputs.
385 """ 444 """
395 454
396 455
397 def main(): 456 def main():
398 """Parse output files and generate additional output files.""" 457 """Parse output files and generate additional output files."""
399 settings = Settings() 458 settings = Settings()
400 context = ExecutionContext(settings) 459 if not settings.msa_only:
401 ranking = ResultRanking(context) 460 context = ExecutionContext(settings)
402 write_confidence_scores(ranking, context) 461 ranking = ResultRanking(context)
403 rekey_relax_metrics(ranking, context) 462 write_confidence_scores(ranking, context)
404 template_html(context) 463 rekey_relax_metrics(ranking, context)
405 464 template_html(context)
406 # Optional outputs 465
407 if settings.output_model_pkls: 466 # Optional outputs
408 rename_model_pkls(ranking, context) 467 if settings.output_model_pkls:
409 if settings.output_model_plots: 468 rename_model_pkls(ranking, context)
410 plddt_pae_plots(ranking, context) 469 if settings.output_model_plots:
411 if settings.output_pae: 470 plddt_pae_plots(ranking, context)
412 # Only created by monomer_ptm and multimer models 471 if settings.output_pae:
413 extract_pae_to_csv(ranking, context) 472 # Only created by monomer_ptm and multimer models
414 if settings.output_residue_scores: 473 extract_pae_to_csv(ranking, context)
415 write_per_residue_scores(ranking, context) 474 if settings.output_residue_scores:
416 if settings.plot_msa: 475 write_per_residue_scores(ranking, context)
417 plot_msa(context.settings.workdir) 476 if settings.plot_msa:
477 plot_msa(settings.workdir)
478 if settings.collect_msas or settings.msa_only:
479 collect_msas(settings)
418 480
419 481
420 if __name__ == '__main__': 482 if __name__ == '__main__':
421 main() 483 main()