Mercurial > repos > galaxy-australia > alphafold2
diff scripts/outputs.py @ 21:e7f1b552a695 draft
planemo upload for repository https://github.com/usegalaxy-au/tools-au commit 628c9fdcb77489063145a2307b6bb6a450416dd6-dirty
author | galaxy-australia |
---|---|
date | Tue, 29 Oct 2024 02:15:36 +0000 |
parents | 6ab1a261520a |
children | 3f188450ca4f |
line wrap: on
line diff
--- a/scripts/outputs.py Sun Jul 28 20:09:55 2024 +0000 +++ b/scripts/outputs.py Tue Oct 29 02:15:36 2024 +0000 @@ -16,11 +16,12 @@ import argparse import json +import numpy as np import os import pickle as pk import shutil from pathlib import Path -from typing import List +from typing import Dict, List from matplotlib import pyplot as plt @@ -33,13 +34,7 @@ 'model_confidence_scores': OUTPUT_DIR + '/model_confidence_scores.tsv', 'plddts': OUTPUT_DIR + '/plddts.tsv', 'relax': OUTPUT_DIR + '/relax_metrics_ranked.json', -} - -# Keys for accessing confidence data from JSON/pkl files -# They change depending on whether the run was monomer or multimer -PLDDT_KEY = { - 'monomer': 'plddts', - 'multimer': 'iptm+ptm', + 'msa': OUTPUT_DIR + '/msa_coverage.png', } HTML_PATH = Path(__file__).parent / "alphafold.html" @@ -49,6 +44,20 @@ 'class="btn disabled" id="btn-ranked_{rank}" disabled') +class PLDDT_KEY: + """Dict keys for accessing confidence data from JSON/pkl files." + Changes depending on which model PRESET was used. + """ + monomer = 'plddts' + multimer = 'iptm+ptm' + + +class PRESETS: + monomer = 'monomer' + monomer_ptm = 'monomer_ptm' + multimer = 'multimer' + + class Settings: """Parse and store settings/config.""" def __init__(self): @@ -63,54 +72,63 @@ parser.add_argument( "workdir", help="alphafold output directory", - type=str + type=str, ) parser.add_argument( - "-p", - "--plddts", + "-s", + "--confidence-scores", help="output per-residue confidence scores (pLDDTs)", - action="store_true" - ) - parser.add_argument( - "-m", - "--multimer", - help="parse output from AlphaFold multimer", - action="store_true" + action="store_true", ) parser.add_argument( "--pkl", help="rename model pkl outputs with rank order", - action="store_true" + action="store_true", ) parser.add_argument( "--pae", help="extract PAE from pkl files to CSV format", - action="store_true" + action="store_true", ) parser.add_argument( "--plot", help="Plot pLDDT and PAE for each model", - action="store_true" + action="store_true", + ) + parser.add_argument( + "--plot-msa", + help="Plot multiple-sequence alignment coverage as a heatmap", + action="store_true", ) args = parser.parse_args() self.workdir = Path(args.workdir.rstrip('/')) - self.output_residue_scores = args.plddts + self.output_residue_scores = args.confidence_scores self.output_model_pkls = args.pkl self.output_model_plots = args.plot self.output_pae = args.pae - self.is_multimer = args.multimer + self.plot_msa = args.plot_msa + self.model_preset = self._sniff_model_preset() self.output_dir = self.workdir / OUTPUT_DIR os.makedirs(self.output_dir, exist_ok=True) + def _sniff_model_preset(self) -> bool: + """Check if the run was multimer or monomer.""" + with open(self.workdir / 'relax_metrics.json') as f: + if '_multimer_' in f.read(): + return PRESETS.multimer + if '_ptm_' in f.read(): + return PRESETS.monomer_ptm + return PRESETS.monomer + class ExecutionContext: """Collect file paths etc.""" def __init__(self, settings: Settings): self.settings = settings - if settings.is_multimer: - self.plddt_key = PLDDT_KEY['multimer'] + if settings.model_preset == PRESETS.multimer: + self.plddt_key = PLDDT_KEY.multimer else: - self.plddt_key = PLDDT_KEY['monomer'] + self.plddt_key = PLDDT_KEY.monomer def get_model_key(self, ix: int) -> str: """Return json key for model index. @@ -192,11 +210,30 @@ def write_confidence_scores(ranking: ResultRanking, context: ExecutionContext): """Write per-model confidence scores.""" - path = context.settings.workdir / OUTPUTS['model_confidence_scores'] - with open(path, 'w') as f: - for rank in range(1, len(context.model_pkl_paths) + 1): - score = ranking.get_plddt_for_rank(rank) - f.write(f'ranked_{rank - 1}\t{score:.2f}\n') + outfile = context.settings.workdir / OUTPUTS['model_confidence_scores'] + scores: Dict[str, list] = {} + header = ['model', context.plddt_key] + + for i, path in enumerate(context.model_pkl_paths): + rank = int(path.name.split('model_')[-1][0]) + scores_ls = [ranking.get_plddt_for_rank(rank)] + with open(path, 'rb') as f: + data = pk.load(f) + if 'ptm' in data: + scores_ls.append(data['ptm']) + if i == 0: + header += ['ptm'] + if 'iptm' in data: + scores_ls.append(data['iptm']) + if i == 0: + header += ['iptm'] + scores[rank] = scores_ls + + with open(outfile, 'w') as f: + f.write('\t'.join(header) + '\n') + for rank, score_ls in scores.items(): + row = [f"ranked_{rank - 1}"] + [str(x) for x in score_ls] + f.write('\t'.join(row) + '\n') def write_per_residue_scores( @@ -304,6 +341,40 @@ plt.ylabel('Aligned residue') plt.savefig(png_path) + plt.close() + + +def plot_msa(wdir: Path, dpi: int = 150): + """Plot MSA as a heatmap.""" + with open(wdir / 'features.pkl', 'rb') as f: + features = pk.load(f) + + msa = features.get('msa') + if msa is None: + print("Could not plot MSA coverage - 'msa' key not found in" + " features.pkl") + return + seqid = (np.array(msa[0] == msa).mean(-1)) + seqid_sort = seqid.argsort() + non_gaps = (msa != 21).astype(float) + non_gaps[non_gaps == 0] = np.nan + final = non_gaps[seqid_sort] * seqid[seqid_sort, None] + + plt.figure(figsize=(6, 4)) + # plt.subplot(111) + plt.title("Sequence coverage") + plt.imshow(final, + interpolation='nearest', aspect='auto', + cmap="rainbow_r", vmin=0, vmax=1, origin='lower') + plt.plot((msa != 21).sum(0), color='black') + plt.xlim(-0.5, msa.shape[1] - 0.5) + plt.ylim(-0.5, msa.shape[0] - 0.5) + plt.colorbar(label="Sequence identity to query", ) + plt.xlabel("Positions") + plt.ylabel("Sequences") + plt.tight_layout() + plt.savefig(wdir / OUTPUTS['msa'], dpi=dpi) + plt.close() def template_html(context: ExecutionContext): @@ -341,6 +412,8 @@ extract_pae_to_csv(ranking, context) if settings.output_residue_scores: write_per_residue_scores(ranking, context) + if settings.plot_msa: + plot_msa(context.settings.workdir) if __name__ == '__main__':