Mercurial > repos > galaxy-australia > alphafold2
changeset 16:f9eb041c518c draft
planemo upload for repository https://github.com/usegalaxy-au/tools-au commit ee77734f1800350fa2a6ef28b2b8eade304a456f-dirty
author | galaxy-australia |
---|---|
date | Mon, 03 Apr 2023 01:00:42 +0000 |
parents | a58f7eb0df2c |
children | 5b85006245f3 |
files | README.rst alphafold.xml outputs.py scripts/outputs.py scripts/validate_fasta.py validate_fasta.py |
diffstat | 6 files changed, 608 insertions(+), 601 deletions(-) [+] |
line wrap: on
line diff
--- a/README.rst Fri Mar 10 02:48:07 2023 +0000 +++ b/README.rst Mon Apr 03 01:00:42 2023 +0000 @@ -75,18 +75,20 @@ ~~~~~~~~~~~~~~ Alphafold needs reference data to run. The wrapper expects this data to -be present at ``/data/alphafold_databases``. A custom path will be read from -the ``ALPHAFOLD_DB`` environment variable, if set. +be present at ``/$ALPHAFOLD_DB/TOOL_MINOR_VERSION``. +Where ``ALPHAFOLD_DB`` is a custom path that will be read from +the ``ALPHAFOLD_DB`` environment variable (defaulting to ``/data``). +And TOOL_MINOR_VERSION is the alphafold version, e.g. ``2.3.1``. To download the AlphaFold reference DBs: :: # Set your AlphaFold DB path - ALPHAFOLD_DB=/data/alphafold_databases + ALPHAFOLD_DB=/data/alphafold_databases/2.3.1 # Set your target AlphaFold version - ALPHAFOLD_VERSION= # e.g. 2.1.2 + ALPHAFOLD_VERSION=2.3.1 # Download repo wget https://github.com/deepmind/alphafold/releases/tag/v${ALPHAFOLD_VERSION}.tar.gz @@ -110,7 +112,7 @@ # NOTE: this structure will change between minor AlphaFold versions # The tree shown below was updated for v2.3.1 - data/alphafold_databases + data/alphafold_databases/2.3.1/ ├── bfd │ ├── bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt_a3m.ffdata │ ├── bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt_a3m.ffindex @@ -176,7 +178,7 @@ :: - data/alphafold_databases + data/alphafold_databases/2.3.1/ ├── small_bfd │ └── bfd-first_non_consensus_sequences.fasta @@ -193,7 +195,7 @@ If you wish to continue hosting prior versions of the tool, you must maintain the reference DBs for each version. The ``ALPHAFOLD_DB`` environment variable must then be set respectively for each tool version in your job conf (on Galaxy -AU this is currently `configured with TPV<https://github.com/usegalaxy-au/infrastructure/blob/master/files/galaxy/dynamic_job_rules/production/total_perspective_vortex/tools.yml#L1515-L1554>`_). +AU this is currently `configured with TPV <https://github.com/usegalaxy-au/infrastructure/blob/master/files/galaxy/dynamic_job_rules/production/total_perspective_vortex/tools.yml#L1515-L1554>`_). To minimize redundancy between DB version, we have symlinked the database components that are unchanging between versions. In ``v2.1.2 -> v2.3.1`` the BFD
--- a/alphafold.xml Fri Mar 10 02:48:07 2023 +0000 +++ b/alphafold.xml Mon Apr 03 01:00:42 2023 +0000 @@ -2,7 +2,8 @@ <description> - AI-guided 3D structural prediction of proteins</description> <macros> <token name="@TOOL_VERSION@">2.3.1</token> - <token name="@VERSION_SUFFIX@">1</token> + <token name="@TOOL_MINOR_VERSION@">2.3</token> + <token name="@VERSION_SUFFIX@">2</token> <import>macro_output.xml</import> <import>macro_test_output.xml</import> </macros> @@ -24,8 +25,11 @@ ## in planemo's gx_venv_n/bin/activate script. AlphaFold outputs will be copied ## from the test-data directory instead of running the tool. -## $ALPHAFOLD_DB variable should point to the location of the AlphaFold -## databases - defaults to /data +## $ALPHAFOLD_DB variable should point to the location containing the versioned +## AlphaFold databases - defaults to /data +## that is the directory should contain a subdir / symlink named identical as +## the value of the TOOL_MINOR_VERSION token which contains the AF reference data +## for the corresponding version ## Read FASTA input ----------------------------------------------------------- #if $fasta_or_text.input_mode == 'history': @@ -34,7 +38,7 @@ echo '$fasta_or_text.fasta_text' > input.fasta #end if -&& python3 '$__tool_directory__/validate_fasta.py' input.fasta +&& python3 '$__tool_directory__/scripts/validate_fasta.py' input.fasta --min_length \${ALPHAFOLD_AA_LENGTH_MIN:-0} --max_length \${ALPHAFOLD_AA_LENGTH_MAX:-0} #if $model_preset == 'multimer': @@ -51,26 +55,26 @@ #if os.environ.get('PLANEMO_TESTING'): ## Run in testing mode (mocks a successful AlphaFold run by copying outputs) && echo "Creating dummy outputs for model_preset=$model_preset..." - && bash '$__tool_directory__/mock_alphafold.sh' $model_preset + && bash '$__tool_directory__/scripts/mock_alphafold.sh' $model_preset #else: ## Run AlphaFold && python /app/alphafold/run_alphafold.py --fasta_paths alphafold.fasta --output_dir output - --data_dir \${ALPHAFOLD_DB:-/data} + --data_dir \${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/ --model_preset=$model_preset ## Set reference database paths - --uniref90_database_path \${ALPHAFOLD_DB:-/data}/uniref90/uniref90.fasta - --mgnify_database_path \${ALPHAFOLD_DB:-/data}/mgnify/mgy_clusters_2022_05.fa - --template_mmcif_dir \${ALPHAFOLD_DB:-/data}/pdb_mmcif/mmcif_files - --obsolete_pdbs_path \${ALPHAFOLD_DB:-/data}/pdb_mmcif/obsolete.dat + --uniref90_database_path \${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/uniref90/uniref90.fasta + --mgnify_database_path \${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/mgnify/mgy_clusters_2022_05.fa + --template_mmcif_dir \${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/pdb_mmcif/mmcif_files + --obsolete_pdbs_path \${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/pdb_mmcif/obsolete.dat #if $dbs == 'full': - --bfd_database_path \${ALPHAFOLD_DB:-/data}/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt - --uniref30_database_path \${ALPHAFOLD_DB:-/data}/uniref30/UniRef30_2021_03 + --bfd_database_path \${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt + --uniref30_database_path \${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/uniref30/UniRef30_2021_03 #else --db_preset=reduced_dbs - --small_bfd_database_path \${ALPHAFOLD_DB:-/data}/small_bfd/bfd-first_non_consensus_sequences.fasta + --small_bfd_database_path \${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/small_bfd/bfd-first_non_consensus_sequences.fasta #end if #if $max_template_date: @@ -82,16 +86,16 @@ --use_gpu_relax=\${ALPHAFOLD_USE_GPU:-True} ## introduced in v2.1.2 #if $model_preset == 'multimer': - --pdb_seqres_database_path=\${ALPHAFOLD_DB:-/data}/pdb_seqres/pdb_seqres.txt - --uniprot_database_path=\${ALPHAFOLD_DB:-/data}/uniprot/uniprot.fasta + --pdb_seqres_database_path=\${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/pdb_seqres/pdb_seqres.txt + --uniprot_database_path=\${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/uniprot/uniprot.fasta --num_multimer_predictions_per_model=1 ## introduced in v2.2.0 #else - --pdb70_database_path \${ALPHAFOLD_DB:-/data}/pdb70/pdb70 + --pdb70_database_path \${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/pdb70/pdb70 #end if #end if ## Generate additional outputs ------------------------------------------------ -&& python3 '$__tool_directory__/outputs.py' output/alphafold +&& python3 '$__tool_directory__/scripts/outputs.py' output/alphafold $outputs.plddts $outputs.model_pkls $outputs.pae_csv
--- a/outputs.py Fri Mar 10 02:48:07 2023 +0000 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,323 +0,0 @@ -"""Generate additional output files not produced by AlphaFold. - -Currently this is includes: -- model confidence scores -- per-residue confidence scores (pLDDTs - optional output) -- model_*.pkl files renamed with rank order - -N.B. There have been issues with this script breaking between AlphaFold -versions due to minor changes in the output directory structure across minor -versions. It will likely need updating with future releases of AlphaFold. - -This code is more complex than you might expect due to the output files -'moving around' considerably, depending on run parameters. You will see that -several output paths are determined dynamically. -""" - -import argparse -import json -import os -import pickle as pk -import shutil -from matplotlib import pyplot as plt -from pathlib import Path -from typing import List - -# Output file paths -OUTPUT_DIR = 'extra' -OUTPUTS = { - 'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl', - 'model_pae': OUTPUT_DIR + '/pae_ranked_{rank}.csv', - 'model_plot': OUTPUT_DIR + '/ranked_{rank}.png', - 'model_confidence_scores': OUTPUT_DIR + '/model_confidence_scores.tsv', - 'plddts': OUTPUT_DIR + '/plddts.tsv', - 'relax': OUTPUT_DIR + '/relax_metrics_ranked.json', -} - -# Keys for accessing confidence data from JSON/pkl files -# They change depending on whether the run was monomer or multimer -PLDDT_KEY = { - 'monomer': 'plddts', - 'multimer': 'iptm+ptm', -} - - -class Settings: - """Parse and store settings/config.""" - def __init__(self): - self.workdir = None - self.output_confidence_scores = True - self.output_residue_scores = False - self.is_multimer = False - self.parse() - - def parse(self) -> None: - parser = argparse.ArgumentParser() - parser.add_argument( - "workdir", - help="alphafold output directory", - type=str - ) - parser.add_argument( - "-p", - "--plddts", - help="output per-residue confidence scores (pLDDTs)", - action="store_true" - ) - parser.add_argument( - "-m", - "--multimer", - help="parse output from AlphaFold multimer", - action="store_true" - ) - parser.add_argument( - "--pkl", - help="rename model pkl outputs with rank order", - action="store_true" - ) - parser.add_argument( - "--pae", - help="extract PAE from pkl files to CSV format", - action="store_true" - ) - parser.add_argument( - "--plot", - help="Plot pLDDT and PAE for each model", - action="store_true" - ) - args = parser.parse_args() - self.workdir = Path(args.workdir.rstrip('/')) - self.output_residue_scores = args.plddts - self.output_model_pkls = args.pkl - self.output_model_plots = args.plot - self.output_pae = args.pae - self.is_multimer = args.multimer - self.output_dir = self.workdir / OUTPUT_DIR - os.makedirs(self.output_dir, exist_ok=True) - - -class ExecutionContext: - """Collect file paths etc.""" - def __init__(self, settings: Settings): - self.settings = settings - if settings.is_multimer: - self.plddt_key = PLDDT_KEY['multimer'] - else: - self.plddt_key = PLDDT_KEY['monomer'] - - def get_model_key(self, ix: int) -> str: - """Return json key for model index. - - The key format changed between minor AlphaFold versions so this - function determines the correct key. - """ - with open(self.ranking_debug) as f: - data = json.load(f) - model_keys = list(data[self.plddt_key].keys()) - for k in model_keys: - if k.startswith(f"model_{ix}_"): - return k - return KeyError( - f'Could not find key for index={ix} in' - ' ranking_debug.json') - - @property - def ranking_debug(self) -> str: - return self.settings.workdir / 'ranking_debug.json' - - @property - def relax_metrics(self) -> str: - return self.settings.workdir / 'relax_metrics.json' - - @property - def relax_metrics_ranked(self) -> str: - return self.settings.workdir / 'relax_metrics_ranked.json' - - @property - def model_pkl_paths(self) -> List[str]: - return sorted([ - self.settings.workdir / f - for f in os.listdir(self.settings.workdir) - if f.startswith('result_model_') and f.endswith('.pkl') - ]) - - -class ResultModelPrediction: - """Load and manipulate data from result_model_*.pkl files.""" - def __init__(self, path: str, context: ExecutionContext): - self.context = context - self.path = path - self.name = os.path.basename(path).replace('result_', '').split('.')[0] - with open(path, 'rb') as path: - self.data = pk.load(path) - - @property - def plddts(self) -> List[float]: - """Return pLDDT scores for each residue.""" - return list(self.data['plddt']) - - -class ResultRanking: - """Load and manipulate data from ranking_debug.json file.""" - - def __init__(self, context: ExecutionContext): - self.path = context.ranking_debug - self.context = context - with open(self.path, 'r') as f: - self.data = json.load(f) - - @property - def order(self) -> List[str]: - """Return ordered list of model indexes.""" - return self.data['order'] - - def get_plddt_for_rank(self, rank: int) -> List[float]: - """Get pLDDT score for model instance.""" - return self.data[self.context.plddt_key][self.data['order'][rank - 1]] - - def get_rank_for_model(self, model_name: str) -> int: - """Return 0-indexed rank for given model name. - - Model names are expressed in result_model_*.pkl file names. - """ - return self.data['order'].index(model_name) - - -def write_confidence_scores(ranking: ResultRanking, context: ExecutionContext): - """Write per-model confidence scores.""" - path = context.settings.workdir / OUTPUTS['model_confidence_scores'] - with open(path, 'w') as f: - for rank in range(1, 6): - score = ranking.get_plddt_for_rank(rank) - f.write(f'ranked_{rank - 1}\t{score:.2f}\n') - - -def write_per_residue_scores( - ranking: ResultRanking, - context: ExecutionContext, -): - """Write per-residue plddts for each model. - - A row of plddt values is written for each model in tabular format. - """ - model_plddts = {} - for i, path in enumerate(context.model_pkl_paths): - model = ResultModelPrediction(path, context) - rank = ranking.get_rank_for_model(model.name) - model_plddts[rank] = model.plddts - - path = context.settings.workdir / OUTPUTS['plddts'] - with open(path, 'w') as f: - for i in sorted(list(model_plddts.keys())): - row = [f'ranked_{i}'] + [ - str(x) for x in model_plddts[i] - ] - f.write('\t'.join(row) + '\n') - - -def rename_model_pkls(ranking: ResultRanking, context: ExecutionContext): - """Rename model.pkl files so the rank order is implicit.""" - for path in context.model_pkl_paths: - model = ResultModelPrediction(path, context) - rank = ranking.get_rank_for_model(model.name) - new_path = ( - context.settings.workdir - / OUTPUTS['model_pkl'].format(rank=rank) - ) - shutil.copyfile(path, new_path) - - -def extract_pae_to_csv(ranking: ResultRanking, context: ExecutionContext): - """Extract predicted alignment error matrix from pickle files. - - Creates a CSV file for each of five ranked models. - """ - for path in context.model_pkl_paths: - model = ResultModelPrediction(path, context) - rank = ranking.get_rank_for_model(model.name) - with open(path, 'rb') as f: - data = pk.load(f) - if 'predicted_aligned_error' not in data: - print("Skipping PAE output" - f" - not found in {path}." - " Running with model_preset=monomer?") - return - pae = data['predicted_aligned_error'] - out_path = ( - context.settings.workdir - / OUTPUTS['model_pae'].format(rank=rank) - ) - with open(out_path, 'w') as f: - for row in pae: - f.write(','.join([str(x) for x in row]) + '\n') - - -def rekey_relax_metrics(ranking: ResultRanking, context: ExecutionContext): - """Replace keys in relax_metrics.json with 0-indexed rank.""" - with open(context.relax_metrics) as f: - data = json.load(f) - for k in list(data.keys()): - rank = ranking.get_rank_for_model(k) - data[f'ranked_{rank}'] = data.pop(k) - new_path = context.settings.workdir / OUTPUTS['relax'] - with open(new_path, 'w') as f: - json.dump(data, f) - - -def plddt_pae_plots(ranking: ResultRanking, context: ExecutionContext): - """Generate a pLDDT + PAE plot for each model.""" - for path in context.model_pkl_paths: - num_plots = 2 - model = ResultModelPrediction(path, context) - rank = ranking.get_rank_for_model(model.name) - png_path = ( - context.settings.workdir - / OUTPUTS['model_plot'].format(rank=rank) - ) - plddts = model.data['plddt'] - if 'predicted_aligned_error' in model.data: - pae = model.data['predicted_aligned_error'] - max_pae = model.data['max_predicted_aligned_error'] - else: - num_plots = 1 - - plt.figure(figsize=[8 * num_plots, 6]) - plt.subplot(1, num_plots, 1) - plt.plot(plddts) - plt.title('Predicted LDDT') - plt.xlabel('Residue') - plt.ylabel('pLDDT') - - if num_plots == 2: - plt.subplot(1, 2, 2) - plt.imshow(pae, vmin=0., vmax=max_pae, cmap='Greens_r') - plt.colorbar(fraction=0.046, pad=0.04) - plt.title('Predicted Aligned Error') - plt.xlabel('Scored residue') - plt.ylabel('Aligned residue') - - plt.savefig(png_path) - - -def main(): - """Parse output files and generate additional output files.""" - settings = Settings() - context = ExecutionContext(settings) - ranking = ResultRanking(context) - write_confidence_scores(ranking, context) - rekey_relax_metrics(ranking, context) - - # Optional outputs - if settings.output_model_pkls: - rename_model_pkls(ranking, context) - if settings.output_model_plots: - plddt_pae_plots(ranking, context) - if settings.output_pae: - # Only created by monomer_ptm and multimer models - extract_pae_to_csv(ranking, context) - if settings.output_residue_scores: - write_per_residue_scores(ranking, context) - - -if __name__ == '__main__': - main()
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/scripts/outputs.py Mon Apr 03 01:00:42 2023 +0000 @@ -0,0 +1,324 @@ +"""Generate additional output files not produced by AlphaFold. + +Currently this is includes: +- model confidence scores +- per-residue confidence scores (pLDDTs - optional output) +- model_*.pkl files renamed with rank order + +N.B. There have been issues with this script breaking between AlphaFold +versions due to minor changes in the output directory structure across minor +versions. It will likely need updating with future releases of AlphaFold. + +This code is more complex than you might expect due to the output files +'moving around' considerably, depending on run parameters. You will see that +several output paths are determined dynamically. +""" + +import argparse +import json +import os +import pickle as pk +import shutil +from pathlib import Path +from typing import List + +from matplotlib import pyplot as plt + +# Output file paths +OUTPUT_DIR = 'extra' +OUTPUTS = { + 'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl', + 'model_pae': OUTPUT_DIR + '/pae_ranked_{rank}.csv', + 'model_plot': OUTPUT_DIR + '/ranked_{rank}.png', + 'model_confidence_scores': OUTPUT_DIR + '/model_confidence_scores.tsv', + 'plddts': OUTPUT_DIR + '/plddts.tsv', + 'relax': OUTPUT_DIR + '/relax_metrics_ranked.json', +} + +# Keys for accessing confidence data from JSON/pkl files +# They change depending on whether the run was monomer or multimer +PLDDT_KEY = { + 'monomer': 'plddts', + 'multimer': 'iptm+ptm', +} + + +class Settings: + """Parse and store settings/config.""" + def __init__(self): + self.workdir = None + self.output_confidence_scores = True + self.output_residue_scores = False + self.is_multimer = False + self.parse() + + def parse(self) -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "workdir", + help="alphafold output directory", + type=str + ) + parser.add_argument( + "-p", + "--plddts", + help="output per-residue confidence scores (pLDDTs)", + action="store_true" + ) + parser.add_argument( + "-m", + "--multimer", + help="parse output from AlphaFold multimer", + action="store_true" + ) + parser.add_argument( + "--pkl", + help="rename model pkl outputs with rank order", + action="store_true" + ) + parser.add_argument( + "--pae", + help="extract PAE from pkl files to CSV format", + action="store_true" + ) + parser.add_argument( + "--plot", + help="Plot pLDDT and PAE for each model", + action="store_true" + ) + args = parser.parse_args() + self.workdir = Path(args.workdir.rstrip('/')) + self.output_residue_scores = args.plddts + self.output_model_pkls = args.pkl + self.output_model_plots = args.plot + self.output_pae = args.pae + self.is_multimer = args.multimer + self.output_dir = self.workdir / OUTPUT_DIR + os.makedirs(self.output_dir, exist_ok=True) + + +class ExecutionContext: + """Collect file paths etc.""" + def __init__(self, settings: Settings): + self.settings = settings + if settings.is_multimer: + self.plddt_key = PLDDT_KEY['multimer'] + else: + self.plddt_key = PLDDT_KEY['monomer'] + + def get_model_key(self, ix: int) -> str: + """Return json key for model index. + + The key format changed between minor AlphaFold versions so this + function determines the correct key. + """ + with open(self.ranking_debug) as f: + data = json.load(f) + model_keys = list(data[self.plddt_key].keys()) + for k in model_keys: + if k.startswith(f"model_{ix}_"): + return k + return KeyError( + f'Could not find key for index={ix} in' + ' ranking_debug.json') + + @property + def ranking_debug(self) -> str: + return self.settings.workdir / 'ranking_debug.json' + + @property + def relax_metrics(self) -> str: + return self.settings.workdir / 'relax_metrics.json' + + @property + def relax_metrics_ranked(self) -> str: + return self.settings.workdir / 'relax_metrics_ranked.json' + + @property + def model_pkl_paths(self) -> List[str]: + return sorted([ + self.settings.workdir / f + for f in os.listdir(self.settings.workdir) + if f.startswith('result_model_') and f.endswith('.pkl') + ]) + + +class ResultModelPrediction: + """Load and manipulate data from result_model_*.pkl files.""" + def __init__(self, path: str, context: ExecutionContext): + self.context = context + self.path = path + self.name = os.path.basename(path).replace('result_', '').split('.')[0] + with open(path, 'rb') as path: + self.data = pk.load(path) + + @property + def plddts(self) -> List[float]: + """Return pLDDT scores for each residue.""" + return list(self.data['plddt']) + + +class ResultRanking: + """Load and manipulate data from ranking_debug.json file.""" + + def __init__(self, context: ExecutionContext): + self.path = context.ranking_debug + self.context = context + with open(self.path, 'r') as f: + self.data = json.load(f) + + @property + def order(self) -> List[str]: + """Return ordered list of model indexes.""" + return self.data['order'] + + def get_plddt_for_rank(self, rank: int) -> List[float]: + """Get pLDDT score for model instance.""" + return self.data[self.context.plddt_key][self.data['order'][rank - 1]] + + def get_rank_for_model(self, model_name: str) -> int: + """Return 0-indexed rank for given model name. + + Model names are expressed in result_model_*.pkl file names. + """ + return self.data['order'].index(model_name) + + +def write_confidence_scores(ranking: ResultRanking, context: ExecutionContext): + """Write per-model confidence scores.""" + path = context.settings.workdir / OUTPUTS['model_confidence_scores'] + with open(path, 'w') as f: + for rank in range(1, 6): + score = ranking.get_plddt_for_rank(rank) + f.write(f'ranked_{rank - 1}\t{score:.2f}\n') + + +def write_per_residue_scores( + ranking: ResultRanking, + context: ExecutionContext, +): + """Write per-residue plddts for each model. + + A row of plddt values is written for each model in tabular format. + """ + model_plddts = {} + for i, path in enumerate(context.model_pkl_paths): + model = ResultModelPrediction(path, context) + rank = ranking.get_rank_for_model(model.name) + model_plddts[rank] = model.plddts + + path = context.settings.workdir / OUTPUTS['plddts'] + with open(path, 'w') as f: + for i in sorted(list(model_plddts.keys())): + row = [f'ranked_{i}'] + [ + str(x) for x in model_plddts[i] + ] + f.write('\t'.join(row) + '\n') + + +def rename_model_pkls(ranking: ResultRanking, context: ExecutionContext): + """Rename model.pkl files so the rank order is implicit.""" + for path in context.model_pkl_paths: + model = ResultModelPrediction(path, context) + rank = ranking.get_rank_for_model(model.name) + new_path = ( + context.settings.workdir + / OUTPUTS['model_pkl'].format(rank=rank) + ) + shutil.copyfile(path, new_path) + + +def extract_pae_to_csv(ranking: ResultRanking, context: ExecutionContext): + """Extract predicted alignment error matrix from pickle files. + + Creates a CSV file for each of five ranked models. + """ + for path in context.model_pkl_paths: + model = ResultModelPrediction(path, context) + rank = ranking.get_rank_for_model(model.name) + with open(path, 'rb') as f: + data = pk.load(f) + if 'predicted_aligned_error' not in data: + print("Skipping PAE output" + f" - not found in {path}." + " Running with model_preset=monomer?") + return + pae = data['predicted_aligned_error'] + out_path = ( + context.settings.workdir + / OUTPUTS['model_pae'].format(rank=rank) + ) + with open(out_path, 'w') as f: + for row in pae: + f.write(','.join([str(x) for x in row]) + '\n') + + +def rekey_relax_metrics(ranking: ResultRanking, context: ExecutionContext): + """Replace keys in relax_metrics.json with 0-indexed rank.""" + with open(context.relax_metrics) as f: + data = json.load(f) + for k in list(data.keys()): + rank = ranking.get_rank_for_model(k) + data[f'ranked_{rank}'] = data.pop(k) + new_path = context.settings.workdir / OUTPUTS['relax'] + with open(new_path, 'w') as f: + json.dump(data, f) + + +def plddt_pae_plots(ranking: ResultRanking, context: ExecutionContext): + """Generate a pLDDT + PAE plot for each model.""" + for path in context.model_pkl_paths: + num_plots = 2 + model = ResultModelPrediction(path, context) + rank = ranking.get_rank_for_model(model.name) + png_path = ( + context.settings.workdir + / OUTPUTS['model_plot'].format(rank=rank) + ) + plddts = model.data['plddt'] + if 'predicted_aligned_error' in model.data: + pae = model.data['predicted_aligned_error'] + max_pae = model.data['max_predicted_aligned_error'] + else: + num_plots = 1 + + plt.figure(figsize=[8 * num_plots, 6]) + plt.subplot(1, num_plots, 1) + plt.plot(plddts) + plt.title('Predicted LDDT') + plt.xlabel('Residue') + plt.ylabel('pLDDT') + + if num_plots == 2: + plt.subplot(1, 2, 2) + plt.imshow(pae, vmin=0., vmax=max_pae, cmap='Greens_r') + plt.colorbar(fraction=0.046, pad=0.04) + plt.title('Predicted Aligned Error') + plt.xlabel('Scored residue') + plt.ylabel('Aligned residue') + + plt.savefig(png_path) + + +def main(): + """Parse output files and generate additional output files.""" + settings = Settings() + context = ExecutionContext(settings) + ranking = ResultRanking(context) + write_confidence_scores(ranking, context) + rekey_relax_metrics(ranking, context) + + # Optional outputs + if settings.output_model_pkls: + rename_model_pkls(ranking, context) + if settings.output_model_plots: + plddt_pae_plots(ranking, context) + if settings.output_pae: + # Only created by monomer_ptm and multimer models + extract_pae_to_csv(ranking, context) + if settings.output_residue_scores: + write_per_residue_scores(ranking, context) + + +if __name__ == '__main__': + main()
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/scripts/validate_fasta.py Mon Apr 03 01:00:42 2023 +0000 @@ -0,0 +1,254 @@ +"""Validate input FASTA sequence.""" + +import argparse +import re +import sys +from typing import List + +MULTIMER_MAX_SEQUENCE_COUNT = 10 + + +class Fasta: + def __init__(self, header_str: str, seq_str: str): + self.header = header_str + self.aa_seq = seq_str + + +class FastaLoader: + def __init__(self, fasta_path: str): + """Initialize from FASTA file.""" + self.fastas = [] + self.load(fasta_path) + + def load(self, fasta_path: str): + """Load bare or FASTA formatted sequence.""" + with open(fasta_path, 'r') as f: + self.content = f.read() + + if "__cn__" in self.content: + # Pasted content with escaped characters + self.newline = '__cn__' + self.read_caret = '__gt__' + else: + # Uploaded file with normal content + self.newline = '\n' + self.read_caret = '>' + + self.lines = self.content.split(self.newline) + + if not self.lines[0].startswith(self.read_caret): + # Fasta is headless, load as single sequence + self.update_fastas( + '', ''.join(self.lines) + ) + + else: + header = None + sequence = None + for line in self.lines: + if line.startswith(self.read_caret): + if header: + self.update_fastas(header, sequence) + header = '>' + self.strip_header(line) + sequence = '' + else: + sequence += line.strip('\n ') + self.update_fastas(header, sequence) + + def strip_header(self, line): + """Strip characters escaped with underscores from pasted text.""" + return re.sub(r'\_\_.{2}\_\_', '', line).strip('>') + + def update_fastas(self, header: str, sequence: str): + # if we have a sequence + if sequence: + # create generic header if not exists + if not header: + fasta_count = len(self.fastas) + header = f'>sequence_{fasta_count}' + + # Create new Fasta + self.fastas.append(Fasta(header, sequence)) + + +class FastaValidator: + def __init__( + self, + min_length=None, + max_length=None, + multiple=False): + self.multiple = multiple + self.min_length = min_length + self.max_length = max_length + self.iupac_characters = { + 'A', 'B', 'C', 'D', 'E', 'F', 'G', + 'H', 'I', 'K', 'L', 'M', 'N', 'P', + 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', + 'Y', 'Z', '-' + } + + def validate(self, fasta_list: List[Fasta]): + """Perform FASTA validation.""" + self.fasta_list = fasta_list + self.validate_num_seqs() + self.validate_length() + self.validate_alphabet() + # not checking for 'X' nucleotides at the moment. + # alphafold can throw an error if it doesn't like it. + # self.validate_x() + return self.fasta_list + + def validate_num_seqs(self) -> None: + """Assert that only one sequence has been provided.""" + fasta_count = len(self.fasta_list) + + if self.multiple: + if fasta_count < 2: + raise ValueError( + 'Error encountered validating FASTA:\n' + 'Multimer mode requires multiple input sequence.' + f' Only {fasta_count} sequences were detected in' + ' the provided file.') + self.fasta_list = self.fasta_list + + elif fasta_count > MULTIMER_MAX_SEQUENCE_COUNT: + sys.stderr.write( + f'WARNING: detected {fasta_count} sequences but the' + f' maximum allowed is {MULTIMER_MAX_SEQUENCE_COUNT}' + ' sequences. The last' + f' {fasta_count - MULTIMER_MAX_SEQUENCE_COUNT} sequence(s)' + ' have been discarded.\n') + self.fasta_list = self.fasta_list[:MULTIMER_MAX_SEQUENCE_COUNT] + else: + if fasta_count > 1: + sys.stderr.write( + 'WARNING: More than 1 sequence detected.' + ' Using first FASTA sequence as input.\n') + self.fasta_list = self.fasta_list[:1] + + elif len(self.fasta_list) == 0: + raise ValueError( + 'Error encountered validating FASTA:\n' + ' no FASTA sequences detected in input file.') + + def validate_length(self): + """Confirm whether sequence length is valid.""" + fasta = self.fasta_list[0] + if self.min_length: + if len(fasta.aa_seq) < self.min_length: + raise ValueError( + 'Error encountered validating FASTA:\n Sequence too short' + f' ({len(fasta.aa_seq)}AA).' + f' Minimum length is {self.min_length}AA.') + if self.max_length: + if len(fasta.aa_seq) > self.max_length: + raise ValueError( + 'Error encountered validating FASTA:\n' + f' Sequence too long ({len(fasta.aa_seq)}AA).' + f' Maximum length is {self.max_length}AA.') + + def validate_alphabet(self): + """Confirm whether the sequence conforms to IUPAC codes. + + If not, report the offending character and its position. + """ + fasta = self.fasta_list[0] + for i, char in enumerate(fasta.aa_seq.upper()): + if char not in self.iupac_characters: + raise ValueError( + 'Error encountered validating FASTA:\n Invalid amino acid' + f' found at pos {i}: "{char}"') + + def validate_x(self): + """Check for X bases.""" + fasta = self.fasta_list[0] + for i, char in enumerate(fasta.aa_seq.upper()): + if char == 'X': + raise ValueError( + 'Error encountered validating FASTA:\n Unsupported AA code' + f' "X" found at pos {i}') + + +class FastaWriter: + def __init__(self) -> None: + self.line_wrap = 60 + + def write(self, fasta: Fasta): + header = fasta.header + seq = self.format_sequence(fasta.aa_seq) + sys.stdout.write(header + '\n') + sys.stdout.write(seq) + + def format_sequence(self, aa_seq: str): + formatted_seq = '' + for i in range(0, len(aa_seq), self.line_wrap): + formatted_seq += aa_seq[i: i + self.line_wrap] + '\n' + return formatted_seq.upper() + + +def main(): + # load fasta file + try: + args = parse_args() + fas = FastaLoader(args.input) + + # validate + fv = FastaValidator( + min_length=args.min_length, + max_length=args.max_length, + multiple=args.multimer, + ) + clean_fastas = fv.validate(fas.fastas) + + # write clean data + fw = FastaWriter() + for fas in clean_fastas: + fw.write(fas) + + sys.stderr.write("Validated FASTA sequence(s):\n\n") + for fas in clean_fastas: + sys.stderr.write(fas.header + '\n') + sys.stderr.write(fas.aa_seq + '\n\n') + + except ValueError as exc: + sys.stderr.write(f"{exc}\n\n") + raise exc + + except Exception as exc: + sys.stderr.write( + "Input error: FASTA input is invalid. Please check your input.\n\n" + ) + raise exc + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "input", + help="input fasta file", + type=str + ) + parser.add_argument( + "--min_length", + dest='min_length', + help="Minimum length of input protein sequence (AA)", + default=None, + type=int, + ) + parser.add_argument( + "--max_length", + dest='max_length', + help="Maximum length of input protein sequence (AA)", + default=None, + type=int, + ) + parser.add_argument( + "--multimer", + action='store_true', + help="Require multiple input sequences", + ) + return parser.parse_args() + + +if __name__ == '__main__': + main()
--- a/validate_fasta.py Fri Mar 10 02:48:07 2023 +0000 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,254 +0,0 @@ -"""Validate input FASTA sequence.""" - -import argparse -import re -import sys -from typing import List - -MULTIMER_MAX_SEQUENCE_COUNT = 10 - - -class Fasta: - def __init__(self, header_str: str, seq_str: str): - self.header = header_str - self.aa_seq = seq_str - - -class FastaLoader: - def __init__(self, fasta_path: str): - """Initialize from FASTA file.""" - self.fastas = [] - self.load(fasta_path) - - def load(self, fasta_path: str): - """Load bare or FASTA formatted sequence.""" - with open(fasta_path, 'r') as f: - self.content = f.read() - - if "__cn__" in self.content: - # Pasted content with escaped characters - self.newline = '__cn__' - self.read_caret = '__gt__' - else: - # Uploaded file with normal content - self.newline = '\n' - self.read_caret = '>' - - self.lines = self.content.split(self.newline) - - if not self.lines[0].startswith(self.read_caret): - # Fasta is headless, load as single sequence - self.update_fastas( - '', ''.join(self.lines) - ) - - else: - header = None - sequence = None - for line in self.lines: - if line.startswith(self.read_caret): - if header: - self.update_fastas(header, sequence) - header = '>' + self.strip_header(line) - sequence = '' - else: - sequence += line.strip('\n ') - self.update_fastas(header, sequence) - - def strip_header(self, line): - """Strip characters escaped with underscores from pasted text.""" - return re.sub(r'\_\_.{2}\_\_', '', line).strip('>') - - def update_fastas(self, header: str, sequence: str): - # if we have a sequence - if sequence: - # create generic header if not exists - if not header: - fasta_count = len(self.fastas) - header = f'>sequence_{fasta_count}' - - # Create new Fasta - self.fastas.append(Fasta(header, sequence)) - - -class FastaValidator: - def __init__( - self, - min_length=None, - max_length=None, - multiple=False): - self.multiple = multiple - self.min_length = min_length - self.max_length = max_length - self.iupac_characters = { - 'A', 'B', 'C', 'D', 'E', 'F', 'G', - 'H', 'I', 'K', 'L', 'M', 'N', 'P', - 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', - 'Y', 'Z', '-' - } - - def validate(self, fasta_list: List[Fasta]): - """Perform FASTA validation.""" - self.fasta_list = fasta_list - self.validate_num_seqs() - self.validate_length() - self.validate_alphabet() - # not checking for 'X' nucleotides at the moment. - # alphafold can throw an error if it doesn't like it. - # self.validate_x() - return self.fasta_list - - def validate_num_seqs(self) -> None: - """Assert that only one sequence has been provided.""" - fasta_count = len(self.fasta_list) - - if self.multiple: - if fasta_count < 2: - raise ValueError( - 'Error encountered validating FASTA:\n' - 'Multimer mode requires multiple input sequence.' - f' Only {fasta_count} sequences were detected in' - ' the provided file.') - self.fasta_list = self.fasta_list - - elif fasta_count > MULTIMER_MAX_SEQUENCE_COUNT: - sys.stderr.write( - f'WARNING: detected {fasta_count} sequences but the' - f' maximum allowed is {MULTIMER_MAX_SEQUENCE_COUNT}' - ' sequences. The last' - f' {fasta_count - MULTIMER_MAX_SEQUENCE_COUNT} sequence(s)' - ' have been discarded.\n') - self.fasta_list = self.fasta_list[:MULTIMER_MAX_SEQUENCE_COUNT] - else: - if fasta_count > 1: - sys.stderr.write( - 'WARNING: More than 1 sequence detected.' - ' Using first FASTA sequence as input.\n') - self.fasta_list = self.fasta_list[:1] - - elif len(self.fasta_list) == 0: - raise ValueError( - 'Error encountered validating FASTA:\n' - ' no FASTA sequences detected in input file.') - - def validate_length(self): - """Confirm whether sequence length is valid.""" - fasta = self.fasta_list[0] - if self.min_length: - if len(fasta.aa_seq) < self.min_length: - raise ValueError( - 'Error encountered validating FASTA:\n Sequence too short' - f' ({len(fasta.aa_seq)}AA).' - f' Minimum length is {self.min_length}AA.') - if self.max_length: - if len(fasta.aa_seq) > self.max_length: - raise ValueError( - 'Error encountered validating FASTA:\n' - f' Sequence too long ({len(fasta.aa_seq)}AA).' - f' Maximum length is {self.max_length}AA.') - - def validate_alphabet(self): - """Confirm whether the sequence conforms to IUPAC codes. - - If not, report the offending character and its position. - """ - fasta = self.fasta_list[0] - for i, char in enumerate(fasta.aa_seq.upper()): - if char not in self.iupac_characters: - raise ValueError( - 'Error encountered validating FASTA:\n Invalid amino acid' - f' found at pos {i}: "{char}"') - - def validate_x(self): - """Check for X bases.""" - fasta = self.fasta_list[0] - for i, char in enumerate(fasta.aa_seq.upper()): - if char == 'X': - raise ValueError( - 'Error encountered validating FASTA:\n Unsupported AA code' - f' "X" found at pos {i}') - - -class FastaWriter: - def __init__(self) -> None: - self.line_wrap = 60 - - def write(self, fasta: Fasta): - header = fasta.header - seq = self.format_sequence(fasta.aa_seq) - sys.stdout.write(header + '\n') - sys.stdout.write(seq) - - def format_sequence(self, aa_seq: str): - formatted_seq = '' - for i in range(0, len(aa_seq), self.line_wrap): - formatted_seq += aa_seq[i: i + self.line_wrap] + '\n' - return formatted_seq.upper() - - -def main(): - # load fasta file - try: - args = parse_args() - fas = FastaLoader(args.input) - - # validate - fv = FastaValidator( - min_length=args.min_length, - max_length=args.max_length, - multiple=args.multimer, - ) - clean_fastas = fv.validate(fas.fastas) - - # write clean data - fw = FastaWriter() - for fas in clean_fastas: - fw.write(fas) - - sys.stderr.write("Validated FASTA sequence(s):\n\n") - for fas in clean_fastas: - sys.stderr.write(fas.header + '\n') - sys.stderr.write(fas.aa_seq + '\n\n') - - except ValueError as exc: - sys.stderr.write(f"{exc}\n\n") - raise exc - - except Exception as exc: - sys.stderr.write( - "Input error: FASTA input is invalid. Please check your input.\n\n" - ) - raise exc - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() - parser.add_argument( - "input", - help="input fasta file", - type=str - ) - parser.add_argument( - "--min_length", - dest='min_length', - help="Minimum length of input protein sequence (AA)", - default=None, - type=int, - ) - parser.add_argument( - "--max_length", - dest='max_length', - help="Maximum length of input protein sequence (AA)", - default=None, - type=int, - ) - parser.add_argument( - "--multimer", - action='store_true', - help="Require multiple input sequences", - ) - return parser.parse_args() - - -if __name__ == '__main__': - main()