Mercurial > repos > galaxy-australia > alphafold2
diff scripts/outputs.py @ 24:31f648b7555a draft
planemo upload for repository https://github.com/usegalaxy-au/tools-au commit 44db277529c0e189149235cf60a627193a792fba
author | galaxy-australia |
---|---|
date | Sat, 05 Jul 2025 03:56:38 +0000 |
parents | 2891385d6ace |
children |
line wrap: on
line diff
--- a/scripts/outputs.py Wed Apr 16 05:46:58 2025 +0000 +++ b/scripts/outputs.py Sat Jul 05 03:56:38 2025 +0000 @@ -20,11 +20,12 @@ import os import pickle as pk import shutil -import zipfile -from matplotlib import pyplot as plt from pathlib import Path from typing import Dict, List +from matplotlib import pyplot as plt + +# Output file paths OUTPUT_DIR = 'extra' OUTPUTS = { 'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl', @@ -63,6 +64,7 @@ self.workdir = None self.output_confidence_scores = True self.output_residue_scores = False + self.is_multimer = False self.parse() def parse(self) -> None: @@ -98,16 +100,6 @@ help="Plot multiple-sequence alignment coverage as a heatmap", action="store_true", ) - parser.add_argument( - "--msa", - help="Collect multiple-sequence alignments as ZIP archives", - action="store_true", - ) - parser.add_argument( - "--msa_only", - help="Alphafold generated MSA files only - skip all other outputs", - action="store_true", - ) args = parser.parse_args() self.workdir = Path(args.workdir.rstrip('/')) self.output_residue_scores = args.confidence_scores @@ -115,11 +107,8 @@ self.output_model_plots = args.plot self.output_pae = args.pae self.plot_msa = args.plot_msa - self.collect_msas = args.msa self.model_preset = self._sniff_model_preset() - self.is_multimer = self.model_preset == PRESETS.multimer self.output_dir = self.workdir / OUTPUT_DIR - self.msa_only = args.msa_only os.makedirs(self.output_dir, exist_ok=True) def _sniff_model_preset(self) -> bool: @@ -131,14 +120,13 @@ if '_ptm_' in path.name: return PRESETS.monomer_ptm return PRESETS.monomer - return PRESETS.monomer class ExecutionContext: """Collect file paths etc.""" def __init__(self, settings: Settings): self.settings = settings - if settings.is_multimer: + if settings.model_preset == PRESETS.multimer: self.plddt_key = PLDDT_KEY.multimer else: self.plddt_key = PLDDT_KEY.monomer @@ -211,7 +199,7 @@ def get_plddt_for_rank(self, rank: int) -> List[float]: """Get pLDDT score for model instance.""" - return self.data[self.context.plddt_key][self.data['order'][rank - 1]] + return self.data[self.context.plddt_key][self.data['order'][rank]] def get_rank_for_model(self, model_name: str) -> int: """Return 0-indexed rank for given model name. @@ -228,7 +216,8 @@ header = ['model', context.plddt_key] for i, path in enumerate(context.model_pkl_paths): - rank = int(path.name.split('model_')[-1][0]) + model_name = 'model_' + path.stem.split('model_')[1] + rank = ranking.get_rank_for_model(model_name) scores_ls = [ranking.get_plddt_for_rank(rank)] with open(path, 'rb') as f: data = pk.load(f) @@ -244,8 +233,9 @@ with open(outfile, 'w') as f: f.write('\t'.join(header) + '\n') - for rank, score_ls in scores.items(): - row = [f"ranked_{rank - 1}"] + [str(x) for x in score_ls] + for rank in sorted(scores): + score_ls = scores[rank] + row = [f"ranked_{rank}"] + [str(x) for x in score_ls] f.write('\t'.join(row) + '\n') @@ -390,53 +380,6 @@ plt.close() -def collect_msas(settings: Settings): - """Collect MSA files into ZIP archive(s).""" - - def zip_dir(directory: Path, is_multimer: bool, name: str): - chain_id = directory.with_suffix('.zip').stem - msa_dir = settings.output_dir / 'msas' - msa_dir.mkdir(exist_ok=True) - zip_name = ( - f"MSA-{chain_id}-{name}.zip" - if is_multimer - else f"MSA-{name}.zip") - zip_path = msa_dir / zip_name - with zipfile.ZipFile(zip_path, 'w') as z: - for path in directory.glob('*'): - z.write(path, path.name) - - print("Collecting MSA archives...") - chain_names = get_input_sequence_ids( - settings.workdir.parent.parent / 'alphafold.fasta') - msa_dir = settings.workdir / 'msas' - is_multimer = (msa_dir / 'A').exists() - if is_multimer: - msa_dirs = sorted([ - path for path in msa_dir.glob('*') - if path.is_dir() - ]) - for i, path in enumerate(msa_dirs): - zip_dir(path, is_multimer, chain_names[i]) - else: - zip_dir(msa_dir, is_multimer, chain_names[0]) - - -def get_input_sequence_ids(fasta_file: Path) -> List[str]: - """Read headers from the input FASTA file. - Split them to get a sequence ID and truncate to 20 chars max. - """ - headers = [] - for line in fasta_file.read_text().split('\n'): - if line.startswith('>'): - seq_id = line[1:].split(' ')[0] - seq_id_trunc = seq_id[:20].strip() - if len(seq_id) > 20: - seq_id_trunc += '...' - headers.append(seq_id_trunc) - return headers - - def template_html(context: ExecutionContext): """Template HTML file. @@ -456,27 +399,24 @@ def main(): """Parse output files and generate additional output files.""" settings = Settings() - if not settings.msa_only: - context = ExecutionContext(settings) - ranking = ResultRanking(context) - write_confidence_scores(ranking, context) - rekey_relax_metrics(ranking, context) - template_html(context) + context = ExecutionContext(settings) + ranking = ResultRanking(context) + write_confidence_scores(ranking, context) + rekey_relax_metrics(ranking, context) + template_html(context) - # Optional outputs - if settings.output_model_pkls: - rename_model_pkls(ranking, context) - if settings.output_model_plots: - plddt_pae_plots(ranking, context) - if settings.output_pae: - # Only created by monomer_ptm and multimer models - extract_pae_to_csv(ranking, context) - if settings.output_residue_scores: - write_per_residue_scores(ranking, context) - if settings.plot_msa: - plot_msa(settings.workdir) - if settings.collect_msas or settings.msa_only: - collect_msas(settings) + # Optional outputs + if settings.output_model_pkls: + rename_model_pkls(ranking, context) + if settings.output_model_plots: + plddt_pae_plots(ranking, context) + if settings.output_pae: + # Only created by monomer_ptm and multimer models + extract_pae_to_csv(ranking, context) + if settings.output_residue_scores: + write_per_residue_scores(ranking, context) + if settings.plot_msa: + plot_msa(context.settings.workdir) if __name__ == '__main__':