Mercurial > repos > galaxy-australia > alphafold2
diff scripts/outputs.py @ 23:2891385d6ace draft default tip
planemo upload for repository https://github.com/usegalaxy-au/tools-au commit b347c6ccc82b14fcbff360b3357050d1d43e3ef5-dirty
author | galaxy-australia |
---|---|
date | Wed, 16 Apr 2025 05:46:58 +0000 |
parents | 3f188450ca4f |
children |
line wrap: on
line diff
--- a/scripts/outputs.py Wed Oct 30 21:46:34 2024 +0000 +++ b/scripts/outputs.py Wed Apr 16 05:46:58 2025 +0000 @@ -20,12 +20,11 @@ import os import pickle as pk import shutil +import zipfile +from matplotlib import pyplot as plt from pathlib import Path from typing import Dict, List -from matplotlib import pyplot as plt - -# Output file paths OUTPUT_DIR = 'extra' OUTPUTS = { 'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl', @@ -64,7 +63,6 @@ self.workdir = None self.output_confidence_scores = True self.output_residue_scores = False - self.is_multimer = False self.parse() def parse(self) -> None: @@ -100,6 +98,16 @@ help="Plot multiple-sequence alignment coverage as a heatmap", action="store_true", ) + parser.add_argument( + "--msa", + help="Collect multiple-sequence alignments as ZIP archives", + action="store_true", + ) + parser.add_argument( + "--msa_only", + help="Alphafold generated MSA files only - skip all other outputs", + action="store_true", + ) args = parser.parse_args() self.workdir = Path(args.workdir.rstrip('/')) self.output_residue_scores = args.confidence_scores @@ -107,8 +115,11 @@ self.output_model_plots = args.plot self.output_pae = args.pae self.plot_msa = args.plot_msa + self.collect_msas = args.msa self.model_preset = self._sniff_model_preset() + self.is_multimer = self.model_preset == PRESETS.multimer self.output_dir = self.workdir / OUTPUT_DIR + self.msa_only = args.msa_only os.makedirs(self.output_dir, exist_ok=True) def _sniff_model_preset(self) -> bool: @@ -120,13 +131,14 @@ if '_ptm_' in path.name: return PRESETS.monomer_ptm return PRESETS.monomer + return PRESETS.monomer class ExecutionContext: """Collect file paths etc.""" def __init__(self, settings: Settings): self.settings = settings - if settings.model_preset == PRESETS.multimer: + if settings.is_multimer: self.plddt_key = PLDDT_KEY.multimer else: self.plddt_key = PLDDT_KEY.monomer @@ -378,6 +390,53 @@ plt.close() +def collect_msas(settings: Settings): + """Collect MSA files into ZIP archive(s).""" + + def zip_dir(directory: Path, is_multimer: bool, name: str): + chain_id = directory.with_suffix('.zip').stem + msa_dir = settings.output_dir / 'msas' + msa_dir.mkdir(exist_ok=True) + zip_name = ( + f"MSA-{chain_id}-{name}.zip" + if is_multimer + else f"MSA-{name}.zip") + zip_path = msa_dir / zip_name + with zipfile.ZipFile(zip_path, 'w') as z: + for path in directory.glob('*'): + z.write(path, path.name) + + print("Collecting MSA archives...") + chain_names = get_input_sequence_ids( + settings.workdir.parent.parent / 'alphafold.fasta') + msa_dir = settings.workdir / 'msas' + is_multimer = (msa_dir / 'A').exists() + if is_multimer: + msa_dirs = sorted([ + path for path in msa_dir.glob('*') + if path.is_dir() + ]) + for i, path in enumerate(msa_dirs): + zip_dir(path, is_multimer, chain_names[i]) + else: + zip_dir(msa_dir, is_multimer, chain_names[0]) + + +def get_input_sequence_ids(fasta_file: Path) -> List[str]: + """Read headers from the input FASTA file. + Split them to get a sequence ID and truncate to 20 chars max. + """ + headers = [] + for line in fasta_file.read_text().split('\n'): + if line.startswith('>'): + seq_id = line[1:].split(' ')[0] + seq_id_trunc = seq_id[:20].strip() + if len(seq_id) > 20: + seq_id_trunc += '...' + headers.append(seq_id_trunc) + return headers + + def template_html(context: ExecutionContext): """Template HTML file. @@ -397,24 +456,27 @@ def main(): """Parse output files and generate additional output files.""" settings = Settings() - context = ExecutionContext(settings) - ranking = ResultRanking(context) - write_confidence_scores(ranking, context) - rekey_relax_metrics(ranking, context) - template_html(context) + if not settings.msa_only: + context = ExecutionContext(settings) + ranking = ResultRanking(context) + write_confidence_scores(ranking, context) + rekey_relax_metrics(ranking, context) + template_html(context) - # Optional outputs - if settings.output_model_pkls: - rename_model_pkls(ranking, context) - if settings.output_model_plots: - plddt_pae_plots(ranking, context) - if settings.output_pae: - # Only created by monomer_ptm and multimer models - extract_pae_to_csv(ranking, context) - if settings.output_residue_scores: - write_per_residue_scores(ranking, context) - if settings.plot_msa: - plot_msa(context.settings.workdir) + # Optional outputs + if settings.output_model_pkls: + rename_model_pkls(ranking, context) + if settings.output_model_plots: + plddt_pae_plots(ranking, context) + if settings.output_pae: + # Only created by monomer_ptm and multimer models + extract_pae_to_csv(ranking, context) + if settings.output_residue_scores: + write_per_residue_scores(ranking, context) + if settings.plot_msa: + plot_msa(settings.workdir) + if settings.collect_msas or settings.msa_only: + collect_msas(settings) if __name__ == '__main__':