comparison gen_extra_outputs.py @ 0:7ae9d78b06f5 draft

"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit 7b79778448363aa8c9b14604337e81009e461bd2-dirty"
author galaxy-australia
date Fri, 28 Jan 2022 04:56:29 +0000
parents
children 3bd420ec162d
comparison
equal deleted inserted replaced
-1:000000000000 0:7ae9d78b06f5
1
2
3 import json
4 import pickle
5 import argparse
6 from typing import Any, Dict, List
7
8
9 class Settings:
10 """parses then keeps track of program settings"""
11 def __init__(self):
12 self.workdir = None
13 self.output_confidence_scores = True
14 self.output_residue_scores = False
15
16 def parse_settings(self) -> None:
17 parser = argparse.ArgumentParser()
18 parser.add_argument(
19 "workdir",
20 help="alphafold output directory",
21 type=str
22 )
23 parser.add_argument(
24 "-p",
25 "--plddts",
26 help="output per-residue confidence scores (pLDDTs)",
27 action="store_true"
28 )
29 args = parser.parse_args()
30 self.workdir = args.workdir.rstrip('/')
31 self.output_residue_scores = args.plddts
32
33
34 class ExecutionContext:
35 """uses program settings to get paths to files etc"""
36 def __init__(self, settings: Settings):
37 self.settings = settings
38
39 @property
40 def ranking_debug(self) -> str:
41 return f'{self.settings.workdir}/ranking_debug.json'
42
43 @property
44 def model_pkls(self) -> List[str]:
45 return [f'{self.settings.workdir}/result_model_{i}.pkl'
46 for i in range(1, 6)]
47
48 @property
49 def model_conf_score_output(self) -> str:
50 return f'{self.settings.workdir}/model_confidence_scores.tsv'
51
52 @property
53 def plddt_output(self) -> str:
54 return f'{self.settings.workdir}/plddts.tsv'
55
56
57 class FileLoader:
58 """loads file data for use by other classes"""
59 def __init__(self, context: ExecutionContext):
60 self.context = context
61
62 def get_model_mapping(self) -> Dict[str, int]:
63 data = self.load_ranking_debug()
64 return {name: int(rank) + 1
65 for (rank, name) in enumerate(data['order'])}
66
67 def get_conf_scores(self) -> Dict[str, float]:
68 data = self.load_ranking_debug()
69 return {name: float(f'{score:.2f}')
70 for name, score in data['plddts'].items()}
71
72 def load_ranking_debug(self) -> Dict[str, Any]:
73 with open(self.context.ranking_debug, 'r') as fp:
74 return json.load(fp)
75
76 def get_model_plddts(self) -> Dict[str, List[float]]:
77 plddts: Dict[str, List[float]] = {}
78 model_pkls = self.context.model_pkls
79 for i in range(5):
80 pklfile = model_pkls[i]
81 with open(pklfile, 'rb') as fp:
82 data = pickle.load(fp)
83 plddts[f'model_{i+1}'] = [float(f'{x:.2f}') for x in data['plddt']]
84 return plddts
85
86
87 class OutputGenerator:
88 """generates the output data we are interested in creating"""
89 def __init__(self, loader: FileLoader):
90 self.loader = loader
91
92 def gen_conf_scores(self):
93 mapping = self.loader.get_model_mapping()
94 scores = self.loader.get_conf_scores()
95 ranked = list(scores.items())
96 ranked.sort(key=lambda x: x[1], reverse=True)
97 return {f'model_{mapping[name]}': score
98 for name, score in ranked}
99
100 def gen_residue_scores(self) -> Dict[str, List[float]]:
101 mapping = self.loader.get_model_mapping()
102 model_plddts = self.loader.get_model_plddts()
103 return {f'model_{mapping[name]}': plddts
104 for name, plddts in model_plddts.items()}
105
106
107 class OutputWriter:
108 """writes generated data to files"""
109 def __init__(self, context: ExecutionContext):
110 self.context = context
111
112 def write_conf_scores(self, data: Dict[str, float]) -> None:
113 outfile = self.context.model_conf_score_output
114 with open(outfile, 'w') as fp:
115 for model, score in data.items():
116 fp.write(f'{model}\t{score}\n')
117
118 def write_residue_scores(self, data: Dict[str, List[float]]) -> None:
119 outfile = self.context.plddt_output
120 model_plddts = list(data.items())
121 model_plddts.sort()
122
123 with open(outfile, 'w') as fp:
124 for model, plddts in model_plddts:
125 plddt_str_list = [str(x) for x in plddts]
126 plddt_str = ','.join(plddt_str_list)
127 fp.write(f'{model}\t{plddt_str}\n')
128
129
130 def main():
131 # setup
132 settings = Settings()
133 settings.parse_settings()
134 context = ExecutionContext(settings)
135 loader = FileLoader(context)
136
137 # generate & write outputs
138 generator = OutputGenerator(loader)
139 writer = OutputWriter(context)
140
141 # confidence scores
142 conf_scores = generator.gen_conf_scores()
143 writer.write_conf_scores(conf_scores)
144
145 # per-residue plddts
146 if settings.output_residue_scores:
147 residue_scores = generator.gen_residue_scores()
148 writer.write_residue_scores(residue_scores)
149
150
151 if __name__ == '__main__':
152 main()
153
154
155