Mercurial > repos > galaxy-australia > alphafold2
diff outputs.py @ 14:d00e15139065 draft
planemo upload for repository https://github.com/usegalaxy-au/tools-au commit d490defa32d9c318137d2d781243b392cb14110d-dirty
author | galaxy-australia |
---|---|
date | Tue, 28 Feb 2023 01:15:42 +0000 |
parents | |
children | a58f7eb0df2c |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/outputs.py Tue Feb 28 01:15:42 2023 +0000 @@ -0,0 +1,245 @@ +"""Generate additional output files not produced by AlphaFold. + +Currently this is includes: +- model confidence scores +- per-residue confidence scores (pLDDTs - optional output) +- model_*.pkl files renamed with rank order + +N.B. There have been issues with this script breaking between AlphaFold +versions due to minor changes in the output directory structure across minor +versions. It will likely need updating with future releases of AlphaFold. + +This code is more complex than you might expect due to the output files +'moving around' considerably, depending on run parameters. You will see that +several output paths are determined dynamically. +""" + +import argparse +import json +import os +import pickle as pk +import shutil +from pathlib import Path +from typing import List + +# Output file names +OUTPUT_DIR = 'extra' +OUTPUTS = { + 'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl', + 'model_confidence_scores': OUTPUT_DIR + '/model_confidence_scores.tsv', + 'plddts': OUTPUT_DIR + '/plddts.tsv', + 'relax': OUTPUT_DIR + '/relax_metrics_ranked.json', +} + +# Keys for accessing confidence data from JSON/pkl files +# They change depending on whether the run was monomer or multimer +PLDDT_KEY = { + 'monomer': 'plddts', + 'multimer': 'iptm+ptm', +} + + +class Settings: + """Parse and store settings/config.""" + def __init__(self): + self.workdir = None + self.output_confidence_scores = True + self.output_residue_scores = False + self.is_multimer = False + + def parse_settings(self) -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "workdir", + help="alphafold output directory", + type=str + ) + parser.add_argument( + "-p", + "--plddts", + help="output per-residue confidence scores (pLDDTs)", + action="store_true" + ) + parser.add_argument( + "-m", + "--multimer", + help="parse output from AlphaFold multimer", + action="store_true" + ) + parser.add_argument( + "--model-pkl", + dest="model_pkl", + help="rename model pkl outputs with rank order", + action="store_true" + ) + args = parser.parse_args() + self.workdir = Path(args.workdir.rstrip('/')) + self.output_residue_scores = args.plddts + self.output_model_pkls = args.model_pkl + self.is_multimer = args.multimer + self.output_dir = self.workdir / OUTPUT_DIR + os.makedirs(self.output_dir, exist_ok=True) + + +class ExecutionContext: + """Collect file paths etc.""" + def __init__(self, settings: Settings): + self.settings = settings + if settings.is_multimer: + self.plddt_key = PLDDT_KEY['multimer'] + else: + self.plddt_key = PLDDT_KEY['monomer'] + + def get_model_key(self, ix: int) -> str: + """Return json key for model index. + + The key format changed between minor AlphaFold versions so this + function determines the correct key. + """ + with open(self.ranking_debug) as f: + data = json.load(f) + model_keys = list(data[self.plddt_key].keys()) + for k in model_keys: + if k.startswith(f"model_{ix}_"): + return k + return KeyError( + f'Could not find key for index={ix} in' + ' ranking_debug.json') + + @property + def ranking_debug(self) -> str: + return self.settings.workdir / 'ranking_debug.json' + + @property + def relax_metrics(self) -> str: + return self.settings.workdir / 'relax_metrics.json' + + @property + def relax_metrics_ranked(self) -> str: + return self.settings.workdir / 'relax_metrics_ranked.json' + + @property + def model_pkl_paths(self) -> List[str]: + return sorted([ + self.settings.workdir / f + for f in os.listdir(self.settings.workdir) + if f.startswith('result_model_') and f.endswith('.pkl') + ]) + + +class ResultModelPrediction: + """Load and manipulate data from result_model_*.pkl files.""" + def __init__(self, path: str, context: ExecutionContext): + self.context = context + self.path = path + self.name = os.path.basename(path).replace('result_', '').split('.')[0] + with open(path, 'rb') as path: + self.data = pk.load(path) + + @property + def plddts(self) -> List[float]: + """Return pLDDT scores for each residue.""" + return list(self.data['plddt']) + + +class ResultRanking: + """Load and manipulate data from ranking_debug.json file.""" + + def __init__(self, context: ExecutionContext): + self.path = context.ranking_debug + self.context = context + with open(self.path, 'r') as f: + self.data = json.load(f) + + @property + def order(self) -> List[str]: + """Return ordered list of model indexes.""" + return self.data['order'] + + def get_plddt_for_rank(self, rank: int) -> List[float]: + """Get pLDDT score for model instance.""" + return self.data[self.context.plddt_key][self.data['order'][rank - 1]] + + def get_rank_for_model(self, model_name: str) -> int: + """Return 0-indexed rank for given model name. + + Model names are expressed in result_model_*.pkl file names. + """ + return self.data['order'].index(model_name) + + +def write_confidence_scores(ranking: ResultRanking, context: ExecutionContext): + """Write per-model confidence scores.""" + path = context.settings.workdir / OUTPUTS['model_confidence_scores'] + with open(path, 'w') as f: + for rank in range(1, 6): + score = ranking.get_plddt_for_rank(rank) + f.write(f'ranked_{rank - 1}\t{score:.2f}\n') + + +def write_per_residue_scores( + ranking: ResultRanking, + context: ExecutionContext, +): + """Write per-residue plddts for each model. + + A row of plddt values is written for each model in tabular format. + """ + model_plddts = {} + for i, path in enumerate(context.model_pkl_paths): + model = ResultModelPrediction(path, context) + rank = ranking.get_rank_for_model(model.name) + model_plddts[rank] = model.plddts + + path = context.settings.workdir / OUTPUTS['plddts'] + with open(path, 'w') as f: + for i in sorted(list(model_plddts.keys())): + row = [f'ranked_{i}'] + [ + str(x) for x in model_plddts[i] + ] + f.write('\t'.join(row) + '\n') + + +def rename_model_pkls(ranking: ResultRanking, context: ExecutionContext): + """Rename model.pkl files so the rank order is implicit.""" + for path in context.model_pkl_paths: + model = ResultModelPrediction(path, context) + rank = ranking.get_rank_for_model(model.name) + new_path = ( + context.settings.workdir + / OUTPUTS['model_pkl'].format(rank=rank) + ) + shutil.copyfile(path, new_path) + + +def rekey_relax_metrics(ranking: ResultRanking, context: ExecutionContext): + """Replace keys in relax_metrics.json with 0-indexed rank.""" + with open(context.relax_metrics) as f: + data = json.load(f) + for k in list(data.keys()): + rank = ranking.get_rank_for_model(k) + data[f'ranked_{rank}'] = data.pop(k) + new_path = context.settings.workdir / OUTPUTS['relax'] + with open(new_path, 'w') as f: + json.dump(data, f) + + +def main(): + """Parse output files and generate additional output files.""" + settings = Settings() + settings.parse_settings() + context = ExecutionContext(settings) + ranking = ResultRanking(context) + write_confidence_scores(ranking, context) + rekey_relax_metrics(ranking, context) + + # Optional outputs + if settings.output_model_pkls: + rename_model_pkls(ranking, context) + + if settings.output_residue_scores: + write_per_residue_scores(ranking, context) + + +if __name__ == '__main__': + main()