Mercurial > repos > galaxy-australia > alphafold2
comparison scripts/outputs.py @ 23:2891385d6ace draft default tip
planemo upload for repository https://github.com/usegalaxy-au/tools-au commit b347c6ccc82b14fcbff360b3357050d1d43e3ef5-dirty
author | galaxy-australia |
---|---|
date | Wed, 16 Apr 2025 05:46:58 +0000 |
parents | 3f188450ca4f |
children |
comparison
equal
deleted
inserted
replaced
22:3f188450ca4f | 23:2891385d6ace |
---|---|
18 import json | 18 import json |
19 import numpy as np | 19 import numpy as np |
20 import os | 20 import os |
21 import pickle as pk | 21 import pickle as pk |
22 import shutil | 22 import shutil |
23 import zipfile | |
24 from matplotlib import pyplot as plt | |
23 from pathlib import Path | 25 from pathlib import Path |
24 from typing import Dict, List | 26 from typing import Dict, List |
25 | 27 |
26 from matplotlib import pyplot as plt | |
27 | |
28 # Output file paths | |
29 OUTPUT_DIR = 'extra' | 28 OUTPUT_DIR = 'extra' |
30 OUTPUTS = { | 29 OUTPUTS = { |
31 'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl', | 30 'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl', |
32 'model_pae': OUTPUT_DIR + '/pae_ranked_{rank}.csv', | 31 'model_pae': OUTPUT_DIR + '/pae_ranked_{rank}.csv', |
33 'model_plot': OUTPUT_DIR + '/ranked_{rank}.png', | 32 'model_plot': OUTPUT_DIR + '/ranked_{rank}.png', |
62 """Parse and store settings/config.""" | 61 """Parse and store settings/config.""" |
63 def __init__(self): | 62 def __init__(self): |
64 self.workdir = None | 63 self.workdir = None |
65 self.output_confidence_scores = True | 64 self.output_confidence_scores = True |
66 self.output_residue_scores = False | 65 self.output_residue_scores = False |
67 self.is_multimer = False | |
68 self.parse() | 66 self.parse() |
69 | 67 |
70 def parse(self) -> None: | 68 def parse(self) -> None: |
71 parser = argparse.ArgumentParser() | 69 parser = argparse.ArgumentParser() |
72 parser.add_argument( | 70 parser.add_argument( |
96 action="store_true", | 94 action="store_true", |
97 ) | 95 ) |
98 parser.add_argument( | 96 parser.add_argument( |
99 "--plot-msa", | 97 "--plot-msa", |
100 help="Plot multiple-sequence alignment coverage as a heatmap", | 98 help="Plot multiple-sequence alignment coverage as a heatmap", |
99 action="store_true", | |
100 ) | |
101 parser.add_argument( | |
102 "--msa", | |
103 help="Collect multiple-sequence alignments as ZIP archives", | |
104 action="store_true", | |
105 ) | |
106 parser.add_argument( | |
107 "--msa_only", | |
108 help="Alphafold generated MSA files only - skip all other outputs", | |
101 action="store_true", | 109 action="store_true", |
102 ) | 110 ) |
103 args = parser.parse_args() | 111 args = parser.parse_args() |
104 self.workdir = Path(args.workdir.rstrip('/')) | 112 self.workdir = Path(args.workdir.rstrip('/')) |
105 self.output_residue_scores = args.confidence_scores | 113 self.output_residue_scores = args.confidence_scores |
106 self.output_model_pkls = args.pkl | 114 self.output_model_pkls = args.pkl |
107 self.output_model_plots = args.plot | 115 self.output_model_plots = args.plot |
108 self.output_pae = args.pae | 116 self.output_pae = args.pae |
109 self.plot_msa = args.plot_msa | 117 self.plot_msa = args.plot_msa |
118 self.collect_msas = args.msa | |
110 self.model_preset = self._sniff_model_preset() | 119 self.model_preset = self._sniff_model_preset() |
120 self.is_multimer = self.model_preset == PRESETS.multimer | |
111 self.output_dir = self.workdir / OUTPUT_DIR | 121 self.output_dir = self.workdir / OUTPUT_DIR |
122 self.msa_only = args.msa_only | |
112 os.makedirs(self.output_dir, exist_ok=True) | 123 os.makedirs(self.output_dir, exist_ok=True) |
113 | 124 |
114 def _sniff_model_preset(self) -> bool: | 125 def _sniff_model_preset(self) -> bool: |
115 """Check if the run was multimer or monomer.""" | 126 """Check if the run was multimer or monomer.""" |
116 for path in self.workdir.glob('*.pkl'): | 127 for path in self.workdir.glob('*.pkl'): |
118 if '_multimer_' in path.name: | 129 if '_multimer_' in path.name: |
119 return PRESETS.multimer | 130 return PRESETS.multimer |
120 if '_ptm_' in path.name: | 131 if '_ptm_' in path.name: |
121 return PRESETS.monomer_ptm | 132 return PRESETS.monomer_ptm |
122 return PRESETS.monomer | 133 return PRESETS.monomer |
134 return PRESETS.monomer | |
123 | 135 |
124 | 136 |
125 class ExecutionContext: | 137 class ExecutionContext: |
126 """Collect file paths etc.""" | 138 """Collect file paths etc.""" |
127 def __init__(self, settings: Settings): | 139 def __init__(self, settings: Settings): |
128 self.settings = settings | 140 self.settings = settings |
129 if settings.model_preset == PRESETS.multimer: | 141 if settings.is_multimer: |
130 self.plddt_key = PLDDT_KEY.multimer | 142 self.plddt_key = PLDDT_KEY.multimer |
131 else: | 143 else: |
132 self.plddt_key = PLDDT_KEY.monomer | 144 self.plddt_key = PLDDT_KEY.monomer |
133 | 145 |
134 def get_model_key(self, ix: int) -> str: | 146 def get_model_key(self, ix: int) -> str: |
376 plt.tight_layout() | 388 plt.tight_layout() |
377 plt.savefig(wdir / OUTPUTS['msa'], dpi=dpi) | 389 plt.savefig(wdir / OUTPUTS['msa'], dpi=dpi) |
378 plt.close() | 390 plt.close() |
379 | 391 |
380 | 392 |
393 def collect_msas(settings: Settings): | |
394 """Collect MSA files into ZIP archive(s).""" | |
395 | |
396 def zip_dir(directory: Path, is_multimer: bool, name: str): | |
397 chain_id = directory.with_suffix('.zip').stem | |
398 msa_dir = settings.output_dir / 'msas' | |
399 msa_dir.mkdir(exist_ok=True) | |
400 zip_name = ( | |
401 f"MSA-{chain_id}-{name}.zip" | |
402 if is_multimer | |
403 else f"MSA-{name}.zip") | |
404 zip_path = msa_dir / zip_name | |
405 with zipfile.ZipFile(zip_path, 'w') as z: | |
406 for path in directory.glob('*'): | |
407 z.write(path, path.name) | |
408 | |
409 print("Collecting MSA archives...") | |
410 chain_names = get_input_sequence_ids( | |
411 settings.workdir.parent.parent / 'alphafold.fasta') | |
412 msa_dir = settings.workdir / 'msas' | |
413 is_multimer = (msa_dir / 'A').exists() | |
414 if is_multimer: | |
415 msa_dirs = sorted([ | |
416 path for path in msa_dir.glob('*') | |
417 if path.is_dir() | |
418 ]) | |
419 for i, path in enumerate(msa_dirs): | |
420 zip_dir(path, is_multimer, chain_names[i]) | |
421 else: | |
422 zip_dir(msa_dir, is_multimer, chain_names[0]) | |
423 | |
424 | |
425 def get_input_sequence_ids(fasta_file: Path) -> List[str]: | |
426 """Read headers from the input FASTA file. | |
427 Split them to get a sequence ID and truncate to 20 chars max. | |
428 """ | |
429 headers = [] | |
430 for line in fasta_file.read_text().split('\n'): | |
431 if line.startswith('>'): | |
432 seq_id = line[1:].split(' ')[0] | |
433 seq_id_trunc = seq_id[:20].strip() | |
434 if len(seq_id) > 20: | |
435 seq_id_trunc += '...' | |
436 headers.append(seq_id_trunc) | |
437 return headers | |
438 | |
439 | |
381 def template_html(context: ExecutionContext): | 440 def template_html(context: ExecutionContext): |
382 """Template HTML file. | 441 """Template HTML file. |
383 | 442 |
384 Remove buttons that are redundant with limited model outputs. | 443 Remove buttons that are redundant with limited model outputs. |
385 """ | 444 """ |
395 | 454 |
396 | 455 |
397 def main(): | 456 def main(): |
398 """Parse output files and generate additional output files.""" | 457 """Parse output files and generate additional output files.""" |
399 settings = Settings() | 458 settings = Settings() |
400 context = ExecutionContext(settings) | 459 if not settings.msa_only: |
401 ranking = ResultRanking(context) | 460 context = ExecutionContext(settings) |
402 write_confidence_scores(ranking, context) | 461 ranking = ResultRanking(context) |
403 rekey_relax_metrics(ranking, context) | 462 write_confidence_scores(ranking, context) |
404 template_html(context) | 463 rekey_relax_metrics(ranking, context) |
405 | 464 template_html(context) |
406 # Optional outputs | 465 |
407 if settings.output_model_pkls: | 466 # Optional outputs |
408 rename_model_pkls(ranking, context) | 467 if settings.output_model_pkls: |
409 if settings.output_model_plots: | 468 rename_model_pkls(ranking, context) |
410 plddt_pae_plots(ranking, context) | 469 if settings.output_model_plots: |
411 if settings.output_pae: | 470 plddt_pae_plots(ranking, context) |
412 # Only created by monomer_ptm and multimer models | 471 if settings.output_pae: |
413 extract_pae_to_csv(ranking, context) | 472 # Only created by monomer_ptm and multimer models |
414 if settings.output_residue_scores: | 473 extract_pae_to_csv(ranking, context) |
415 write_per_residue_scores(ranking, context) | 474 if settings.output_residue_scores: |
416 if settings.plot_msa: | 475 write_per_residue_scores(ranking, context) |
417 plot_msa(context.settings.workdir) | 476 if settings.plot_msa: |
477 plot_msa(settings.workdir) | |
478 if settings.collect_msas or settings.msa_only: | |
479 collect_msas(settings) | |
418 | 480 |
419 | 481 |
420 if __name__ == '__main__': | 482 if __name__ == '__main__': |
421 main() | 483 main() |