diff scripts/outputs.py @ 24:31f648b7555a draft

planemo upload for repository https://github.com/usegalaxy-au/tools-au commit 44db277529c0e189149235cf60a627193a792fba
author galaxy-australia
date Sat, 05 Jul 2025 03:56:38 +0000
parents 2891385d6ace
children
line wrap: on
line diff
--- a/scripts/outputs.py	Wed Apr 16 05:46:58 2025 +0000
+++ b/scripts/outputs.py	Sat Jul 05 03:56:38 2025 +0000
@@ -20,11 +20,12 @@
 import os
 import pickle as pk
 import shutil
-import zipfile
-from matplotlib import pyplot as plt
 from pathlib import Path
 from typing import Dict, List
 
+from matplotlib import pyplot as plt
+
+# Output file paths
 OUTPUT_DIR = 'extra'
 OUTPUTS = {
     'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl',
@@ -63,6 +64,7 @@
         self.workdir = None
         self.output_confidence_scores = True
         self.output_residue_scores = False
+        self.is_multimer = False
         self.parse()
 
     def parse(self) -> None:
@@ -98,16 +100,6 @@
             help="Plot multiple-sequence alignment coverage as a heatmap",
             action="store_true",
         )
-        parser.add_argument(
-            "--msa",
-            help="Collect multiple-sequence alignments as ZIP archives",
-            action="store_true",
-        )
-        parser.add_argument(
-            "--msa_only",
-            help="Alphafold generated MSA files only - skip all other outputs",
-            action="store_true",
-        )
         args = parser.parse_args()
         self.workdir = Path(args.workdir.rstrip('/'))
         self.output_residue_scores = args.confidence_scores
@@ -115,11 +107,8 @@
         self.output_model_plots = args.plot
         self.output_pae = args.pae
         self.plot_msa = args.plot_msa
-        self.collect_msas = args.msa
         self.model_preset = self._sniff_model_preset()
-        self.is_multimer = self.model_preset == PRESETS.multimer
         self.output_dir = self.workdir / OUTPUT_DIR
-        self.msa_only = args.msa_only
         os.makedirs(self.output_dir, exist_ok=True)
 
     def _sniff_model_preset(self) -> bool:
@@ -131,14 +120,13 @@
                 if '_ptm_' in path.name:
                     return PRESETS.monomer_ptm
                 return PRESETS.monomer
-        return PRESETS.monomer
 
 
 class ExecutionContext:
     """Collect file paths etc."""
     def __init__(self, settings: Settings):
         self.settings = settings
-        if settings.is_multimer:
+        if settings.model_preset == PRESETS.multimer:
             self.plddt_key = PLDDT_KEY.multimer
         else:
             self.plddt_key = PLDDT_KEY.monomer
@@ -211,7 +199,7 @@
 
     def get_plddt_for_rank(self, rank: int) -> List[float]:
         """Get pLDDT score for model instance."""
-        return self.data[self.context.plddt_key][self.data['order'][rank - 1]]
+        return self.data[self.context.plddt_key][self.data['order'][rank]]
 
     def get_rank_for_model(self, model_name: str) -> int:
         """Return 0-indexed rank for given model name.
@@ -228,7 +216,8 @@
     header = ['model', context.plddt_key]
 
     for i, path in enumerate(context.model_pkl_paths):
-        rank = int(path.name.split('model_')[-1][0])
+        model_name = 'model_' + path.stem.split('model_')[1]
+        rank = ranking.get_rank_for_model(model_name)
         scores_ls = [ranking.get_plddt_for_rank(rank)]
         with open(path, 'rb') as f:
             data = pk.load(f)
@@ -244,8 +233,9 @@
 
     with open(outfile, 'w') as f:
         f.write('\t'.join(header) + '\n')
-        for rank, score_ls in scores.items():
-            row = [f"ranked_{rank - 1}"] + [str(x) for x in score_ls]
+        for rank in sorted(scores):
+            score_ls = scores[rank]
+            row = [f"ranked_{rank}"] + [str(x) for x in score_ls]
             f.write('\t'.join(row) + '\n')
 
 
@@ -390,53 +380,6 @@
     plt.close()
 
 
-def collect_msas(settings: Settings):
-    """Collect MSA files into ZIP archive(s)."""
-
-    def zip_dir(directory: Path, is_multimer: bool, name: str):
-        chain_id = directory.with_suffix('.zip').stem
-        msa_dir = settings.output_dir / 'msas'
-        msa_dir.mkdir(exist_ok=True)
-        zip_name = (
-            f"MSA-{chain_id}-{name}.zip"
-            if is_multimer
-            else f"MSA-{name}.zip")
-        zip_path = msa_dir / zip_name
-        with zipfile.ZipFile(zip_path, 'w') as z:
-            for path in directory.glob('*'):
-                z.write(path, path.name)
-
-    print("Collecting MSA archives...")
-    chain_names = get_input_sequence_ids(
-        settings.workdir.parent.parent / 'alphafold.fasta')
-    msa_dir = settings.workdir / 'msas'
-    is_multimer = (msa_dir / 'A').exists()
-    if is_multimer:
-        msa_dirs = sorted([
-            path for path in msa_dir.glob('*')
-            if path.is_dir()
-        ])
-        for i, path in enumerate(msa_dirs):
-            zip_dir(path, is_multimer, chain_names[i])
-    else:
-        zip_dir(msa_dir, is_multimer, chain_names[0])
-
-
-def get_input_sequence_ids(fasta_file: Path) -> List[str]:
-    """Read headers from the input FASTA file.
-    Split them to get a sequence ID and truncate to 20 chars max.
-    """
-    headers = []
-    for line in fasta_file.read_text().split('\n'):
-        if line.startswith('>'):
-            seq_id = line[1:].split(' ')[0]
-            seq_id_trunc = seq_id[:20].strip()
-            if len(seq_id) > 20:
-                seq_id_trunc += '...'
-            headers.append(seq_id_trunc)
-    return headers
-
-
 def template_html(context: ExecutionContext):
     """Template HTML file.
 
@@ -456,27 +399,24 @@
 def main():
     """Parse output files and generate additional output files."""
     settings = Settings()
-    if not settings.msa_only:
-        context = ExecutionContext(settings)
-        ranking = ResultRanking(context)
-        write_confidence_scores(ranking, context)
-        rekey_relax_metrics(ranking, context)
-        template_html(context)
+    context = ExecutionContext(settings)
+    ranking = ResultRanking(context)
+    write_confidence_scores(ranking, context)
+    rekey_relax_metrics(ranking, context)
+    template_html(context)
 
-        # Optional outputs
-        if settings.output_model_pkls:
-            rename_model_pkls(ranking, context)
-        if settings.output_model_plots:
-            plddt_pae_plots(ranking, context)
-        if settings.output_pae:
-            # Only created by monomer_ptm and multimer models
-            extract_pae_to_csv(ranking, context)
-        if settings.output_residue_scores:
-            write_per_residue_scores(ranking, context)
-        if settings.plot_msa:
-            plot_msa(settings.workdir)
-    if settings.collect_msas or settings.msa_only:
-        collect_msas(settings)
+    # Optional outputs
+    if settings.output_model_pkls:
+        rename_model_pkls(ranking, context)
+    if settings.output_model_plots:
+        plddt_pae_plots(ranking, context)
+    if settings.output_pae:
+        # Only created by monomer_ptm and multimer models
+        extract_pae_to_csv(ranking, context)
+    if settings.output_residue_scores:
+        write_per_residue_scores(ranking, context)
+    if settings.plot_msa:
+        plot_msa(context.settings.workdir)
 
 
 if __name__ == '__main__':