changeset 16:f9eb041c518c draft

planemo upload for repository https://github.com/usegalaxy-au/tools-au commit ee77734f1800350fa2a6ef28b2b8eade304a456f-dirty
author galaxy-australia
date Mon, 03 Apr 2023 01:00:42 +0000
parents a58f7eb0df2c
children 5b85006245f3
files README.rst alphafold.xml outputs.py scripts/outputs.py scripts/validate_fasta.py validate_fasta.py
diffstat 6 files changed, 608 insertions(+), 601 deletions(-) [+]
line wrap: on
line diff
--- a/README.rst	Fri Mar 10 02:48:07 2023 +0000
+++ b/README.rst	Mon Apr 03 01:00:42 2023 +0000
@@ -75,18 +75,20 @@
 ~~~~~~~~~~~~~~
 
 Alphafold needs reference data to run. The wrapper expects this data to
-be present at ``/data/alphafold_databases``. A custom path will be read from
-the ``ALPHAFOLD_DB`` environment variable, if set.
+be present at ``/$ALPHAFOLD_DB/TOOL_MINOR_VERSION``.
+Where ``ALPHAFOLD_DB`` is a custom path that will be read from
+the ``ALPHAFOLD_DB`` environment variable (defaulting to ``/data``).
+And TOOL_MINOR_VERSION is the alphafold version, e.g. ``2.3.1``.
 
 To download the AlphaFold reference DBs:
 
 ::
 
    # Set your AlphaFold DB path
-   ALPHAFOLD_DB=/data/alphafold_databases
+   ALPHAFOLD_DB=/data/alphafold_databases/2.3.1
 
    # Set your target AlphaFold version
-   ALPHAFOLD_VERSION=  # e.g. 2.1.2
+   ALPHAFOLD_VERSION=2.3.1
 
    # Download repo
    wget https://github.com/deepmind/alphafold/releases/tag/v${ALPHAFOLD_VERSION}.tar.gz
@@ -110,7 +112,7 @@
    # NOTE: this structure will change between minor AlphaFold versions
    # The tree shown below was updated for v2.3.1
 
-   data/alphafold_databases
+   data/alphafold_databases/2.3.1/
    ├── bfd
    │   ├── bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt_a3m.ffdata
    │   ├── bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt_a3m.ffindex
@@ -176,7 +178,7 @@
 
 ::
 
-   data/alphafold_databases
+   data/alphafold_databases/2.3.1/
    ├── small_bfd
    │   └── bfd-first_non_consensus_sequences.fasta
 
@@ -193,7 +195,7 @@
 If you wish to continue hosting prior versions of the tool, you must maintain
 the reference DBs for each version. The ``ALPHAFOLD_DB`` environment variable
 must then be set respectively for each tool version in your job conf (on Galaxy
-AU this is currently `configured with TPV<https://github.com/usegalaxy-au/infrastructure/blob/master/files/galaxy/dynamic_job_rules/production/total_perspective_vortex/tools.yml#L1515-L1554>`_).
+AU this is currently `configured with TPV <https://github.com/usegalaxy-au/infrastructure/blob/master/files/galaxy/dynamic_job_rules/production/total_perspective_vortex/tools.yml#L1515-L1554>`_).
 
 To minimize redundancy between DB version, we have symlinked the database
 components that are unchanging between versions. In ``v2.1.2 -> v2.3.1`` the BFD
--- a/alphafold.xml	Fri Mar 10 02:48:07 2023 +0000
+++ b/alphafold.xml	Mon Apr 03 01:00:42 2023 +0000
@@ -2,7 +2,8 @@
     <description> - AI-guided 3D structural prediction of proteins</description>
     <macros>
       <token name="@TOOL_VERSION@">2.3.1</token>
-      <token name="@VERSION_SUFFIX@">1</token>
+      <token name="@TOOL_MINOR_VERSION@">2.3</token>
+      <token name="@VERSION_SUFFIX@">2</token>
       <import>macro_output.xml</import>
       <import>macro_test_output.xml</import>
     </macros>
@@ -24,8 +25,11 @@
 ## in planemo's gx_venv_n/bin/activate script. AlphaFold outputs will be copied
 ## from the test-data directory instead of running the tool.
 
-## $ALPHAFOLD_DB variable should point to the location of the AlphaFold
-## databases - defaults to /data
+## $ALPHAFOLD_DB variable should point to the location containing the versioned
+## AlphaFold databases - defaults to /data
+## that is the directory should contain a subdir / symlink named identical as
+## the value of the TOOL_MINOR_VERSION token which contains the AF reference data
+## for the corresponding version
 
 ## Read FASTA input -----------------------------------------------------------
 #if $fasta_or_text.input_mode == 'history':
@@ -34,7 +38,7 @@
     echo '$fasta_or_text.fasta_text' > input.fasta
 #end if
 
-&& python3 '$__tool_directory__/validate_fasta.py' input.fasta
+&& python3 '$__tool_directory__/scripts/validate_fasta.py' input.fasta
 --min_length \${ALPHAFOLD_AA_LENGTH_MIN:-0}
 --max_length \${ALPHAFOLD_AA_LENGTH_MAX:-0}
 #if $model_preset == 'multimer':
@@ -51,26 +55,26 @@
 #if os.environ.get('PLANEMO_TESTING'):
     ## Run in testing mode (mocks a successful AlphaFold run by copying outputs)
     && echo "Creating dummy outputs for model_preset=$model_preset..."
-    && bash '$__tool_directory__/mock_alphafold.sh' $model_preset
+    && bash '$__tool_directory__/scripts/mock_alphafold.sh' $model_preset
 #else:
     ## Run AlphaFold
     && python /app/alphafold/run_alphafold.py
         --fasta_paths alphafold.fasta
         --output_dir output
-        --data_dir \${ALPHAFOLD_DB:-/data}
+        --data_dir \${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/
         --model_preset=$model_preset
 
         ## Set reference database paths
-        --uniref90_database_path   \${ALPHAFOLD_DB:-/data}/uniref90/uniref90.fasta
-        --mgnify_database_path     \${ALPHAFOLD_DB:-/data}/mgnify/mgy_clusters_2022_05.fa
-        --template_mmcif_dir       \${ALPHAFOLD_DB:-/data}/pdb_mmcif/mmcif_files
-        --obsolete_pdbs_path       \${ALPHAFOLD_DB:-/data}/pdb_mmcif/obsolete.dat
+        --uniref90_database_path   \${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/uniref90/uniref90.fasta
+        --mgnify_database_path     \${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/mgnify/mgy_clusters_2022_05.fa
+        --template_mmcif_dir       \${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/pdb_mmcif/mmcif_files
+        --obsolete_pdbs_path       \${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/pdb_mmcif/obsolete.dat
         #if $dbs == 'full':
-        --bfd_database_path        \${ALPHAFOLD_DB:-/data}/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt
-        --uniref30_database_path   \${ALPHAFOLD_DB:-/data}/uniref30/UniRef30_2021_03
+        --bfd_database_path        \${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt
+        --uniref30_database_path   \${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/uniref30/UniRef30_2021_03
         #else
         --db_preset=reduced_dbs
-        --small_bfd_database_path  \${ALPHAFOLD_DB:-/data}/small_bfd/bfd-first_non_consensus_sequences.fasta
+        --small_bfd_database_path  \${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/small_bfd/bfd-first_non_consensus_sequences.fasta
         #end if
 
         #if $max_template_date:
@@ -82,16 +86,16 @@
         --use_gpu_relax=\${ALPHAFOLD_USE_GPU:-True}  ## introduced in v2.1.2
 
         #if $model_preset == 'multimer':
-        --pdb_seqres_database_path=\${ALPHAFOLD_DB:-/data}/pdb_seqres/pdb_seqres.txt
-        --uniprot_database_path=\${ALPHAFOLD_DB:-/data}/uniprot/uniprot.fasta
+        --pdb_seqres_database_path=\${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/pdb_seqres/pdb_seqres.txt
+        --uniprot_database_path=\${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/uniprot/uniprot.fasta
         --num_multimer_predictions_per_model=1  ## introduced in v2.2.0
         #else
-        --pdb70_database_path \${ALPHAFOLD_DB:-/data}/pdb70/pdb70
+        --pdb70_database_path \${ALPHAFOLD_DB:-/data}/@TOOL_MINOR_VERSION@/pdb70/pdb70
         #end if
 #end if
 
 ## Generate additional outputs ------------------------------------------------
-&& python3 '$__tool_directory__/outputs.py' output/alphafold
+&& python3 '$__tool_directory__/scripts/outputs.py' output/alphafold
 $outputs.plddts
 $outputs.model_pkls
 $outputs.pae_csv
--- a/outputs.py	Fri Mar 10 02:48:07 2023 +0000
+++ /dev/null	Thu Jan 01 00:00:00 1970 +0000
@@ -1,323 +0,0 @@
-"""Generate additional output files not produced by AlphaFold.
-
-Currently this is includes:
-- model confidence scores
-- per-residue confidence scores (pLDDTs - optional output)
-- model_*.pkl files renamed with rank order
-
-N.B. There have been issues with this script breaking between AlphaFold
-versions due to minor changes in the output directory structure across minor
-versions. It will likely need updating with future releases of AlphaFold.
-
-This code is more complex than you might expect due to the output files
-'moving around' considerably, depending on run parameters. You will see that
-several output paths are determined dynamically.
-"""
-
-import argparse
-import json
-import os
-import pickle as pk
-import shutil
-from matplotlib import pyplot as plt
-from pathlib import Path
-from typing import List
-
-# Output file paths
-OUTPUT_DIR = 'extra'
-OUTPUTS = {
-    'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl',
-    'model_pae': OUTPUT_DIR + '/pae_ranked_{rank}.csv',
-    'model_plot': OUTPUT_DIR + '/ranked_{rank}.png',
-    'model_confidence_scores': OUTPUT_DIR + '/model_confidence_scores.tsv',
-    'plddts': OUTPUT_DIR + '/plddts.tsv',
-    'relax': OUTPUT_DIR + '/relax_metrics_ranked.json',
-}
-
-# Keys for accessing confidence data from JSON/pkl files
-# They change depending on whether the run was monomer or multimer
-PLDDT_KEY = {
-    'monomer': 'plddts',
-    'multimer': 'iptm+ptm',
-}
-
-
-class Settings:
-    """Parse and store settings/config."""
-    def __init__(self):
-        self.workdir = None
-        self.output_confidence_scores = True
-        self.output_residue_scores = False
-        self.is_multimer = False
-        self.parse()
-
-    def parse(self) -> None:
-        parser = argparse.ArgumentParser()
-        parser.add_argument(
-            "workdir",
-            help="alphafold output directory",
-            type=str
-        )
-        parser.add_argument(
-            "-p",
-            "--plddts",
-            help="output per-residue confidence scores (pLDDTs)",
-            action="store_true"
-        )
-        parser.add_argument(
-            "-m",
-            "--multimer",
-            help="parse output from AlphaFold multimer",
-            action="store_true"
-        )
-        parser.add_argument(
-            "--pkl",
-            help="rename model pkl outputs with rank order",
-            action="store_true"
-        )
-        parser.add_argument(
-            "--pae",
-            help="extract PAE from pkl files to CSV format",
-            action="store_true"
-        )
-        parser.add_argument(
-            "--plot",
-            help="Plot pLDDT and PAE for each model",
-            action="store_true"
-        )
-        args = parser.parse_args()
-        self.workdir = Path(args.workdir.rstrip('/'))
-        self.output_residue_scores = args.plddts
-        self.output_model_pkls = args.pkl
-        self.output_model_plots = args.plot
-        self.output_pae = args.pae
-        self.is_multimer = args.multimer
-        self.output_dir = self.workdir / OUTPUT_DIR
-        os.makedirs(self.output_dir, exist_ok=True)
-
-
-class ExecutionContext:
-    """Collect file paths etc."""
-    def __init__(self, settings: Settings):
-        self.settings = settings
-        if settings.is_multimer:
-            self.plddt_key = PLDDT_KEY['multimer']
-        else:
-            self.plddt_key = PLDDT_KEY['monomer']
-
-    def get_model_key(self, ix: int) -> str:
-        """Return json key for model index.
-
-        The key format changed between minor AlphaFold versions so this
-        function determines the correct key.
-        """
-        with open(self.ranking_debug) as f:
-            data = json.load(f)
-        model_keys = list(data[self.plddt_key].keys())
-        for k in model_keys:
-            if k.startswith(f"model_{ix}_"):
-                return k
-        return KeyError(
-            f'Could not find key for index={ix} in'
-            ' ranking_debug.json')
-
-    @property
-    def ranking_debug(self) -> str:
-        return self.settings.workdir / 'ranking_debug.json'
-
-    @property
-    def relax_metrics(self) -> str:
-        return self.settings.workdir / 'relax_metrics.json'
-
-    @property
-    def relax_metrics_ranked(self) -> str:
-        return self.settings.workdir / 'relax_metrics_ranked.json'
-
-    @property
-    def model_pkl_paths(self) -> List[str]:
-        return sorted([
-            self.settings.workdir / f
-            for f in os.listdir(self.settings.workdir)
-            if f.startswith('result_model_') and f.endswith('.pkl')
-        ])
-
-
-class ResultModelPrediction:
-    """Load and manipulate data from result_model_*.pkl files."""
-    def __init__(self, path: str, context: ExecutionContext):
-        self.context = context
-        self.path = path
-        self.name = os.path.basename(path).replace('result_', '').split('.')[0]
-        with open(path, 'rb') as path:
-            self.data = pk.load(path)
-
-    @property
-    def plddts(self) -> List[float]:
-        """Return pLDDT scores for each residue."""
-        return list(self.data['plddt'])
-
-
-class ResultRanking:
-    """Load and manipulate data from ranking_debug.json file."""
-
-    def __init__(self, context: ExecutionContext):
-        self.path = context.ranking_debug
-        self.context = context
-        with open(self.path, 'r') as f:
-            self.data = json.load(f)
-
-    @property
-    def order(self) -> List[str]:
-        """Return ordered list of model indexes."""
-        return self.data['order']
-
-    def get_plddt_for_rank(self, rank: int) -> List[float]:
-        """Get pLDDT score for model instance."""
-        return self.data[self.context.plddt_key][self.data['order'][rank - 1]]
-
-    def get_rank_for_model(self, model_name: str) -> int:
-        """Return 0-indexed rank for given model name.
-
-        Model names are expressed in result_model_*.pkl file names.
-        """
-        return self.data['order'].index(model_name)
-
-
-def write_confidence_scores(ranking: ResultRanking, context: ExecutionContext):
-    """Write per-model confidence scores."""
-    path = context.settings.workdir / OUTPUTS['model_confidence_scores']
-    with open(path, 'w') as f:
-        for rank in range(1, 6):
-            score = ranking.get_plddt_for_rank(rank)
-            f.write(f'ranked_{rank - 1}\t{score:.2f}\n')
-
-
-def write_per_residue_scores(
-    ranking: ResultRanking,
-    context: ExecutionContext,
-):
-    """Write per-residue plddts for each model.
-
-    A row of plddt values is written for each model in tabular format.
-    """
-    model_plddts = {}
-    for i, path in enumerate(context.model_pkl_paths):
-        model = ResultModelPrediction(path, context)
-        rank = ranking.get_rank_for_model(model.name)
-        model_plddts[rank] = model.plddts
-
-    path = context.settings.workdir / OUTPUTS['plddts']
-    with open(path, 'w') as f:
-        for i in sorted(list(model_plddts.keys())):
-            row = [f'ranked_{i}'] + [
-                str(x) for x in model_plddts[i]
-            ]
-            f.write('\t'.join(row) + '\n')
-
-
-def rename_model_pkls(ranking: ResultRanking, context: ExecutionContext):
-    """Rename model.pkl files so the rank order is implicit."""
-    for path in context.model_pkl_paths:
-        model = ResultModelPrediction(path, context)
-        rank = ranking.get_rank_for_model(model.name)
-        new_path = (
-            context.settings.workdir
-            / OUTPUTS['model_pkl'].format(rank=rank)
-        )
-        shutil.copyfile(path, new_path)
-
-
-def extract_pae_to_csv(ranking: ResultRanking, context: ExecutionContext):
-    """Extract predicted alignment error matrix from pickle files.
-
-    Creates a CSV file for each of five ranked models.
-    """
-    for path in context.model_pkl_paths:
-        model = ResultModelPrediction(path, context)
-        rank = ranking.get_rank_for_model(model.name)
-        with open(path, 'rb') as f:
-            data = pk.load(f)
-        if 'predicted_aligned_error' not in data:
-            print("Skipping PAE output"
-                  f" - not found in {path}."
-                  " Running with model_preset=monomer?")
-            return
-        pae = data['predicted_aligned_error']
-        out_path = (
-            context.settings.workdir
-            / OUTPUTS['model_pae'].format(rank=rank)
-        )
-        with open(out_path, 'w') as f:
-            for row in pae:
-                f.write(','.join([str(x) for x in row]) + '\n')
-
-
-def rekey_relax_metrics(ranking: ResultRanking, context: ExecutionContext):
-    """Replace keys in relax_metrics.json with 0-indexed rank."""
-    with open(context.relax_metrics) as f:
-        data = json.load(f)
-        for k in list(data.keys()):
-            rank = ranking.get_rank_for_model(k)
-            data[f'ranked_{rank}'] = data.pop(k)
-    new_path = context.settings.workdir / OUTPUTS['relax']
-    with open(new_path, 'w') as f:
-        json.dump(data, f)
-
-
-def plddt_pae_plots(ranking: ResultRanking, context: ExecutionContext):
-    """Generate a pLDDT + PAE plot for each model."""
-    for path in context.model_pkl_paths:
-        num_plots = 2
-        model = ResultModelPrediction(path, context)
-        rank = ranking.get_rank_for_model(model.name)
-        png_path = (
-            context.settings.workdir
-            / OUTPUTS['model_plot'].format(rank=rank)
-        )
-        plddts = model.data['plddt']
-        if 'predicted_aligned_error' in model.data:
-            pae = model.data['predicted_aligned_error']
-            max_pae = model.data['max_predicted_aligned_error']
-        else:
-            num_plots = 1
-
-        plt.figure(figsize=[8 * num_plots, 6])
-        plt.subplot(1, num_plots, 1)
-        plt.plot(plddts)
-        plt.title('Predicted LDDT')
-        plt.xlabel('Residue')
-        plt.ylabel('pLDDT')
-
-        if num_plots == 2:
-            plt.subplot(1, 2, 2)
-            plt.imshow(pae, vmin=0., vmax=max_pae, cmap='Greens_r')
-            plt.colorbar(fraction=0.046, pad=0.04)
-            plt.title('Predicted Aligned Error')
-            plt.xlabel('Scored residue')
-            plt.ylabel('Aligned residue')
-
-        plt.savefig(png_path)
-
-
-def main():
-    """Parse output files and generate additional output files."""
-    settings = Settings()
-    context = ExecutionContext(settings)
-    ranking = ResultRanking(context)
-    write_confidence_scores(ranking, context)
-    rekey_relax_metrics(ranking, context)
-
-    # Optional outputs
-    if settings.output_model_pkls:
-        rename_model_pkls(ranking, context)
-    if settings.output_model_plots:
-        plddt_pae_plots(ranking, context)
-    if settings.output_pae:
-        # Only created by monomer_ptm and multimer models
-        extract_pae_to_csv(ranking, context)
-    if settings.output_residue_scores:
-        write_per_residue_scores(ranking, context)
-
-
-if __name__ == '__main__':
-    main()
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/scripts/outputs.py	Mon Apr 03 01:00:42 2023 +0000
@@ -0,0 +1,324 @@
+"""Generate additional output files not produced by AlphaFold.
+
+Currently this is includes:
+- model confidence scores
+- per-residue confidence scores (pLDDTs - optional output)
+- model_*.pkl files renamed with rank order
+
+N.B. There have been issues with this script breaking between AlphaFold
+versions due to minor changes in the output directory structure across minor
+versions. It will likely need updating with future releases of AlphaFold.
+
+This code is more complex than you might expect due to the output files
+'moving around' considerably, depending on run parameters. You will see that
+several output paths are determined dynamically.
+"""
+
+import argparse
+import json
+import os
+import pickle as pk
+import shutil
+from pathlib import Path
+from typing import List
+
+from matplotlib import pyplot as plt
+
+# Output file paths
+OUTPUT_DIR = 'extra'
+OUTPUTS = {
+    'model_pkl': OUTPUT_DIR + '/ranked_{rank}.pkl',
+    'model_pae': OUTPUT_DIR + '/pae_ranked_{rank}.csv',
+    'model_plot': OUTPUT_DIR + '/ranked_{rank}.png',
+    'model_confidence_scores': OUTPUT_DIR + '/model_confidence_scores.tsv',
+    'plddts': OUTPUT_DIR + '/plddts.tsv',
+    'relax': OUTPUT_DIR + '/relax_metrics_ranked.json',
+}
+
+# Keys for accessing confidence data from JSON/pkl files
+# They change depending on whether the run was monomer or multimer
+PLDDT_KEY = {
+    'monomer': 'plddts',
+    'multimer': 'iptm+ptm',
+}
+
+
+class Settings:
+    """Parse and store settings/config."""
+    def __init__(self):
+        self.workdir = None
+        self.output_confidence_scores = True
+        self.output_residue_scores = False
+        self.is_multimer = False
+        self.parse()
+
+    def parse(self) -> None:
+        parser = argparse.ArgumentParser()
+        parser.add_argument(
+            "workdir",
+            help="alphafold output directory",
+            type=str
+        )
+        parser.add_argument(
+            "-p",
+            "--plddts",
+            help="output per-residue confidence scores (pLDDTs)",
+            action="store_true"
+        )
+        parser.add_argument(
+            "-m",
+            "--multimer",
+            help="parse output from AlphaFold multimer",
+            action="store_true"
+        )
+        parser.add_argument(
+            "--pkl",
+            help="rename model pkl outputs with rank order",
+            action="store_true"
+        )
+        parser.add_argument(
+            "--pae",
+            help="extract PAE from pkl files to CSV format",
+            action="store_true"
+        )
+        parser.add_argument(
+            "--plot",
+            help="Plot pLDDT and PAE for each model",
+            action="store_true"
+        )
+        args = parser.parse_args()
+        self.workdir = Path(args.workdir.rstrip('/'))
+        self.output_residue_scores = args.plddts
+        self.output_model_pkls = args.pkl
+        self.output_model_plots = args.plot
+        self.output_pae = args.pae
+        self.is_multimer = args.multimer
+        self.output_dir = self.workdir / OUTPUT_DIR
+        os.makedirs(self.output_dir, exist_ok=True)
+
+
+class ExecutionContext:
+    """Collect file paths etc."""
+    def __init__(self, settings: Settings):
+        self.settings = settings
+        if settings.is_multimer:
+            self.plddt_key = PLDDT_KEY['multimer']
+        else:
+            self.plddt_key = PLDDT_KEY['monomer']
+
+    def get_model_key(self, ix: int) -> str:
+        """Return json key for model index.
+
+        The key format changed between minor AlphaFold versions so this
+        function determines the correct key.
+        """
+        with open(self.ranking_debug) as f:
+            data = json.load(f)
+        model_keys = list(data[self.plddt_key].keys())
+        for k in model_keys:
+            if k.startswith(f"model_{ix}_"):
+                return k
+        return KeyError(
+            f'Could not find key for index={ix} in'
+            ' ranking_debug.json')
+
+    @property
+    def ranking_debug(self) -> str:
+        return self.settings.workdir / 'ranking_debug.json'
+
+    @property
+    def relax_metrics(self) -> str:
+        return self.settings.workdir / 'relax_metrics.json'
+
+    @property
+    def relax_metrics_ranked(self) -> str:
+        return self.settings.workdir / 'relax_metrics_ranked.json'
+
+    @property
+    def model_pkl_paths(self) -> List[str]:
+        return sorted([
+            self.settings.workdir / f
+            for f in os.listdir(self.settings.workdir)
+            if f.startswith('result_model_') and f.endswith('.pkl')
+        ])
+
+
+class ResultModelPrediction:
+    """Load and manipulate data from result_model_*.pkl files."""
+    def __init__(self, path: str, context: ExecutionContext):
+        self.context = context
+        self.path = path
+        self.name = os.path.basename(path).replace('result_', '').split('.')[0]
+        with open(path, 'rb') as path:
+            self.data = pk.load(path)
+
+    @property
+    def plddts(self) -> List[float]:
+        """Return pLDDT scores for each residue."""
+        return list(self.data['plddt'])
+
+
+class ResultRanking:
+    """Load and manipulate data from ranking_debug.json file."""
+
+    def __init__(self, context: ExecutionContext):
+        self.path = context.ranking_debug
+        self.context = context
+        with open(self.path, 'r') as f:
+            self.data = json.load(f)
+
+    @property
+    def order(self) -> List[str]:
+        """Return ordered list of model indexes."""
+        return self.data['order']
+
+    def get_plddt_for_rank(self, rank: int) -> List[float]:
+        """Get pLDDT score for model instance."""
+        return self.data[self.context.plddt_key][self.data['order'][rank - 1]]
+
+    def get_rank_for_model(self, model_name: str) -> int:
+        """Return 0-indexed rank for given model name.
+
+        Model names are expressed in result_model_*.pkl file names.
+        """
+        return self.data['order'].index(model_name)
+
+
+def write_confidence_scores(ranking: ResultRanking, context: ExecutionContext):
+    """Write per-model confidence scores."""
+    path = context.settings.workdir / OUTPUTS['model_confidence_scores']
+    with open(path, 'w') as f:
+        for rank in range(1, 6):
+            score = ranking.get_plddt_for_rank(rank)
+            f.write(f'ranked_{rank - 1}\t{score:.2f}\n')
+
+
+def write_per_residue_scores(
+    ranking: ResultRanking,
+    context: ExecutionContext,
+):
+    """Write per-residue plddts for each model.
+
+    A row of plddt values is written for each model in tabular format.
+    """
+    model_plddts = {}
+    for i, path in enumerate(context.model_pkl_paths):
+        model = ResultModelPrediction(path, context)
+        rank = ranking.get_rank_for_model(model.name)
+        model_plddts[rank] = model.plddts
+
+    path = context.settings.workdir / OUTPUTS['plddts']
+    with open(path, 'w') as f:
+        for i in sorted(list(model_plddts.keys())):
+            row = [f'ranked_{i}'] + [
+                str(x) for x in model_plddts[i]
+            ]
+            f.write('\t'.join(row) + '\n')
+
+
+def rename_model_pkls(ranking: ResultRanking, context: ExecutionContext):
+    """Rename model.pkl files so the rank order is implicit."""
+    for path in context.model_pkl_paths:
+        model = ResultModelPrediction(path, context)
+        rank = ranking.get_rank_for_model(model.name)
+        new_path = (
+            context.settings.workdir
+            / OUTPUTS['model_pkl'].format(rank=rank)
+        )
+        shutil.copyfile(path, new_path)
+
+
+def extract_pae_to_csv(ranking: ResultRanking, context: ExecutionContext):
+    """Extract predicted alignment error matrix from pickle files.
+
+    Creates a CSV file for each of five ranked models.
+    """
+    for path in context.model_pkl_paths:
+        model = ResultModelPrediction(path, context)
+        rank = ranking.get_rank_for_model(model.name)
+        with open(path, 'rb') as f:
+            data = pk.load(f)
+        if 'predicted_aligned_error' not in data:
+            print("Skipping PAE output"
+                  f" - not found in {path}."
+                  " Running with model_preset=monomer?")
+            return
+        pae = data['predicted_aligned_error']
+        out_path = (
+            context.settings.workdir
+            / OUTPUTS['model_pae'].format(rank=rank)
+        )
+        with open(out_path, 'w') as f:
+            for row in pae:
+                f.write(','.join([str(x) for x in row]) + '\n')
+
+
+def rekey_relax_metrics(ranking: ResultRanking, context: ExecutionContext):
+    """Replace keys in relax_metrics.json with 0-indexed rank."""
+    with open(context.relax_metrics) as f:
+        data = json.load(f)
+        for k in list(data.keys()):
+            rank = ranking.get_rank_for_model(k)
+            data[f'ranked_{rank}'] = data.pop(k)
+    new_path = context.settings.workdir / OUTPUTS['relax']
+    with open(new_path, 'w') as f:
+        json.dump(data, f)
+
+
+def plddt_pae_plots(ranking: ResultRanking, context: ExecutionContext):
+    """Generate a pLDDT + PAE plot for each model."""
+    for path in context.model_pkl_paths:
+        num_plots = 2
+        model = ResultModelPrediction(path, context)
+        rank = ranking.get_rank_for_model(model.name)
+        png_path = (
+            context.settings.workdir
+            / OUTPUTS['model_plot'].format(rank=rank)
+        )
+        plddts = model.data['plddt']
+        if 'predicted_aligned_error' in model.data:
+            pae = model.data['predicted_aligned_error']
+            max_pae = model.data['max_predicted_aligned_error']
+        else:
+            num_plots = 1
+
+        plt.figure(figsize=[8 * num_plots, 6])
+        plt.subplot(1, num_plots, 1)
+        plt.plot(plddts)
+        plt.title('Predicted LDDT')
+        plt.xlabel('Residue')
+        plt.ylabel('pLDDT')
+
+        if num_plots == 2:
+            plt.subplot(1, 2, 2)
+            plt.imshow(pae, vmin=0., vmax=max_pae, cmap='Greens_r')
+            plt.colorbar(fraction=0.046, pad=0.04)
+            plt.title('Predicted Aligned Error')
+            plt.xlabel('Scored residue')
+            plt.ylabel('Aligned residue')
+
+        plt.savefig(png_path)
+
+
+def main():
+    """Parse output files and generate additional output files."""
+    settings = Settings()
+    context = ExecutionContext(settings)
+    ranking = ResultRanking(context)
+    write_confidence_scores(ranking, context)
+    rekey_relax_metrics(ranking, context)
+
+    # Optional outputs
+    if settings.output_model_pkls:
+        rename_model_pkls(ranking, context)
+    if settings.output_model_plots:
+        plddt_pae_plots(ranking, context)
+    if settings.output_pae:
+        # Only created by monomer_ptm and multimer models
+        extract_pae_to_csv(ranking, context)
+    if settings.output_residue_scores:
+        write_per_residue_scores(ranking, context)
+
+
+if __name__ == '__main__':
+    main()
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/scripts/validate_fasta.py	Mon Apr 03 01:00:42 2023 +0000
@@ -0,0 +1,254 @@
+"""Validate input FASTA sequence."""
+
+import argparse
+import re
+import sys
+from typing import List
+
+MULTIMER_MAX_SEQUENCE_COUNT = 10
+
+
+class Fasta:
+    def __init__(self, header_str: str, seq_str: str):
+        self.header = header_str
+        self.aa_seq = seq_str
+
+
+class FastaLoader:
+    def __init__(self, fasta_path: str):
+        """Initialize from FASTA file."""
+        self.fastas = []
+        self.load(fasta_path)
+
+    def load(self, fasta_path: str):
+        """Load bare or FASTA formatted sequence."""
+        with open(fasta_path, 'r') as f:
+            self.content = f.read()
+
+        if "__cn__" in self.content:
+            # Pasted content with escaped characters
+            self.newline = '__cn__'
+            self.read_caret = '__gt__'
+        else:
+            # Uploaded file with normal content
+            self.newline = '\n'
+            self.read_caret = '>'
+
+        self.lines = self.content.split(self.newline)
+
+        if not self.lines[0].startswith(self.read_caret):
+            # Fasta is headless, load as single sequence
+            self.update_fastas(
+                '', ''.join(self.lines)
+            )
+
+        else:
+            header = None
+            sequence = None
+            for line in self.lines:
+                if line.startswith(self.read_caret):
+                    if header:
+                        self.update_fastas(header, sequence)
+                    header = '>' + self.strip_header(line)
+                    sequence = ''
+                else:
+                    sequence += line.strip('\n ')
+            self.update_fastas(header, sequence)
+
+    def strip_header(self, line):
+        """Strip characters escaped with underscores from pasted text."""
+        return re.sub(r'\_\_.{2}\_\_', '', line).strip('>')
+
+    def update_fastas(self, header: str, sequence: str):
+        # if we have a sequence
+        if sequence:
+            # create generic header if not exists
+            if not header:
+                fasta_count = len(self.fastas)
+                header = f'>sequence_{fasta_count}'
+
+            # Create new Fasta
+            self.fastas.append(Fasta(header, sequence))
+
+
+class FastaValidator:
+    def __init__(
+            self,
+            min_length=None,
+            max_length=None,
+            multiple=False):
+        self.multiple = multiple
+        self.min_length = min_length
+        self.max_length = max_length
+        self.iupac_characters = {
+            'A', 'B', 'C', 'D', 'E', 'F', 'G',
+            'H', 'I', 'K', 'L', 'M', 'N', 'P',
+            'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X',
+            'Y', 'Z', '-'
+        }
+
+    def validate(self, fasta_list: List[Fasta]):
+        """Perform FASTA validation."""
+        self.fasta_list = fasta_list
+        self.validate_num_seqs()
+        self.validate_length()
+        self.validate_alphabet()
+        # not checking for 'X' nucleotides at the moment.
+        # alphafold can throw an error if it doesn't like it.
+        # self.validate_x()
+        return self.fasta_list
+
+    def validate_num_seqs(self) -> None:
+        """Assert that only one sequence has been provided."""
+        fasta_count = len(self.fasta_list)
+
+        if self.multiple:
+            if fasta_count < 2:
+                raise ValueError(
+                    'Error encountered validating FASTA:\n'
+                    'Multimer mode requires multiple input sequence.'
+                    f' Only {fasta_count} sequences were detected in'
+                    ' the provided file.')
+                self.fasta_list = self.fasta_list
+
+            elif fasta_count > MULTIMER_MAX_SEQUENCE_COUNT:
+                sys.stderr.write(
+                    f'WARNING: detected {fasta_count} sequences but the'
+                    f' maximum allowed is {MULTIMER_MAX_SEQUENCE_COUNT}'
+                    ' sequences. The last'
+                    f' {fasta_count - MULTIMER_MAX_SEQUENCE_COUNT} sequence(s)'
+                    ' have been discarded.\n')
+                self.fasta_list = self.fasta_list[:MULTIMER_MAX_SEQUENCE_COUNT]
+        else:
+            if fasta_count > 1:
+                sys.stderr.write(
+                    'WARNING: More than 1 sequence detected.'
+                    ' Using first FASTA sequence as input.\n')
+                self.fasta_list = self.fasta_list[:1]
+
+            elif len(self.fasta_list) == 0:
+                raise ValueError(
+                    'Error encountered validating FASTA:\n'
+                    ' no FASTA sequences detected in input file.')
+
+    def validate_length(self):
+        """Confirm whether sequence length is valid."""
+        fasta = self.fasta_list[0]
+        if self.min_length:
+            if len(fasta.aa_seq) < self.min_length:
+                raise ValueError(
+                    'Error encountered validating FASTA:\n Sequence too short'
+                    f' ({len(fasta.aa_seq)}AA).'
+                    f' Minimum length is {self.min_length}AA.')
+        if self.max_length:
+            if len(fasta.aa_seq) > self.max_length:
+                raise ValueError(
+                    'Error encountered validating FASTA:\n'
+                    f' Sequence too long ({len(fasta.aa_seq)}AA).'
+                    f' Maximum length is {self.max_length}AA.')
+
+    def validate_alphabet(self):
+        """Confirm whether the sequence conforms to IUPAC codes.
+
+        If not, report the offending character and its position.
+        """
+        fasta = self.fasta_list[0]
+        for i, char in enumerate(fasta.aa_seq.upper()):
+            if char not in self.iupac_characters:
+                raise ValueError(
+                    'Error encountered validating FASTA:\n Invalid amino acid'
+                    f' found at pos {i}: "{char}"')
+
+    def validate_x(self):
+        """Check for X bases."""
+        fasta = self.fasta_list[0]
+        for i, char in enumerate(fasta.aa_seq.upper()):
+            if char == 'X':
+                raise ValueError(
+                    'Error encountered validating FASTA:\n Unsupported AA code'
+                    f' "X" found at pos {i}')
+
+
+class FastaWriter:
+    def __init__(self) -> None:
+        self.line_wrap = 60
+
+    def write(self, fasta: Fasta):
+        header = fasta.header
+        seq = self.format_sequence(fasta.aa_seq)
+        sys.stdout.write(header + '\n')
+        sys.stdout.write(seq)
+
+    def format_sequence(self, aa_seq: str):
+        formatted_seq = ''
+        for i in range(0, len(aa_seq), self.line_wrap):
+            formatted_seq += aa_seq[i: i + self.line_wrap] + '\n'
+        return formatted_seq.upper()
+
+
+def main():
+    # load fasta file
+    try:
+        args = parse_args()
+        fas = FastaLoader(args.input)
+
+        # validate
+        fv = FastaValidator(
+            min_length=args.min_length,
+            max_length=args.max_length,
+            multiple=args.multimer,
+        )
+        clean_fastas = fv.validate(fas.fastas)
+
+        # write clean data
+        fw = FastaWriter()
+        for fas in clean_fastas:
+            fw.write(fas)
+
+        sys.stderr.write("Validated FASTA sequence(s):\n\n")
+        for fas in clean_fastas:
+            sys.stderr.write(fas.header + '\n')
+            sys.stderr.write(fas.aa_seq + '\n\n')
+
+    except ValueError as exc:
+        sys.stderr.write(f"{exc}\n\n")
+        raise exc
+
+    except Exception as exc:
+        sys.stderr.write(
+            "Input error: FASTA input is invalid. Please check your input.\n\n"
+        )
+        raise exc
+
+
+def parse_args() -> argparse.Namespace:
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "input",
+        help="input fasta file",
+        type=str
+    )
+    parser.add_argument(
+        "--min_length",
+        dest='min_length',
+        help="Minimum length of input protein sequence (AA)",
+        default=None,
+        type=int,
+    )
+    parser.add_argument(
+        "--max_length",
+        dest='max_length',
+        help="Maximum length of input protein sequence (AA)",
+        default=None,
+        type=int,
+    )
+    parser.add_argument(
+        "--multimer",
+        action='store_true',
+        help="Require multiple input sequences",
+    )
+    return parser.parse_args()
+
+
+if __name__ == '__main__':
+    main()
--- a/validate_fasta.py	Fri Mar 10 02:48:07 2023 +0000
+++ /dev/null	Thu Jan 01 00:00:00 1970 +0000
@@ -1,254 +0,0 @@
-"""Validate input FASTA sequence."""
-
-import argparse
-import re
-import sys
-from typing import List
-
-MULTIMER_MAX_SEQUENCE_COUNT = 10
-
-
-class Fasta:
-    def __init__(self, header_str: str, seq_str: str):
-        self.header = header_str
-        self.aa_seq = seq_str
-
-
-class FastaLoader:
-    def __init__(self, fasta_path: str):
-        """Initialize from FASTA file."""
-        self.fastas = []
-        self.load(fasta_path)
-
-    def load(self, fasta_path: str):
-        """Load bare or FASTA formatted sequence."""
-        with open(fasta_path, 'r') as f:
-            self.content = f.read()
-
-        if "__cn__" in self.content:
-            # Pasted content with escaped characters
-            self.newline = '__cn__'
-            self.read_caret = '__gt__'
-        else:
-            # Uploaded file with normal content
-            self.newline = '\n'
-            self.read_caret = '>'
-
-        self.lines = self.content.split(self.newline)
-
-        if not self.lines[0].startswith(self.read_caret):
-            # Fasta is headless, load as single sequence
-            self.update_fastas(
-                '', ''.join(self.lines)
-            )
-
-        else:
-            header = None
-            sequence = None
-            for line in self.lines:
-                if line.startswith(self.read_caret):
-                    if header:
-                        self.update_fastas(header, sequence)
-                    header = '>' + self.strip_header(line)
-                    sequence = ''
-                else:
-                    sequence += line.strip('\n ')
-            self.update_fastas(header, sequence)
-
-    def strip_header(self, line):
-        """Strip characters escaped with underscores from pasted text."""
-        return re.sub(r'\_\_.{2}\_\_', '', line).strip('>')
-
-    def update_fastas(self, header: str, sequence: str):
-        # if we have a sequence
-        if sequence:
-            # create generic header if not exists
-            if not header:
-                fasta_count = len(self.fastas)
-                header = f'>sequence_{fasta_count}'
-
-            # Create new Fasta
-            self.fastas.append(Fasta(header, sequence))
-
-
-class FastaValidator:
-    def __init__(
-            self,
-            min_length=None,
-            max_length=None,
-            multiple=False):
-        self.multiple = multiple
-        self.min_length = min_length
-        self.max_length = max_length
-        self.iupac_characters = {
-            'A', 'B', 'C', 'D', 'E', 'F', 'G',
-            'H', 'I', 'K', 'L', 'M', 'N', 'P',
-            'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X',
-            'Y', 'Z', '-'
-        }
-
-    def validate(self, fasta_list: List[Fasta]):
-        """Perform FASTA validation."""
-        self.fasta_list = fasta_list
-        self.validate_num_seqs()
-        self.validate_length()
-        self.validate_alphabet()
-        # not checking for 'X' nucleotides at the moment.
-        # alphafold can throw an error if it doesn't like it.
-        # self.validate_x()
-        return self.fasta_list
-
-    def validate_num_seqs(self) -> None:
-        """Assert that only one sequence has been provided."""
-        fasta_count = len(self.fasta_list)
-
-        if self.multiple:
-            if fasta_count < 2:
-                raise ValueError(
-                    'Error encountered validating FASTA:\n'
-                    'Multimer mode requires multiple input sequence.'
-                    f' Only {fasta_count} sequences were detected in'
-                    ' the provided file.')
-                self.fasta_list = self.fasta_list
-
-            elif fasta_count > MULTIMER_MAX_SEQUENCE_COUNT:
-                sys.stderr.write(
-                    f'WARNING: detected {fasta_count} sequences but the'
-                    f' maximum allowed is {MULTIMER_MAX_SEQUENCE_COUNT}'
-                    ' sequences. The last'
-                    f' {fasta_count - MULTIMER_MAX_SEQUENCE_COUNT} sequence(s)'
-                    ' have been discarded.\n')
-                self.fasta_list = self.fasta_list[:MULTIMER_MAX_SEQUENCE_COUNT]
-        else:
-            if fasta_count > 1:
-                sys.stderr.write(
-                    'WARNING: More than 1 sequence detected.'
-                    ' Using first FASTA sequence as input.\n')
-                self.fasta_list = self.fasta_list[:1]
-
-            elif len(self.fasta_list) == 0:
-                raise ValueError(
-                    'Error encountered validating FASTA:\n'
-                    ' no FASTA sequences detected in input file.')
-
-    def validate_length(self):
-        """Confirm whether sequence length is valid."""
-        fasta = self.fasta_list[0]
-        if self.min_length:
-            if len(fasta.aa_seq) < self.min_length:
-                raise ValueError(
-                    'Error encountered validating FASTA:\n Sequence too short'
-                    f' ({len(fasta.aa_seq)}AA).'
-                    f' Minimum length is {self.min_length}AA.')
-        if self.max_length:
-            if len(fasta.aa_seq) > self.max_length:
-                raise ValueError(
-                    'Error encountered validating FASTA:\n'
-                    f' Sequence too long ({len(fasta.aa_seq)}AA).'
-                    f' Maximum length is {self.max_length}AA.')
-
-    def validate_alphabet(self):
-        """Confirm whether the sequence conforms to IUPAC codes.
-
-        If not, report the offending character and its position.
-        """
-        fasta = self.fasta_list[0]
-        for i, char in enumerate(fasta.aa_seq.upper()):
-            if char not in self.iupac_characters:
-                raise ValueError(
-                    'Error encountered validating FASTA:\n Invalid amino acid'
-                    f' found at pos {i}: "{char}"')
-
-    def validate_x(self):
-        """Check for X bases."""
-        fasta = self.fasta_list[0]
-        for i, char in enumerate(fasta.aa_seq.upper()):
-            if char == 'X':
-                raise ValueError(
-                    'Error encountered validating FASTA:\n Unsupported AA code'
-                    f' "X" found at pos {i}')
-
-
-class FastaWriter:
-    def __init__(self) -> None:
-        self.line_wrap = 60
-
-    def write(self, fasta: Fasta):
-        header = fasta.header
-        seq = self.format_sequence(fasta.aa_seq)
-        sys.stdout.write(header + '\n')
-        sys.stdout.write(seq)
-
-    def format_sequence(self, aa_seq: str):
-        formatted_seq = ''
-        for i in range(0, len(aa_seq), self.line_wrap):
-            formatted_seq += aa_seq[i: i + self.line_wrap] + '\n'
-        return formatted_seq.upper()
-
-
-def main():
-    # load fasta file
-    try:
-        args = parse_args()
-        fas = FastaLoader(args.input)
-
-        # validate
-        fv = FastaValidator(
-            min_length=args.min_length,
-            max_length=args.max_length,
-            multiple=args.multimer,
-        )
-        clean_fastas = fv.validate(fas.fastas)
-
-        # write clean data
-        fw = FastaWriter()
-        for fas in clean_fastas:
-            fw.write(fas)
-
-        sys.stderr.write("Validated FASTA sequence(s):\n\n")
-        for fas in clean_fastas:
-            sys.stderr.write(fas.header + '\n')
-            sys.stderr.write(fas.aa_seq + '\n\n')
-
-    except ValueError as exc:
-        sys.stderr.write(f"{exc}\n\n")
-        raise exc
-
-    except Exception as exc:
-        sys.stderr.write(
-            "Input error: FASTA input is invalid. Please check your input.\n\n"
-        )
-        raise exc
-
-
-def parse_args() -> argparse.Namespace:
-    parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "input",
-        help="input fasta file",
-        type=str
-    )
-    parser.add_argument(
-        "--min_length",
-        dest='min_length',
-        help="Minimum length of input protein sequence (AA)",
-        default=None,
-        type=int,
-    )
-    parser.add_argument(
-        "--max_length",
-        dest='max_length',
-        help="Maximum length of input protein sequence (AA)",
-        default=None,
-        type=int,
-    )
-    parser.add_argument(
-        "--multimer",
-        action='store_true',
-        help="Require multiple input sequences",
-    )
-    return parser.parse_args()
-
-
-if __name__ == '__main__':
-    main()