Mercurial > repos > galaxy-australia > alphafold2
view gen_extra_outputs.py @ 10:072c324f20fc draft
planemo upload for repository https://github.com/usegalaxy-au/tools-au commit c8f9b460b5c5f2ef0344719d38f6be9d1a3da573
author | galaxy-australia |
---|---|
date | Fri, 16 Sep 2022 02:19:37 +0000 |
parents | 3bd420ec162d |
children | c0e71cb2bd1b |
line wrap: on
line source
"""Generate additional output files not produced by AlphaFold.""" import json import pickle import argparse from typing import Any, Dict, List # Keys for accessing confidence data from JSON/pkl files # They change depending on whether the run was monomer or multimer CONTEXT_KEY = { 'monomer': 'plddts', 'multimer': 'iptm+ptm', } class Settings: """parses then keeps track of program settings""" def __init__(self): self.workdir = None self.output_confidence_scores = True self.output_residue_scores = False self.is_multimer = False def parse_settings(self) -> None: parser = argparse.ArgumentParser() parser.add_argument( "workdir", help="alphafold output directory", type=str ) parser.add_argument( "-p", "--plddts", help="output per-residue confidence scores (pLDDTs)", action="store_true" ) parser.add_argument( "--multimer", help="parse output from AlphaFold multimer", action="store_true" ) args = parser.parse_args() self.workdir = args.workdir.rstrip('/') self.output_residue_scores = args.plddts self.is_multimer = False self.is_multimer = args.multimer class ExecutionContext: """uses program settings to get paths to files etc""" def __init__(self, settings: Settings): self.settings = settings @property def ranking_debug(self) -> str: return f'{self.settings.workdir}/ranking_debug.json' @property def model_pkls(self) -> List[str]: ext = '.pkl' if self.settings.is_multimer: ext = '_multimer.pkl' return [ f'{self.settings.workdir}/result_model_{i}{ext}' for i in range(1, 6) ] @property def model_conf_score_output(self) -> str: return f'{self.settings.workdir}/model_confidence_scores.tsv' @property def plddt_output(self) -> str: return f'{self.settings.workdir}/plddts.tsv' class FileLoader: """loads file data for use by other classes""" def __init__(self, context: ExecutionContext): self.context = context @property def confidence_key(self) -> str: """Return the correct key for confidence data.""" if self.context.settings.is_multimer: return CONTEXT_KEY['multimer'] return CONTEXT_KEY['monomer'] def get_model_mapping(self) -> Dict[str, int]: data = self.load_ranking_debug() return {name: int(rank) + 1 for (rank, name) in enumerate(data['order'])} def get_conf_scores(self) -> Dict[str, float]: data = self.load_ranking_debug() return { name: float(f'{score:.2f}') for name, score in data[self.confidence_key].items() } def load_ranking_debug(self) -> Dict[str, Any]: with open(self.context.ranking_debug, 'r') as fp: return json.load(fp) def get_model_plddts(self) -> Dict[str, List[float]]: plddts: Dict[str, List[float]] = {} model_pkls = self.context.model_pkls for i in range(len(model_pkls)): pklfile = model_pkls[i] with open(pklfile, 'rb') as fp: data = pickle.load(fp) plddts[f'model_{i+1}'] = [ float(f'{x:.2f}') for x in data['plddt'] ] return plddts class OutputGenerator: """generates the output data we are interested in creating""" def __init__(self, loader: FileLoader): self.loader = loader def gen_conf_scores(self): mapping = self.loader.get_model_mapping() scores = self.loader.get_conf_scores() ranked = list(scores.items()) ranked.sort(key=lambda x: x[1], reverse=True) return {f'model_{mapping[name]}': score for name, score in ranked} def gen_residue_scores(self) -> Dict[str, List[float]]: mapping = self.loader.get_model_mapping() model_plddts = self.loader.get_model_plddts() return {f'model_{mapping[name]}': plddts for name, plddts in model_plddts.items()} class OutputWriter: """writes generated data to files""" def __init__(self, context: ExecutionContext): self.context = context def write_conf_scores(self, data: Dict[str, float]) -> None: outfile = self.context.model_conf_score_output with open(outfile, 'w') as fp: for model, score in data.items(): fp.write(f'{model}\t{score}\n') def write_residue_scores(self, data: Dict[str, List[float]]) -> None: outfile = self.context.plddt_output model_plddts = list(data.items()) model_plddts.sort() with open(outfile, 'w') as fp: for model, plddts in model_plddts: plddt_str_list = [str(x) for x in plddts] plddt_str = ','.join(plddt_str_list) fp.write(f'{model}\t{plddt_str}\n') def main(): # setup settings = Settings() settings.parse_settings() context = ExecutionContext(settings) loader = FileLoader(context) # generate & write outputs generator = OutputGenerator(loader) writer = OutputWriter(context) # confidence scores conf_scores = generator.gen_conf_scores() writer.write_conf_scores(conf_scores) # per-residue plddts if settings.output_residue_scores: residue_scores = generator.gen_residue_scores() writer.write_residue_scores(residue_scores) if __name__ == '__main__': main()