diff scripts/outputs.py @ 23:2891385d6ace draft default tip

planemo upload for repository https://github.com/usegalaxy-au/tools-au commit b347c6ccc82b14fcbff360b3357050d1d43e3ef5-dirty
author galaxy-australia
date Wed, 16 Apr 2025 05:46:58 +0000
parents 3f188450ca4f
children
line wrap: on
line diff
--- a/scripts/outputs.py	Wed Oct 30 21:46:34 2024 +0000
+++ b/scripts/outputs.py	Wed Apr 16 05:46:58 2025 +0000
@@ -20,12 +20,11 @@
 import os
 import pickle as pk
 import shutil
+import zipfile
+from matplotlib import pyplot as plt
 from pathlib import Path
 from typing import Dict, List
 
-from matplotlib import pyplot as plt
-
-# Output file paths
 OUTPUT_DIR = 'extra'
 OUTPUTS = {
     'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl',
@@ -64,7 +63,6 @@
         self.workdir = None
         self.output_confidence_scores = True
         self.output_residue_scores = False
-        self.is_multimer = False
         self.parse()
 
     def parse(self) -> None:
@@ -100,6 +98,16 @@
             help="Plot multiple-sequence alignment coverage as a heatmap",
             action="store_true",
         )
+        parser.add_argument(
+            "--msa",
+            help="Collect multiple-sequence alignments as ZIP archives",
+            action="store_true",
+        )
+        parser.add_argument(
+            "--msa_only",
+            help="Alphafold generated MSA files only - skip all other outputs",
+            action="store_true",
+        )
         args = parser.parse_args()
         self.workdir = Path(args.workdir.rstrip('/'))
         self.output_residue_scores = args.confidence_scores
@@ -107,8 +115,11 @@
         self.output_model_plots = args.plot
         self.output_pae = args.pae
         self.plot_msa = args.plot_msa
+        self.collect_msas = args.msa
         self.model_preset = self._sniff_model_preset()
+        self.is_multimer = self.model_preset == PRESETS.multimer
         self.output_dir = self.workdir / OUTPUT_DIR
+        self.msa_only = args.msa_only
         os.makedirs(self.output_dir, exist_ok=True)
 
     def _sniff_model_preset(self) -> bool:
@@ -120,13 +131,14 @@
                 if '_ptm_' in path.name:
                     return PRESETS.monomer_ptm
                 return PRESETS.monomer
+        return PRESETS.monomer
 
 
 class ExecutionContext:
     """Collect file paths etc."""
     def __init__(self, settings: Settings):
         self.settings = settings
-        if settings.model_preset == PRESETS.multimer:
+        if settings.is_multimer:
             self.plddt_key = PLDDT_KEY.multimer
         else:
             self.plddt_key = PLDDT_KEY.monomer
@@ -378,6 +390,53 @@
     plt.close()
 
 
+def collect_msas(settings: Settings):
+    """Collect MSA files into ZIP archive(s)."""
+
+    def zip_dir(directory: Path, is_multimer: bool, name: str):
+        chain_id = directory.with_suffix('.zip').stem
+        msa_dir = settings.output_dir / 'msas'
+        msa_dir.mkdir(exist_ok=True)
+        zip_name = (
+            f"MSA-{chain_id}-{name}.zip"
+            if is_multimer
+            else f"MSA-{name}.zip")
+        zip_path = msa_dir / zip_name
+        with zipfile.ZipFile(zip_path, 'w') as z:
+            for path in directory.glob('*'):
+                z.write(path, path.name)
+
+    print("Collecting MSA archives...")
+    chain_names = get_input_sequence_ids(
+        settings.workdir.parent.parent / 'alphafold.fasta')
+    msa_dir = settings.workdir / 'msas'
+    is_multimer = (msa_dir / 'A').exists()
+    if is_multimer:
+        msa_dirs = sorted([
+            path for path in msa_dir.glob('*')
+            if path.is_dir()
+        ])
+        for i, path in enumerate(msa_dirs):
+            zip_dir(path, is_multimer, chain_names[i])
+    else:
+        zip_dir(msa_dir, is_multimer, chain_names[0])
+
+
+def get_input_sequence_ids(fasta_file: Path) -> List[str]:
+    """Read headers from the input FASTA file.
+    Split them to get a sequence ID and truncate to 20 chars max.
+    """
+    headers = []
+    for line in fasta_file.read_text().split('\n'):
+        if line.startswith('>'):
+            seq_id = line[1:].split(' ')[0]
+            seq_id_trunc = seq_id[:20].strip()
+            if len(seq_id) > 20:
+                seq_id_trunc += '...'
+            headers.append(seq_id_trunc)
+    return headers
+
+
 def template_html(context: ExecutionContext):
     """Template HTML file.
 
@@ -397,24 +456,27 @@
 def main():
     """Parse output files and generate additional output files."""
     settings = Settings()
-    context = ExecutionContext(settings)
-    ranking = ResultRanking(context)
-    write_confidence_scores(ranking, context)
-    rekey_relax_metrics(ranking, context)
-    template_html(context)
+    if not settings.msa_only:
+        context = ExecutionContext(settings)
+        ranking = ResultRanking(context)
+        write_confidence_scores(ranking, context)
+        rekey_relax_metrics(ranking, context)
+        template_html(context)
 
-    # Optional outputs
-    if settings.output_model_pkls:
-        rename_model_pkls(ranking, context)
-    if settings.output_model_plots:
-        plddt_pae_plots(ranking, context)
-    if settings.output_pae:
-        # Only created by monomer_ptm and multimer models
-        extract_pae_to_csv(ranking, context)
-    if settings.output_residue_scores:
-        write_per_residue_scores(ranking, context)
-    if settings.plot_msa:
-        plot_msa(context.settings.workdir)
+        # Optional outputs
+        if settings.output_model_pkls:
+            rename_model_pkls(ranking, context)
+        if settings.output_model_plots:
+            plddt_pae_plots(ranking, context)
+        if settings.output_pae:
+            # Only created by monomer_ptm and multimer models
+            extract_pae_to_csv(ranking, context)
+        if settings.output_residue_scores:
+            write_per_residue_scores(ranking, context)
+        if settings.plot_msa:
+            plot_msa(settings.workdir)
+    if settings.collect_msas or settings.msa_only:
+        collect_msas(settings)
 
 
 if __name__ == '__main__':