Mercurial > repos > galaxy-australia > alphafold2
comparison gen_extra_outputs.py @ 0:7ae9d78b06f5 draft
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit 7b79778448363aa8c9b14604337e81009e461bd2-dirty"
author | galaxy-australia |
---|---|
date | Fri, 28 Jan 2022 04:56:29 +0000 |
parents | |
children | 3bd420ec162d |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:7ae9d78b06f5 |
---|---|
1 | |
2 | |
3 import json | |
4 import pickle | |
5 import argparse | |
6 from typing import Any, Dict, List | |
7 | |
8 | |
9 class Settings: | |
10 """parses then keeps track of program settings""" | |
11 def __init__(self): | |
12 self.workdir = None | |
13 self.output_confidence_scores = True | |
14 self.output_residue_scores = False | |
15 | |
16 def parse_settings(self) -> None: | |
17 parser = argparse.ArgumentParser() | |
18 parser.add_argument( | |
19 "workdir", | |
20 help="alphafold output directory", | |
21 type=str | |
22 ) | |
23 parser.add_argument( | |
24 "-p", | |
25 "--plddts", | |
26 help="output per-residue confidence scores (pLDDTs)", | |
27 action="store_true" | |
28 ) | |
29 args = parser.parse_args() | |
30 self.workdir = args.workdir.rstrip('/') | |
31 self.output_residue_scores = args.plddts | |
32 | |
33 | |
34 class ExecutionContext: | |
35 """uses program settings to get paths to files etc""" | |
36 def __init__(self, settings: Settings): | |
37 self.settings = settings | |
38 | |
39 @property | |
40 def ranking_debug(self) -> str: | |
41 return f'{self.settings.workdir}/ranking_debug.json' | |
42 | |
43 @property | |
44 def model_pkls(self) -> List[str]: | |
45 return [f'{self.settings.workdir}/result_model_{i}.pkl' | |
46 for i in range(1, 6)] | |
47 | |
48 @property | |
49 def model_conf_score_output(self) -> str: | |
50 return f'{self.settings.workdir}/model_confidence_scores.tsv' | |
51 | |
52 @property | |
53 def plddt_output(self) -> str: | |
54 return f'{self.settings.workdir}/plddts.tsv' | |
55 | |
56 | |
57 class FileLoader: | |
58 """loads file data for use by other classes""" | |
59 def __init__(self, context: ExecutionContext): | |
60 self.context = context | |
61 | |
62 def get_model_mapping(self) -> Dict[str, int]: | |
63 data = self.load_ranking_debug() | |
64 return {name: int(rank) + 1 | |
65 for (rank, name) in enumerate(data['order'])} | |
66 | |
67 def get_conf_scores(self) -> Dict[str, float]: | |
68 data = self.load_ranking_debug() | |
69 return {name: float(f'{score:.2f}') | |
70 for name, score in data['plddts'].items()} | |
71 | |
72 def load_ranking_debug(self) -> Dict[str, Any]: | |
73 with open(self.context.ranking_debug, 'r') as fp: | |
74 return json.load(fp) | |
75 | |
76 def get_model_plddts(self) -> Dict[str, List[float]]: | |
77 plddts: Dict[str, List[float]] = {} | |
78 model_pkls = self.context.model_pkls | |
79 for i in range(5): | |
80 pklfile = model_pkls[i] | |
81 with open(pklfile, 'rb') as fp: | |
82 data = pickle.load(fp) | |
83 plddts[f'model_{i+1}'] = [float(f'{x:.2f}') for x in data['plddt']] | |
84 return plddts | |
85 | |
86 | |
87 class OutputGenerator: | |
88 """generates the output data we are interested in creating""" | |
89 def __init__(self, loader: FileLoader): | |
90 self.loader = loader | |
91 | |
92 def gen_conf_scores(self): | |
93 mapping = self.loader.get_model_mapping() | |
94 scores = self.loader.get_conf_scores() | |
95 ranked = list(scores.items()) | |
96 ranked.sort(key=lambda x: x[1], reverse=True) | |
97 return {f'model_{mapping[name]}': score | |
98 for name, score in ranked} | |
99 | |
100 def gen_residue_scores(self) -> Dict[str, List[float]]: | |
101 mapping = self.loader.get_model_mapping() | |
102 model_plddts = self.loader.get_model_plddts() | |
103 return {f'model_{mapping[name]}': plddts | |
104 for name, plddts in model_plddts.items()} | |
105 | |
106 | |
107 class OutputWriter: | |
108 """writes generated data to files""" | |
109 def __init__(self, context: ExecutionContext): | |
110 self.context = context | |
111 | |
112 def write_conf_scores(self, data: Dict[str, float]) -> None: | |
113 outfile = self.context.model_conf_score_output | |
114 with open(outfile, 'w') as fp: | |
115 for model, score in data.items(): | |
116 fp.write(f'{model}\t{score}\n') | |
117 | |
118 def write_residue_scores(self, data: Dict[str, List[float]]) -> None: | |
119 outfile = self.context.plddt_output | |
120 model_plddts = list(data.items()) | |
121 model_plddts.sort() | |
122 | |
123 with open(outfile, 'w') as fp: | |
124 for model, plddts in model_plddts: | |
125 plddt_str_list = [str(x) for x in plddts] | |
126 plddt_str = ','.join(plddt_str_list) | |
127 fp.write(f'{model}\t{plddt_str}\n') | |
128 | |
129 | |
130 def main(): | |
131 # setup | |
132 settings = Settings() | |
133 settings.parse_settings() | |
134 context = ExecutionContext(settings) | |
135 loader = FileLoader(context) | |
136 | |
137 # generate & write outputs | |
138 generator = OutputGenerator(loader) | |
139 writer = OutputWriter(context) | |
140 | |
141 # confidence scores | |
142 conf_scores = generator.gen_conf_scores() | |
143 writer.write_conf_scores(conf_scores) | |
144 | |
145 # per-residue plddts | |
146 if settings.output_residue_scores: | |
147 residue_scores = generator.gen_residue_scores() | |
148 writer.write_residue_scores(residue_scores) | |
149 | |
150 | |
151 if __name__ == '__main__': | |
152 main() | |
153 | |
154 | |
155 |