view gen_extra_outputs.py @ 8:ca90d17ff51b draft

"planemo upload for repository https://github.com/usegalaxy-au/tools-au commit 03537aada92b5fff565ff48dd47c81462c5df47e"
author galaxy-australia
date Fri, 19 Aug 2022 00:29:16 +0000
parents 7ae9d78b06f5
children 3bd420ec162d
line wrap: on
line source



import json
import pickle
import argparse
from typing import Any, Dict, List


class Settings:
    """parses then keeps track of program settings"""
    def __init__(self):
        self.workdir = None
        self.output_confidence_scores = True
        self.output_residue_scores = False

    def parse_settings(self) -> None:
        parser = argparse.ArgumentParser()
        parser.add_argument(
            "workdir", 
            help="alphafold output directory", 
            type=str
        )   
        parser.add_argument(
            "-p",
            "--plddts",
            help="output per-residue confidence scores (pLDDTs)", 
            action="store_true"
        )
        args = parser.parse_args()
        self.workdir = args.workdir.rstrip('/')
        self.output_residue_scores = args.plddts


class ExecutionContext:
    """uses program settings to get paths to files etc"""
    def __init__(self, settings: Settings):
        self.settings = settings

    @property
    def ranking_debug(self) -> str:
        return f'{self.settings.workdir}/ranking_debug.json'

    @property
    def model_pkls(self) -> List[str]:
        return [f'{self.settings.workdir}/result_model_{i}.pkl'
                for i in range(1, 6)]

    @property
    def model_conf_score_output(self) -> str:
        return f'{self.settings.workdir}/model_confidence_scores.tsv'

    @property
    def plddt_output(self) -> str:
        return f'{self.settings.workdir}/plddts.tsv'


class FileLoader:
    """loads file data for use by other classes"""
    def __init__(self, context: ExecutionContext):
        self.context = context

    def get_model_mapping(self) -> Dict[str, int]:
        data = self.load_ranking_debug()
        return {name: int(rank) + 1 
                for (rank, name) in enumerate(data['order'])}

    def get_conf_scores(self) -> Dict[str, float]:
        data = self.load_ranking_debug()
        return {name: float(f'{score:.2f}') 
                for name, score in data['plddts'].items()}

    def load_ranking_debug(self) -> Dict[str, Any]:
        with open(self.context.ranking_debug, 'r') as fp:
            return json.load(fp)

    def get_model_plddts(self) -> Dict[str, List[float]]:
        plddts: Dict[str, List[float]] = {}
        model_pkls = self.context.model_pkls
        for i in range(5):
            pklfile = model_pkls[i]
            with open(pklfile, 'rb') as fp:
                data = pickle.load(fp)
                plddts[f'model_{i+1}'] = [float(f'{x:.2f}') for x in data['plddt']]
        return plddts


class OutputGenerator:
    """generates the output data we are interested in creating"""
    def __init__(self, loader: FileLoader):
        self.loader = loader

    def gen_conf_scores(self):
        mapping = self.loader.get_model_mapping()
        scores = self.loader.get_conf_scores()
        ranked = list(scores.items())
        ranked.sort(key=lambda x: x[1], reverse=True)
        return {f'model_{mapping[name]}': score 
                for name, score in ranked}

    def gen_residue_scores(self) -> Dict[str, List[float]]:
        mapping = self.loader.get_model_mapping()
        model_plddts = self.loader.get_model_plddts()
        return {f'model_{mapping[name]}': plddts 
                for name, plddts in model_plddts.items()}


class OutputWriter:
    """writes generated data to files"""
    def __init__(self, context: ExecutionContext):
        self.context = context

    def write_conf_scores(self, data: Dict[str, float]) -> None:
        outfile = self.context.model_conf_score_output
        with open(outfile, 'w') as fp:
            for model, score in data.items():
                fp.write(f'{model}\t{score}\n')
    
    def write_residue_scores(self, data: Dict[str, List[float]]) -> None:
        outfile = self.context.plddt_output
        model_plddts = list(data.items())
        model_plddts.sort()

        with open(outfile, 'w') as fp:
            for model, plddts in model_plddts:
                plddt_str_list = [str(x) for x in plddts]
                plddt_str = ','.join(plddt_str_list)
                fp.write(f'{model}\t{plddt_str}\n')


def main():
    # setup
    settings = Settings()
    settings.parse_settings()
    context = ExecutionContext(settings)
    loader = FileLoader(context)
    
    # generate & write outputs
    generator = OutputGenerator(loader)
    writer = OutputWriter(context)
    
    # confidence scores
    conf_scores = generator.gen_conf_scores()
    writer.write_conf_scores(conf_scores)
    
    # per-residue plddts
    if settings.output_residue_scores:
        residue_scores = generator.gen_residue_scores()
        writer.write_residue_scores(residue_scores)

    
if __name__ == '__main__':
    main()