diff gen_extra_outputs.py @ 9:3bd420ec162d draft

planemo upload for repository https://github.com/usegalaxy-au/tools-au commit 7726c3cba165bdc8fc6366ec0ce6596e55657468
author galaxy-australia
date Tue, 13 Sep 2022 22:04:12 +0000
parents 7ae9d78b06f5
children c0e71cb2bd1b
line wrap: on
line diff
--- a/gen_extra_outputs.py	Fri Aug 19 00:29:16 2022 +0000
+++ b/gen_extra_outputs.py	Tue Sep 13 22:04:12 2022 +0000
@@ -1,10 +1,17 @@
-
+"""Generate additional output files not produced by AlphaFold."""
 
 import json
 import pickle
 import argparse
 from typing import Any, Dict, List
 
+# Keys for accessing confidence data from JSON/pkl files
+# They change depending on whether the run was monomer or multimer
+CONTEXT_KEY = {
+    'monomer': 'plddts',
+    'multimer': 'iptm+ptm',
+}
+
 
 class Settings:
     """parses then keeps track of program settings"""
@@ -12,23 +19,31 @@
         self.workdir = None
         self.output_confidence_scores = True
         self.output_residue_scores = False
+        self.is_multimer = False
 
     def parse_settings(self) -> None:
         parser = argparse.ArgumentParser()
         parser.add_argument(
-            "workdir", 
-            help="alphafold output directory", 
+            "workdir",
+            help="alphafold output directory",
             type=str
-        )   
+        )
         parser.add_argument(
             "-p",
             "--plddts",
-            help="output per-residue confidence scores (pLDDTs)", 
+            help="output per-residue confidence scores (pLDDTs)",
+            action="store_true"
+        )
+        parser.add_argument(
+            "--multimer",
+            help="parse output from AlphaFold multimer",
             action="store_true"
         )
         args = parser.parse_args()
         self.workdir = args.workdir.rstrip('/')
         self.output_residue_scores = args.plddts
+        self.is_multimer = False
+        self.is_multimer = args.multimer
 
 
 class ExecutionContext:
@@ -42,8 +57,13 @@
 
     @property
     def model_pkls(self) -> List[str]:
-        return [f'{self.settings.workdir}/result_model_{i}.pkl'
-                for i in range(1, 6)]
+        ext = '.pkl'
+        if self.settings.is_multimer:
+            ext = '_multimer.pkl'
+        return [
+            f'{self.settings.workdir}/result_model_{i}{ext}'
+            for i in range(1, 6)
+        ]
 
     @property
     def model_conf_score_output(self) -> str:
@@ -56,18 +76,28 @@
 
 class FileLoader:
     """loads file data for use by other classes"""
+
     def __init__(self, context: ExecutionContext):
         self.context = context
 
+    @property
+    def confidence_key(self) -> str:
+        """Return the correct key for confidence data."""
+        if self.context.settings.is_multimer:
+            return CONTEXT_KEY['multimer']
+        return CONTEXT_KEY['monomer']
+
     def get_model_mapping(self) -> Dict[str, int]:
         data = self.load_ranking_debug()
-        return {name: int(rank) + 1 
+        return {name: int(rank) + 1
                 for (rank, name) in enumerate(data['order'])}
 
     def get_conf_scores(self) -> Dict[str, float]:
         data = self.load_ranking_debug()
-        return {name: float(f'{score:.2f}') 
-                for name, score in data['plddts'].items()}
+        return {
+            name: float(f'{score:.2f}')
+            for name, score in data[self.confidence_key].items()
+        }
 
     def load_ranking_debug(self) -> Dict[str, Any]:
         with open(self.context.ranking_debug, 'r') as fp:
@@ -76,11 +106,14 @@
     def get_model_plddts(self) -> Dict[str, List[float]]:
         plddts: Dict[str, List[float]] = {}
         model_pkls = self.context.model_pkls
-        for i in range(5):
+        for i in range(len(model_pkls)):
             pklfile = model_pkls[i]
             with open(pklfile, 'rb') as fp:
                 data = pickle.load(fp)
-                plddts[f'model_{i+1}'] = [float(f'{x:.2f}') for x in data['plddt']]
+            plddts[f'model_{i+1}'] = [
+                float(f'{x:.2f}')
+                for x in data['plddt']
+            ]
         return plddts
 
 
@@ -94,13 +127,13 @@
         scores = self.loader.get_conf_scores()
         ranked = list(scores.items())
         ranked.sort(key=lambda x: x[1], reverse=True)
-        return {f'model_{mapping[name]}': score 
+        return {f'model_{mapping[name]}': score
                 for name, score in ranked}
 
     def gen_residue_scores(self) -> Dict[str, List[float]]:
         mapping = self.loader.get_model_mapping()
         model_plddts = self.loader.get_model_plddts()
-        return {f'model_{mapping[name]}': plddts 
+        return {f'model_{mapping[name]}': plddts
                 for name, plddts in model_plddts.items()}
 
 
@@ -114,7 +147,7 @@
         with open(outfile, 'w') as fp:
             for model, score in data.items():
                 fp.write(f'{model}\t{score}\n')
-    
+
     def write_residue_scores(self, data: Dict[str, List[float]]) -> None:
         outfile = self.context.plddt_output
         model_plddts = list(data.items())
@@ -133,23 +166,20 @@
     settings.parse_settings()
     context = ExecutionContext(settings)
     loader = FileLoader(context)
-    
+
     # generate & write outputs
     generator = OutputGenerator(loader)
     writer = OutputWriter(context)
-    
+
     # confidence scores
     conf_scores = generator.gen_conf_scores()
     writer.write_conf_scores(conf_scores)
-    
+
     # per-residue plddts
     if settings.output_residue_scores:
         residue_scores = generator.gen_residue_scores()
         writer.write_residue_scores(residue_scores)
 
-    
+
 if __name__ == '__main__':
     main()
-
-
-