Previous changeset 0:7ae9d78b06f5 (2022-01-28) Next changeset 2:abba603c6ef3 (2022-03-03) |
Commit message:
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86" |
modified:
README.md alphafold.xml scripts/download_all_data.sh scripts/download_alphafold_params.sh scripts/download_bfd.sh scripts/download_mgnify.sh scripts/download_pdb70.sh scripts/download_pdb_mmcif.sh scripts/download_pdb_seqres.sh scripts/download_small_bfd.sh scripts/download_uniclust30.sh scripts/download_uniprot.sh scripts/download_uniref90.sh validate_fasta.py |
added:
README.rst alphafold.fasta docker/Dockerfile docker/README.md docker/alphafold/.dockerignore docker/alphafold/CONTRIBUTING.md docker/alphafold/LICENSE docker/alphafold/README.md docker/alphafold/alphafold/__init__.py docker/alphafold/alphafold/common/__init__.py docker/alphafold/alphafold/common/confidence.py docker/alphafold/alphafold/common/protein.py docker/alphafold/alphafold/common/protein_test.py docker/alphafold/alphafold/common/residue_constants.py docker/alphafold/alphafold/common/residue_constants_test.py docker/alphafold/alphafold/common/testdata/2rbg.pdb docker/alphafold/alphafold/data/__init__.py docker/alphafold/alphafold/data/feature_processing.py docker/alphafold/alphafold/data/mmcif_parsing.py docker/alphafold/alphafold/data/msa_identifiers.py docker/alphafold/alphafold/data/msa_pairing.py docker/alphafold/alphafold/data/parsers.py docker/alphafold/alphafold/data/pipeline.py docker/alphafold/alphafold/data/pipeline_multimer.py docker/alphafold/alphafold/data/templates.py docker/alphafold/alphafold/data/tools/__init__.py docker/alphafold/alphafold/data/tools/hhblits.py docker/alphafold/alphafold/data/tools/hhsearch.py docker/alphafold/alphafold/data/tools/hmmbuild.py docker/alphafold/alphafold/data/tools/hmmsearch.py docker/alphafold/alphafold/data/tools/jackhmmer.py docker/alphafold/alphafold/data/tools/kalign.py docker/alphafold/alphafold/data/tools/utils.py docker/alphafold/alphafold/model/__init__.py docker/alphafold/alphafold/model/all_atom.py docker/alphafold/alphafold/model/all_atom_multimer.py docker/alphafold/alphafold/model/all_atom_test.py docker/alphafold/alphafold/model/common_modules.py docker/alphafold/alphafold/model/config.py docker/alphafold/alphafold/model/data.py docker/alphafold/alphafold/model/features.py docker/alphafold/alphafold/model/folding.py docker/alphafold/alphafold/model/folding_multimer.py docker/alphafold/alphafold/model/geometry/__init__.py docker/alphafold/alphafold/model/geometry/rigid_matrix_vector.py docker/alphafold/alphafold/model/geometry/rotation_matrix.py docker/alphafold/alphafold/model/geometry/struct_of_array.py docker/alphafold/alphafold/model/geometry/test_utils.py docker/alphafold/alphafold/model/geometry/utils.py docker/alphafold/alphafold/model/geometry/vector.py docker/alphafold/alphafold/model/layer_stack.py docker/alphafold/alphafold/model/layer_stack_test.py docker/alphafold/alphafold/model/lddt.py docker/alphafold/alphafold/model/lddt_test.py docker/alphafold/alphafold/model/mapping.py docker/alphafold/alphafold/model/model.py docker/alphafold/alphafold/model/modules.py docker/alphafold/alphafold/model/modules_multimer.py docker/alphafold/alphafold/model/prng.py docker/alphafold/alphafold/model/prng_test.py docker/alphafold/alphafold/model/quat_affine.py docker/alphafold/alphafold/model/quat_affine_test.py docker/alphafold/alphafold/model/r3.py docker/alphafold/alphafold/model/tf/__init__.py docker/alphafold/alphafold/model/tf/data_transforms.py docker/alphafold/alphafold/model/tf/input_pipeline.py docker/alphafold/alphafold/model/tf/protein_features.py docker/alphafold/alphafold/model/tf/protein_features_test.py docker/alphafold/alphafold/model/tf/proteins_dataset.py docker/alphafold/alphafold/model/tf/shape_helpers.py docker/alphafold/alphafold/model/tf/shape_helpers_test.py docker/alphafold/alphafold/model/tf/shape_placeholders.py docker/alphafold/alphafold/model/tf/utils.py docker/alphafold/alphafold/model/utils.py docker/alphafold/alphafold/notebooks/__init__.py docker/alphafold/alphafold/notebooks/notebook_utils.py docker/alphafold/alphafold/notebooks/notebook_utils_test.py docker/alphafold/alphafold/relax/__init__.py docker/alphafold/alphafold/relax/amber_minimize.py docker/alphafold/alphafold/relax/amber_minimize_test.py docker/alphafold/alphafold/relax/cleanup.py docker/alphafold/alphafold/relax/cleanup_test.py docker/alphafold/alphafold/relax/relax.py docker/alphafold/alphafold/relax/relax_test.py docker/alphafold/alphafold/relax/testdata/model_output.pdb docker/alphafold/alphafold/relax/testdata/multiple_disulfides_target.pdb docker/alphafold/alphafold/relax/testdata/with_violations.pdb docker/alphafold/alphafold/relax/testdata/with_violations_casp14.pdb docker/alphafold/alphafold/relax/utils.py docker/alphafold/alphafold/relax/utils_test.py docker/alphafold/docker/Dockerfile docker/alphafold/docker/openmm.patch docker/alphafold/docker/requirements.txt docker/alphafold/docker/run_docker.py docker/alphafold/imgs/casp14_predictions.gif docker/alphafold/imgs/header.jpg docker/alphafold/notebooks/AlphaFold.ipynb docker/alphafold/requirements.txt docker/alphafold/run_alphafold.py docker/alphafold/run_alphafold_test.py docker/alphafold/scripts/download_all_data.sh docker/alphafold/scripts/download_alphafold_params.sh docker/alphafold/scripts/download_bfd.sh docker/alphafold/scripts/download_mgnify.sh docker/alphafold/scripts/download_pdb70.sh docker/alphafold/scripts/download_pdb_mmcif.sh docker/alphafold/scripts/download_pdb_seqres.sh docker/alphafold/scripts/download_small_bfd.sh docker/alphafold/scripts/download_uniclust30.sh docker/alphafold/scripts/download_uniprot.sh docker/alphafold/scripts/download_uniref90.sh docker/alphafold/setup.py docker/claremcwhite/Dockerfile docker/claremcwhite/README.md |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 README.md --- a/README.md Fri Jan 28 04:56:29 2022 +0000 +++ b/README.md Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -3,9 +3,9 @@ ## Overview -Alphafold requires a customised compute environment to run. The machine needs a GPU, and access to a 2.2 Tb reference data store. +Alphafold requires a customised compute environment to run. The machine needs a GPU, and access to a 2.2 Tb reference data store. -This document is designed to provide details on the compute environment required for Alphafold operation, and the Galaxy job destination settings to run the wrapper. +This document is designed to provide details on the compute environment required for Alphafold operation, and the Galaxy job destination settings to run the wrapper. For full details on Alphafold requirements, see https://github.com/deepmind/alphafold. @@ -13,11 +13,11 @@ ### HARDWARE -The machine is recommended to have the following specs: +The machine is recommended to have the following specs: - 12 cores - 80 Gb RAM - 2.5 Tb storage -- A fast Nvidia GPU. +- A fast Nvidia GPU. As a minimum, the Nvidia GPU must have 8Gb RAM. It also requires ***unified memory*** to be switched on. <br> Unified memory is usually enabled by default, but some HPC systems will turn it off so the GPU can be shared between multiple jobs concurrently. @@ -31,7 +31,7 @@ - [Singularity](https://sylabs.io/guides/3.0/user-guide/installation.html) - [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) -As Alphafold uses an Nvidia GPU, the NVIDIA Container Toolkit is needed. This makes the GPU available inside the running singularity container. +As Alphafold uses an Nvidia GPU, the NVIDIA Container Toolkit is needed. This makes the GPU available inside the running singularity container. To check that everything has been set up correctly, run the following @@ -79,7 +79,7 @@ bash scripts/download_all_data.sh /data/alphafold_databases ``` -This will install the reference data to `/data/alphafold_databases`. To check this has worked, ensure the final folder structure is as follows: +This will install the reference data to `/data/alphafold_databases`. To check this has worked, ensure the final folder structure is as follows: ``` data/alphafold_databases @@ -128,9 +128,9 @@ ### JOB DESTINATION -Alphafold needs a custom singularity job destination to run. +Alphafold needs a custom singularity job destination to run. The destination needs to be configured for singularity, and some -extra singularity params need to be set as seen below. +extra singularity params need to be set as seen below. Specify the job runner. For example, a local runner @@ -154,4 +154,4 @@ ### Closing -If you are experiencing technical issues, feel free to write to help@genome.edu.au. We may be able to provide comment on setting up Alphafold on your compute environment. +If you are experiencing technical issues, feel free to write to help@genome.edu.au. We may be able to provide advice on setting up Alphafold on your compute environment. |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 README.rst --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/README.rst Tue Mar 01 02:53:05 2022 +0000 |
b |
@@ -0,0 +1,164 @@ +Alphafold compute setup +======================= + +Overview +-------- + +Alphafold requires a customised compute environment to run. The machine +needs a GPU, and access to a 2.2 Tb reference data store. + +This document is designed to provide details on the compute environment +required for Alphafold operation, and the Galaxy job destination +settings to run the wrapper. + +For full details on Alphafold requirements, see +https://github.com/deepmind/alphafold. + +HARDWARE +~~~~~~~~ + +The machine is recommended to have the following specs: - 12 cores - 80 +Gb RAM - 2.5 Tb storage - A fast Nvidia GPU. + +As a minimum, the Nvidia GPU must have 8Gb RAM. It also requires +**unified memory** to be switched on. Unified memory is usually enabled +by default, but some HPC systems will turn it off so the GPU can be +shared between multiple jobs concurrently. + +ENVIRONMENT +~~~~~~~~~~~ + +This wrapper runs Alphafold as a singularity container. The following +software are needed: + +- `Singularity <https://sylabs.io/guides/3.0/user-guide/installation.html>`_ +- `NVIDIA Container + Toolkit <https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html>`_ + +As Alphafold uses an Nvidia GPU, the NVIDIA Container Toolkit is needed. +This makes the GPU available inside the running singularity container. + +To check that everything has been set up correctly, run the following + +:: + + singularity run --nv docker://nvidia/cuda:11.0-base nvidia-smi + +If you can see something similar to this output (details depend on your +GPU), it has been set up correctly. + +:: + + +-----------------------------------------------------------------------------+ + | NVIDIA-SMI 470.57.02 Driver Version: 470.57.02 CUDA Version: 11.4 | + |-------------------------------+----------------------+----------------------+ + | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | + | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | + | | | MIG M. | + |===============================+======================+======================| + | 0 Tesla T4 Off | 00000000:00:05.0 Off | 0 | + | N/A 49C P0 28W / 70W | 0MiB / 15109MiB | 0% Default | + | | | N/A | + +-------------------------------+----------------------+----------------------+ + + +-----------------------------------------------------------------------------+ + | Processes: | + | GPU GI CI PID Type Process name GPU Memory | + | ID ID Usage | + |=============================================================================| + | No running processes found | + +-----------------------------------------------------------------------------+ + +REFERENCE DATA +~~~~~~~~~~~~~~ + +Alphafold needs reference data to run. The wrapper expects this data to +be present at ``/data/alphafold_databases``. To download, run the +following shell script command in the tool directory. + +:: + + # make folders if needed + mkdir /data /data/alphafold_databases + + # download ref data + bash scripts/download_all_data.sh /data/alphafold_databases + +This will install the reference data to ``/data/alphafold_databases``. +To check this has worked, ensure the final folder structure is as +follows: + +:: + + data/alphafold_databases + ├── bfd + │ ├── bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt_a3m.ffdata + │ ├── bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt_a3m.ffindex + │ ├── bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt_cs219.ffdata + │ ├── bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt_cs219.ffindex + │ ├── bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt_hhm.ffdata + │ └── bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt_hhm.ffindex + ├── mgnify + │ └── mgy_clusters_2018_12.fa + ├── params + │ ├── LICENSE + │ ├── params_model_1.npz + │ ├── params_model_1_ptm.npz + │ ├── params_model_2.npz + │ ├── params_model_2_ptm.npz + │ ├── params_model_3.npz + │ ├── params_model_3_ptm.npz + │ ├── params_model_4.npz + │ ├── params_model_4_ptm.npz + │ ├── params_model_5.npz + │ └── params_model_5_ptm.npz + ├── pdb70 + │ ├── md5sum + │ ├── pdb70_a3m.ffdata + │ ├── pdb70_a3m.ffindex + │ ├── pdb70_clu.tsv + │ ├── pdb70_cs219.ffdata + │ ├── pdb70_cs219.ffindex + │ ├── pdb70_hhm.ffdata + │ ├── pdb70_hhm.ffindex + │ └── pdb_filter.dat + ├── pdb_mmcif + │ ├── mmcif_files + │ └── obsolete.dat + ├── uniclust30 + │ └── uniclust30_2018_08 + └── uniref90 + └── uniref90.fasta + +JOB DESTINATION +~~~~~~~~~~~~~~~ + +Alphafold needs a custom singularity job destination to run. The +destination needs to be configured for singularity, and some extra +singularity params need to be set as seen below. + +Specify the job runner. For example, a local runner + +:: + + <plugin id="alphafold_runner" type="runner" load="galaxy.jobs.runners.local:LocalJobRunner"/> + +Customise the job destination with required singularity settings. The +settings below are mandatory, but you may include other settings as +needed. + +:: + + <destination id="alphafold" runner="alphafold_runner"> + <param id="dependency_resolution">'none'</param> + <param id="singularity_enabled">true</param> + <param id="singularity_run_extra_arguments">--nv</param> + <param id="singularity_volumes">"$job_directory:ro,$tool_directory:ro,$job_directory/outputs:rw,$working_directory:rw,/data/alphafold_databases:/data:ro"</param> + </destination> + +Closing +~~~~~~~ + +If you are experiencing technical issues, feel free to write to +help@genome.edu.au. We may be able to provide advice on setting up +Alphafold on your compute environment. |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 alphafold.fasta --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/alphafold.fasta Tue Mar 01 02:53:05 2022 +0000 |
b |
@@ -0,0 +1,4 @@ +>AAB30827.1 thyroid-stimulating hormone alpha subunit Homo sapiens +MDYYRKYAAIFLVTLSVFLHVLHSAPDVQDCPECTLQENPFFSQPGAPILQCMGCCFSRA +YPTPLRSKKTMLVQKNVTSESTCCVAKSYNRVTVMGGFKVENHTACHCSTCYYHKS + |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 alphafold.xml --- a/alphafold.xml Fri Jan 28 04:56:29 2022 +0000 +++ b/alphafold.xml Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -1,64 +1,68 @@ <tool id="alphafold" name="alphafold" version="@TOOL_VERSION@+galaxy@VERSION_SUFFIX@" profile="20.01"> <description>Alphafold v2.0: AI-guided 3D structure prediction of proteins</description> <macros> - <token name="@TOOL_VERSION@">2.0.0</token> - <token name="@VERSION_SUFFIX@">0</token> + <token name="@TOOL_VERSION@">2.0.0</token> + <token name="@VERSION_SUFFIX@">0</token> </macros> <edam_topics> - <edam_topic>topic_0082</edam_topic> + <edam_topic>topic_0082</edam_topic> </edam_topics> <edam_operations> - <edam_operation>operation_0474</edam_operation> + <edam_operation>operation_0474</edam_operation> </edam_operations> + <xrefs> + <xref type="bio.tools">alphafold_2.0</xref> + </xrefs> <requirements> - <container type="docker">neoformit/alphafold-galaxy@sha256:6adf7f07062b307d08c11130c39a28abc7c290b23f6c347b09c2c649c054c338</container> + <container type="docker">neoformit/alphafold:latest</container> </requirements> <command detect_errors="exit_code"><![CDATA[ - ## fasta setup ---------------------------- - #if $fasta_or_text.input_mode == 'history': - cp '$fasta_or_text.fasta_file' input.fasta && + +## $ALPHAFOLD_DB variable should point to the location of the AlphaFold +## databases - defaults to /data + +## fasta setup ---------------------------- +#if $fasta_or_text.input_mode == 'history': + cp '$fasta_or_text.fasta_file' input.fasta && - #elif $fasta_or_text.input_mode == 'textbox': - echo '$fasta_or_text.fasta_text' > input.fasta && - #end if +#elif $fasta_or_text.input_mode == 'textbox': + echo '$fasta_or_text.fasta_text' > input.fasta && +#end if - python3 '$__tool_directory__/validate_fasta.py' input.fasta && +python3 '$__tool_directory__/validate_fasta.py' input.fasta && - ## env vars ------------------------------- - export TF_FORCE_UNIFIED_MEMORY=1 && - export XLA_PYTHON_CLIENT_MEM_FRACTION=4.0 && - export DATE=`date +"%Y-%m-%d"` && +## env vars ------------------------------- +export TF_FORCE_UNIFIED_MEMORY=1 && +export XLA_PYTHON_CLIENT_MEM_FRACTION=4.0 && +export DATE=`date +"%Y-%m-%d"` && - ## run alphafold ------------------------- - ln -s /app/alphafold/alphafold alphafold && - python /app/alphafold/run_alphafold.py - --fasta_paths alphafold.fasta - --output_dir output - --data_dir /data ## location of the alphafold databases on pulsar node --> could this maybe a env var? $ALPHAFOLD_DB --> \${ALPHAFOLD_DB:-/data} - --uniref90_database_path /data/uniref90/uniref90.fasta - --mgnify_database_path /data/mgnify/mgy_clusters_2018_12.fa - --pdb70_database_path /data/pdb70/pdb70 - --template_mmcif_dir /data/pdb_mmcif/mmcif_files - --obsolete_pdbs_path /data/pdb_mmcif/obsolete.dat - --max_template_date=\$DATE - --bfd_database_path /data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt - --uniclust30_database_path /data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 - && +## run alphafold ------------------------- +python /app/alphafold/run_alphafold.py +--fasta_paths alphafold.fasta +--output_dir output +--data_dir \${ALPHAFOLD_DB:-/data} +--uniref90_database_path \${ALPHAFOLD_DB:-/data}/uniref90/uniref90.fasta +--mgnify_database_path \${ALPHAFOLD_DB:-/data}/mgnify/mgy_clusters_2018_12.fa +--pdb70_database_path \${ALPHAFOLD_DB:-/data}/pdb70/pdb70 +--template_mmcif_dir \${ALPHAFOLD_DB:-/data}/pdb_mmcif/mmcif_files +--obsolete_pdbs_path \${ALPHAFOLD_DB:-/data}/pdb_mmcif/obsolete.dat +--max_template_date=\$DATE +--bfd_database_path \${ALPHAFOLD_DB:-/data}/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt +--uniclust30_database_path \${ALPHAFOLD_DB:-/data}/uniclust30/uniclust30_2018_08/uniclust30_2018_08 +--use_gpu_relax=True +&& - ## for dry run testing - ## cp -r '$__tool_directory__/output' . && +## Uncomment for "dummy run" - skip alphafold run and read output from test-data +## cp -r '$__tool_directory__/output' . && - ## generate extra outputs ----------------- - ## plddts - python3 '$__tool_directory__/gen_extra_outputs.py' output/alphafold $output_plddts && +## Generate additional outputs ------------ +python3 '$__tool_directory__/gen_extra_outputs.py' output/alphafold $output_plddts && - ## html - mkdir -p '${ html.files_path }' && - cp '$__tool_directory__/alphafold.html' ${html} && - cp output/alphafold/ranked_*.pdb '${html.files_path}' && +## HTML output +mkdir -p '${ html.files_path }' && +cp '$__tool_directory__/alphafold.html' '${html}' && +cp output/alphafold/ranked_*.pdb '${html.files_path}' - ## For some reason the working directory ends up being one level too deep! - mv working/* . ]]></command> <inputs> <conditional name="fasta_or_text"> |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/Dockerfile --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/Dockerfile Tue Mar 01 02:53:05 2022 +0000 |
b |
@@ -0,0 +1,26 @@ +FROM clairemcwhite/alphafold + +ARG CUDA=11.2 +ARG CUDA_FULL=11.2.2 + +# Copy in updated alphafold repo (last commit 05/11/2021) +# https://github.com/deepmind/alphafold/tree/be37a41d6f83e4145bd4912cbe8bf6a24af80c29 +RUN rm -rf /app/alphafold/alphafold +COPY alphafold /app/alphafold/ + +RUN conda update -qy conda \ + && conda install -y -c conda-forge \ + openmm=7.5.1 \ + cudatoolkit==${CUDA_FULL} \ + pdbfixer \ + pip \ + python=3.7 + +RUN wget -q -P /app/alphafold/alphafold/common/ \ + https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt + +# Fix correct jax version for Cuda 11.2: https://github.com/google/jax/issues/5668 +RUN pip3 install --upgrade pip \ + && pip3 install -r /app/alphafold/requirements.txt \ + && pip3 install --upgrade jax jaxlib==0.1.61+cuda112 -f \ + https://storage.googleapis.com/jax-releases/jax_releases.html |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/README.md --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/README.md Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,12 @@ +# What is this? + +These are `alphafold` git repos copied from: + - the `clairemcwhite/alphafold` docker container (originates from a fork/branch https://github.com/deisseroth-lab/alphafold/tree/cudnn-runtime) + - The upstream https://github.com/deepmind/alphafold + +### Diffs +- According to [the closed pull request](https://github.com/deepmind/alphafold/pull/36), the main diff is updates to Dockerfile Cuda deps in the fork +- These issues have since been resolved in the upstream +- Can probably copy the new repo into the image in a new Dockerfile `FROM clairemcwhite/alphafold` +- And hope that alphafold on pulsar can work with the new container! + (There were lots of dependency issues...) |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/.dockerignore --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/.dockerignore Tue Mar 01 02:53:05 2022 +0000 |
b |
@@ -0,0 +1,3 @@ +.dockerignore +docker/Dockerfile +README.md |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/CONTRIBUTING.md --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/CONTRIBUTING.md Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,23 @@ +# How to Contribute + +We welcome small patches related to bug fixes and documentation, but we do not +plan to make any major changes to this repository. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution, +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to <https://cla.developers.google.com/> to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/LICENSE --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/LICENSE Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,202 @@\n+\n+ Apache License\n+ Version 2.0, January 2004\n+ http://www.apache.org/licenses/\n+\n+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n+\n+ 1. Definitions.\n+\n+ "License" shall mean the terms and conditions for use, reproduction,\n+ and distribution as defined by Sections 1 through 9 of this document.\n+\n+ "Licensor" shall mean the copyright owner or entity authorized by\n+ the copyright owner that is granting the License.\n+\n+ "Legal Entity" shall mean the union of the acting entity and all\n+ other entities that control, are controlled by, or are under common\n+ control with that entity. For the purposes of this definition,\n+ "control" means (i) the power, direct or indirect, to cause the\n+ direction or management of such entity, whether by contract or\n+ otherwise, or (ii) ownership of fifty percent (50%) or more of the\n+ outstanding shares, or (iii) beneficial ownership of such entity.\n+\n+ "You" (or "Your") shall mean an individual or Legal Entity\n+ exercising permissions granted by this License.\n+\n+ "Source" form shall mean the preferred form for making modifications,\n+ including but not limited to software source code, documentation\n+ source, and configuration files.\n+\n+ "Object" form shall mean any form resulting from mechanical\n+ transformation or translation of a Source form, including but\n+ not limited to compiled object code, generated documentation,\n+ and conversions to other media types.\n+\n+ "Work" shall mean the work of authorship, whether in Source or\n+ Object form, made available under the License, as indicated by a\n+ copyright notice that is included in or attached to the work\n+ (an example is provided in the Appendix below).\n+\n+ "Derivative Works" shall mean any work, whether in Source or Object\n+ form, that is based on (or derived from) the Work and for which the\n+ editorial revisions, annotations, elaborations, or other modifications\n+ represent, as a whole, an original work of authorship. For the purposes\n+ of this License, Derivative Works shall not include works that remain\n+ separable from, or merely link (or bind by name) to the interfaces of,\n+ the Work and Derivative Works thereof.\n+\n+ "Contribution" shall mean any work of authorship, including\n+ the original version of the Work and any modifications or additions\n+ to that Work or Derivative Works thereof, that is intentionally\n+ submitted to Licensor for inclusion in the Work by the copyright owner\n+ or by an individual or Legal Entity authorized to submit on behalf of\n+ the copyright owner. For the purposes of this definition, "submitted"\n+ means any form of electronic, verbal, or written communication sent\n+ to the Licensor or its representatives, including but not limited to\n+ communication on electronic mailing lists, source code control systems,\n+ and issue tracking systems that are managed by, or on behalf of, the\n+ Licensor for the purpose of discussing and improving the Work, but\n+ excluding communication that is conspicuously marked or otherwise\n+ designated in writing by the copyright owner as "Not a Contribution."\n+\n+ "Contributor" shall mean Licensor and any individual or Legal Entity\n+ on behalf of whom a Contribution has been received by Licensor and\n+ subsequently incorporated within the Work.\n+\n+ 2. Grant of Copyright License. Subject to the terms and conditions of\n+ this License, each Contributor hereby grants to You a perpetual,\n+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n+ copyright license to reproduce, prepare Derivative Works of,\n+ publicly display, publicly perform, sublicense, and distribute the\n+ Work and such Derivative Works in Source or Obj'..b'r shall be under the terms and conditions of\n+ this License, without any additional terms or conditions.\n+ Notwithstanding the above, nothing herein shall supersede or modify\n+ the terms of any separate license agreement you may have executed\n+ with Licensor regarding such Contributions.\n+\n+ 6. Trademarks. This License does not grant permission to use the trade\n+ names, trademarks, service marks, or product names of the Licensor,\n+ except as required for reasonable and customary use in describing the\n+ origin of the Work and reproducing the content of the NOTICE file.\n+\n+ 7. Disclaimer of Warranty. Unless required by applicable law or\n+ agreed to in writing, Licensor provides the Work (and each\n+ Contributor provides its Contributions) on an "AS IS" BASIS,\n+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n+ implied, including, without limitation, any warranties or conditions\n+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n+ PARTICULAR PURPOSE. You are solely responsible for determining the\n+ appropriateness of using or redistributing the Work and assume any\n+ risks associated with Your exercise of permissions under this License.\n+\n+ 8. Limitation of Liability. In no event and under no legal theory,\n+ whether in tort (including negligence), contract, or otherwise,\n+ unless required by applicable law (such as deliberate and grossly\n+ negligent acts) or agreed to in writing, shall any Contributor be\n+ liable to You for damages, including any direct, indirect, special,\n+ incidental, or consequential damages of any character arising as a\n+ result of this License or out of the use or inability to use the\n+ Work (including but not limited to damages for loss of goodwill,\n+ work stoppage, computer failure or malfunction, or any and all\n+ other commercial damages or losses), even if such Contributor\n+ has been advised of the possibility of such damages.\n+\n+ 9. Accepting Warranty or Additional Liability. While redistributing\n+ the Work or Derivative Works thereof, You may choose to offer,\n+ and charge a fee for, acceptance of support, warranty, indemnity,\n+ or other liability obligations and/or rights consistent with this\n+ License. However, in accepting such obligations, You may act only\n+ on Your own behalf and on Your sole responsibility, not on behalf\n+ of any other Contributor, and only if You agree to indemnify,\n+ defend, and hold each Contributor harmless for any liability\n+ incurred by, or claims asserted against, such Contributor by reason\n+ of your accepting any such warranty or additional liability.\n+\n+ END OF TERMS AND CONDITIONS\n+\n+ APPENDIX: How to apply the Apache License to your work.\n+\n+ To apply the Apache License to your work, attach the following\n+ boilerplate notice, with the fields enclosed by brackets "[]"\n+ replaced with your own identifying information. (Don\'t include\n+ the brackets!) The text should be enclosed in the appropriate\n+ comment syntax for the file format. We also recommend that a\n+ file or class name and description of purpose be included on the\n+ same "printed page" as the copyright notice for easier\n+ identification within third-party archives.\n+\n+ Copyright [yyyy] [name of copyright owner]\n+\n+ Licensed under the Apache License, Version 2.0 (the "License");\n+ you may not use this file except in compliance with the License.\n+ You may obtain a copy of the License at\n+\n+ http://www.apache.org/licenses/LICENSE-2.0\n+\n+ Unless required by applicable law or agreed to in writing, software\n+ distributed under the License is distributed on an "AS IS" BASIS,\n+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+ See the License for the specific language governing permissions and\n+ limitations under the License.\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/README.md --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/README.md Tue Mar 01 02:53:05 2022 +0000 |
[ |
b"@@ -0,0 +1,666 @@\n+![header](imgs/header.jpg)\n+\n+# AlphaFold\n+\n+This package provides an implementation of the inference pipeline of AlphaFold\n+v2.0. This is a completely new model that was entered in CASP14 and published in\n+Nature. For simplicity, we refer to this model as AlphaFold throughout the rest\n+of this document.\n+\n+We also provide an implementation of AlphaFold-Multimer. This represents a work\n+in progress and AlphaFold-Multimer isn't expected to be as stable as our monomer\n+AlphaFold system.\n+[Read the guide](#updating-existing-alphafold-installation-to-include-alphafold-multimers)\n+for how to upgrade and update code.\n+\n+Any publication that discloses findings arising from using this source code or the model parameters should [cite](#citing-this-work) the\n+[AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2) and, if\n+applicable, the [AlphaFold-Multimer paper](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1).\n+\n+Please also refer to the\n+[Supplementary Information](https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf)\n+for a detailed description of the method.\n+\n+**You can use a slightly simplified version of AlphaFold with\n+[this Colab\n+notebook](https://colab.research.google.com/github/deepmind/alphafold/blob/main/notebooks/AlphaFold.ipynb)**\n+or community-supported versions (see below).\n+\n+![CASP14 predictions](imgs/casp14_predictions.gif)\n+\n+## First time setup\n+\n+The following steps are required in order to run AlphaFold:\n+\n+1. Install [Docker](https://www.docker.com/).\n+ * Install\n+ [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)\n+ for GPU support.\n+ * Setup running\n+ [Docker as a non-root user](https://docs.docker.com/engine/install/linux-postinstall/#manage-docker-as-a-non-root-user).\n+1. Download genetic databases (see below).\n+1. Download model parameters (see below).\n+1. Check that AlphaFold will be able to use a GPU by running:\n+\n+ ```bash\n+ docker run --rm --gpus all nvidia/cuda:11.0-base nvidia-smi\n+ ```\n+\n+ The output of this command should show a list of your GPUs. If it doesn't,\n+ check if you followed all steps correctly when setting up the\n+ [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)\n+ or take a look at the following\n+ [NVIDIA Docker issue](https://github.com/NVIDIA/nvidia-docker/issues/1447#issuecomment-801479573).\n+\n+If you wish to run AlphaFold using Singularity (a common containerization platform on HPC systems) we recommend using some of the\n+third party Singularity setups as linked in\n+https://github.com/deepmind/alphafold/issues/10 or\n+https://github.com/deepmind/alphafold/issues/24.\n+\n+### Genetic databases\n+\n+This step requires `aria2c` to be installed on your machine.\n+\n+AlphaFold needs multiple genetic (sequence) databases to run:\n+\n+* [BFD](https://bfd.mmseqs.com/),\n+* [MGnify](https://www.ebi.ac.uk/metagenomics/),\n+* [PDB70](http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/),\n+* [PDB](https://www.rcsb.org/) (structures in the mmCIF format),\n+* [PDB seqres](https://www.rcsb.org/) \xe2\x80\x93 only for AlphaFold-Multimer,\n+* [Uniclust30](https://uniclust.mmseqs.com/),\n+* [UniProt](https://www.uniprot.org/uniprot/) \xe2\x80\x93 only for AlphaFold-Multimer,\n+* [UniRef90](https://www.uniprot.org/help/uniref).\n+\n+We provide a script `scripts/download_all_data.sh` that can be used to download\n+and set up all of these databases:\n+\n+* Default:\n+\n+ ```bash\n+ scripts/download_all_data.sh <DOWNLOAD_DIR>\n+ ```\n+\n+ will download the full databases.\n+\n+* With `reduced_dbs`:\n+\n+ ```bash\n+ scripts/download_all_data.sh <DOWNLOAD_DIR> reduced_dbs\n+ ```\n+\n+ will download a reduced version of the databases to be used with the\n+ `reduced_dbs` database preset.\n+\n+:ledger: **Note: Th"..b'he following separate libraries\n+and packages:\n+\n+* [Abseil](https://github.com/abseil/abseil-py)\n+* [Biopython](https://biopython.org)\n+* [Chex](https://github.com/deepmind/chex)\n+* [Colab](https://research.google.com/colaboratory/)\n+* [Docker](https://www.docker.com)\n+* [HH Suite](https://github.com/soedinglab/hh-suite)\n+* [HMMER Suite](http://eddylab.org/software/hmmer)\n+* [Haiku](https://github.com/deepmind/dm-haiku)\n+* [Immutabledict](https://github.com/corenting/immutabledict)\n+* [JAX](https://github.com/google/jax/)\n+* [Kalign](https://msa.sbc.su.se/cgi-bin/msa.cgi)\n+* [matplotlib](https://matplotlib.org/)\n+* [ML Collections](https://github.com/google/ml_collections)\n+* [NumPy](https://numpy.org)\n+* [OpenMM](https://github.com/openmm/openmm)\n+* [OpenStructure](https://openstructure.org)\n+* [pandas](https://pandas.pydata.org/)\n+* [pymol3d](https://github.com/avirshup/py3dmol)\n+* [SciPy](https://scipy.org)\n+* [Sonnet](https://github.com/deepmind/sonnet)\n+* [TensorFlow](https://github.com/tensorflow/tensorflow)\n+* [Tree](https://github.com/deepmind/tree)\n+* [tqdm](https://github.com/tqdm/tqdm)\n+\n+We thank all their contributors and maintainers!\n+\n+## License and Disclaimer\n+\n+This is not an officially supported Google product.\n+\n+Copyright 2021 DeepMind Technologies Limited.\n+\n+### AlphaFold Code License\n+\n+Licensed under the Apache License, Version 2.0 (the "License"); you may not use\n+this file except in compliance with the License. You may obtain a copy of the\n+License at https://www.apache.org/licenses/LICENSE-2.0.\n+\n+Unless required by applicable law or agreed to in writing, software distributed\n+under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR\n+CONDITIONS OF ANY KIND, either express or implied. See the License for the\n+specific language governing permissions and limitations under the License.\n+\n+### Model Parameters License\n+\n+The AlphaFold parameters are made available for non-commercial use only, under\n+the terms of the Creative Commons Attribution-NonCommercial 4.0 International\n+(CC BY-NC 4.0) license. You can find details at:\n+https://creativecommons.org/licenses/by-nc/4.0/legalcode\n+\n+### Third-party software\n+\n+Use of the third-party software, libraries or code referred to in the\n+[Acknowledgements](#acknowledgements) section above may be governed by separate\n+terms and conditions or license provisions. Your use of the third-party\n+software, libraries or code is subject to any such terms and you should check\n+that you can comply with any applicable restrictions or terms and conditions\n+before use.\n+\n+### Mirrored Databases\n+\n+The following databases have been mirrored by DeepMind, and are available with reference to the following:\n+\n+* [BFD](https://bfd.mmseqs.com/) (unmodified), by Steinegger M. and S\xc3\xb6ding J., available under a [Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/).\n+\n+* [BFD](https://bfd.mmseqs.com/) (modified), by Steinegger M. and S\xc3\xb6ding J., modified by DeepMind, available under a [Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/). See the Methods section of the [AlphaFold proteome paper](https://www.nature.com/articles/s41586-021-03828-1) for details.\n+\n+* [Uniclust30: v2018_08](http://wwwuser.gwdg.de/~compbiol/uniclust/2018_08/) (unmodified), by Mirdita M. et al., available under a [Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/).\n+\n+* [MGnify: v2018_12](http://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/current_release/README.txt) (unmodified), by Mitchell AL et al., available free of all copyright restrictions and made fully and freely available for both non-commercial and commercial use under [CC0 1.0 Universal (CC0 1.0) Public Domain Dedication](https://creativecommons.org/publicdomain/zero/1.0/).\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/__init__.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/__init__.py Tue Mar 01 02:53:05 2022 +0000 |
b |
@@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""An implementation of the inference pipeline of AlphaFold v2.0.""" |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/common/__init__.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/common/__init__.py Tue Mar 01 02:53:05 2022 +0000 |
b |
@@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common data types and constants used within Alphafold.""" |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/common/confidence.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/common/confidence.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,168 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for processing confidence metrics.""" + +from typing import Dict, Optional, Tuple +import numpy as np +import scipy.special + + +def compute_plddt(logits: np.ndarray) -> np.ndarray: + """Computes per-residue pLDDT from logits. + + Args: + logits: [num_res, num_bins] output from the PredictedLDDTHead. + + Returns: + plddt: [num_res] per-residue pLDDT. + """ + num_bins = logits.shape[-1] + bin_width = 1.0 / num_bins + bin_centers = np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width) + probs = scipy.special.softmax(logits, axis=-1) + predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1) + return predicted_lddt_ca * 100 + + +def _calculate_bin_centers(breaks: np.ndarray): + """Gets the bin centers from the bin edges. + + Args: + breaks: [num_bins - 1] the error bin edges. + + Returns: + bin_centers: [num_bins] the error bin centers. + """ + step = (breaks[1] - breaks[0]) + + # Add half-step to get the center + bin_centers = breaks + step / 2 + # Add a catch-all bin at the end. + bin_centers = np.concatenate([bin_centers, [bin_centers[-1] + step]], + axis=0) + return bin_centers + + +def _calculate_expected_aligned_error( + alignment_confidence_breaks: np.ndarray, + aligned_distance_error_probs: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Calculates expected aligned distance errors for every pair of residues. + + Args: + alignment_confidence_breaks: [num_bins - 1] the error bin edges. + aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted + probs for each error bin, for each pair of residues. + + Returns: + predicted_aligned_error: [num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: The maximum predicted error possible. + """ + bin_centers = _calculate_bin_centers(alignment_confidence_breaks) + + # Tuple of expected aligned distance error and max possible error. + return (np.sum(aligned_distance_error_probs * bin_centers, axis=-1), + np.asarray(bin_centers[-1])) + + +def compute_predicted_aligned_error( + logits: np.ndarray, + breaks: np.ndarray) -> Dict[str, np.ndarray]: + """Computes aligned confidence metrics from logits. + + Args: + logits: [num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + breaks: [num_bins - 1] the error bin edges. + + Returns: + aligned_confidence_probs: [num_res, num_res, num_bins] the predicted + aligned error probabilities over bins for each residue pair. + predicted_aligned_error: [num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: The maximum predicted error possible. + """ + aligned_confidence_probs = scipy.special.softmax( + logits, + axis=-1) + predicted_aligned_error, max_predicted_aligned_error = ( + _calculate_expected_aligned_error( + alignment_confidence_breaks=breaks, + aligned_distance_error_probs=aligned_confidence_probs)) + return { + 'aligned_confidence_probs': aligned_confidence_probs, + 'predicted_aligned_error': predicted_aligned_error, + 'max_predicted_aligned_error': max_predicted_aligned_error, + } + + +def predicted_tm_score( + logits: np.ndarray, + breaks: np.ndarray, + residue_weights: Optional[np.ndarray] = None, + asym_id: Optional[np.ndarray] = None, + interface: bool = False) -> np.ndarray: + """Computes predicted TM alignment or predicted interface TM alignment score. + + Args: + logits: [num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + breaks: [num_bins] the error bins. + residue_weights: [num_res] the per residue weights to use for the + expectation. + asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for + ipTM calculation, i.e. when interface=True. + interface: If True, interface predicted TM score is computed. + + Returns: + ptm_score: The predicted TM alignment or the predicted iTM score. + """ + + # residue_weights has to be in [0, 1], but can be floating-point, i.e. the + # exp. resolved head's probability. + if residue_weights is None: + residue_weights = np.ones(logits.shape[0]) + + bin_centers = _calculate_bin_centers(breaks) + + num_res = int(np.sum(residue_weights)) + # Clip num_res to avoid negative/undefined d0. + clipped_num_res = max(num_res, 19) + + # Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick + # "Scoring function for automated assessment of protein structure template + # quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf + d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8 + + # Convert logits to probs. + probs = scipy.special.softmax(logits, axis=-1) + + # TM-Score term for every bin. + tm_per_bin = 1. / (1 + np.square(bin_centers) / np.square(d0)) + # E_distances tm(distance). + predicted_tm_term = np.sum(probs * tm_per_bin, axis=-1) + + pair_mask = np.ones(shape=(num_res, num_res), dtype=bool) + if interface: + pair_mask *= asym_id[:, None] != asym_id[None, :] + + predicted_tm_term *= pair_mask + + pair_residue_weights = pair_mask * ( + residue_weights[None, :] * residue_weights[:, None]) + normed_residue_mask = pair_residue_weights / (1e-8 + np.sum( + pair_residue_weights, axis=-1, keepdims=True)) + per_alignment = np.sum(predicted_tm_term * normed_residue_mask, axis=-1) + return np.asarray(per_alignment[(per_alignment * residue_weights).argmax()]) |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/common/protein.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/common/protein.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,278 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Protein data type."""\n+import dataclasses\n+import io\n+from typing import Any, Mapping, Optional\n+from alphafold.common import residue_constants\n+from Bio.PDB import PDBParser\n+import numpy as np\n+\n+FeatureDict = Mapping[str, np.ndarray]\n+ModelOutput = Mapping[str, Any] # Is a nested dict.\n+\n+# Complete sequence of chain IDs supported by the PDB format.\n+PDB_CHAIN_IDS = \'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789\'\n+PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62.\n+\n+\n+@dataclasses.dataclass(frozen=True)\n+class Protein:\n+ """Protein structure representation."""\n+\n+ # Cartesian coordinates of atoms in angstroms. The atom types correspond to\n+ # residue_constants.atom_types, i.e. the first three are N, CA, CB.\n+ atom_positions: np.ndarray # [num_res, num_atom_type, 3]\n+\n+ # Amino-acid type for each residue represented as an integer between 0 and\n+ # 20, where 20 is \'X\'.\n+ aatype: np.ndarray # [num_res]\n+\n+ # Binary float mask to indicate presence of a particular atom. 1.0 if an atom\n+ # is present and 0.0 if not. This should be used for loss masking.\n+ atom_mask: np.ndarray # [num_res, num_atom_type]\n+\n+ # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.\n+ residue_index: np.ndarray # [num_res]\n+\n+ # 0-indexed number corresponding to the chain in the protein that this residue\n+ # belongs to.\n+ chain_index: np.ndarray # [num_res]\n+\n+ # B-factors, or temperature factors, of each residue (in sq. angstroms units),\n+ # representing the displacement of the residue from its ground truth mean\n+ # value.\n+ b_factors: np.ndarray # [num_res, num_atom_type]\n+\n+ def __post_init__(self):\n+ if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS:\n+ raise ValueError(\n+ f\'Cannot build an instance with more than {PDB_MAX_CHAINS} chains \'\n+ \'because these cannot be written to PDB format.\')\n+\n+\n+def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:\n+ """Takes a PDB string and constructs a Protein object.\n+\n+ WARNING: All non-standard residue types will be converted into UNK. All\n+ non-standard atoms will be ignored.\n+\n+ Args:\n+ pdb_str: The contents of the pdb file\n+ chain_id: If chain_id is specified (e.g. A), then only that chain\n+ is parsed. Otherwise all chains are parsed.\n+\n+ Returns:\n+ A new `Protein` parsed from the pdb contents.\n+ """\n+ pdb_fh = io.StringIO(pdb_str)\n+ parser = PDBParser(QUIET=True)\n+ structure = parser.get_structure(\'none\', pdb_fh)\n+ models = list(structure.get_models())\n+ if len(models) != 1:\n+ raise ValueError(\n+ f\'Only single model PDBs are supported. Found {len(models)} models.\')\n+ model = models[0]\n+\n+ atom_positions = []\n+ aatype = []\n+ atom_mask = []\n+ residue_index = []\n+ chain_ids = []\n+ b_factors = []\n+\n+ for chain in model:\n+ if chain_id is not None and chain.id != chain_id:\n+ continue\n+ for res in chain:\n+ if res.id[2] != \' \':\n+ raise ValueError(\n+ f\'PDB contains an insertion code at chain {chain.id} and residue \'\n+ f\'index {res.id[1]}. These are not supported.\')\n+ res_shortname = residue_constants.restype_3to1.get(res.resname, \'X\')\n+ restype_idx = residue_constants.restype_order.get(\n+ res_shortname, residue_constants.restype_num)\n+ pos = np.zeros((residue_constants'..b'ain integer indices to chain ID strings.\n+ chain_ids = {}\n+ for i in np.unique(chain_index): # np.unique gives sorted output.\n+ if i >= PDB_MAX_CHAINS:\n+ raise ValueError(\n+ f\'The PDB format supports at most {PDB_MAX_CHAINS} chains.\')\n+ chain_ids[i] = PDB_CHAIN_IDS[i]\n+\n+ pdb_lines.append(\'MODEL 1\')\n+ atom_index = 1\n+ last_chain_index = chain_index[0]\n+ # Add all atom sites.\n+ for i in range(aatype.shape[0]):\n+ # Close the previous chain if in a multichain PDB.\n+ if last_chain_index != chain_index[i]:\n+ pdb_lines.append(_chain_end(\n+ atom_index, res_1to3(aatype[i - 1]), chain_ids[chain_index[i - 1]],\n+ residue_index[i - 1]))\n+ last_chain_index = chain_index[i]\n+ atom_index += 1 # Atom index increases at the TER symbol.\n+\n+ res_name_3 = res_1to3(aatype[i])\n+ for atom_name, pos, mask, b_factor in zip(\n+ atom_types, atom_positions[i], atom_mask[i], b_factors[i]):\n+ if mask < 0.5:\n+ continue\n+\n+ record_type = \'ATOM\'\n+ name = atom_name if len(atom_name) == 4 else f\' {atom_name}\'\n+ alt_loc = \'\'\n+ insertion_code = \'\'\n+ occupancy = 1.00\n+ element = atom_name[0] # Protein supports only C, N, O, S, this works.\n+ charge = \'\'\n+ # PDB is a columnar format, every space matters here!\n+ atom_line = (f\'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}\'\n+ f\'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}\'\n+ f\'{residue_index[i]:>4}{insertion_code:>1} \'\n+ f\'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}\'\n+ f\'{occupancy:>6.2f}{b_factor:>6.2f} \'\n+ f\'{element:>2}{charge:>2}\')\n+ pdb_lines.append(atom_line)\n+ atom_index += 1\n+\n+ # Close the final chain.\n+ pdb_lines.append(_chain_end(atom_index, res_1to3(aatype[-1]),\n+ chain_ids[chain_index[-1]], residue_index[-1]))\n+ pdb_lines.append(\'ENDMDL\')\n+ pdb_lines.append(\'END\')\n+\n+ # Pad all lines to 80 characters.\n+ pdb_lines = [line.ljust(80) for line in pdb_lines]\n+ return \'\\n\'.join(pdb_lines) + \'\\n\' # Add terminating newline.\n+\n+\n+def ideal_atom_mask(prot: Protein) -> np.ndarray:\n+ """Computes an ideal atom mask.\n+\n+ `Protein.atom_mask` typically is defined according to the atoms that are\n+ reported in the PDB. This function computes a mask according to heavy atoms\n+ that should be present in the given sequence of amino acids.\n+\n+ Args:\n+ prot: `Protein` whose fields are `numpy.ndarray` objects.\n+\n+ Returns:\n+ An ideal atom mask.\n+ """\n+ return residue_constants.STANDARD_ATOM_MASK[prot.aatype]\n+\n+\n+def from_prediction(\n+ features: FeatureDict,\n+ result: ModelOutput,\n+ b_factors: Optional[np.ndarray] = None,\n+ remove_leading_feature_dimension: bool = True) -> Protein:\n+ """Assembles a protein from a prediction.\n+\n+ Args:\n+ features: Dictionary holding model inputs.\n+ result: Dictionary holding model outputs.\n+ b_factors: (Optional) B-factors to use for the protein.\n+ remove_leading_feature_dimension: Whether to remove the leading dimension\n+ of the `features` values.\n+\n+ Returns:\n+ A protein instance.\n+ """\n+ fold_output = result[\'structure_module\']\n+\n+ def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:\n+ return arr[0] if remove_leading_feature_dimension else arr\n+\n+ if \'asym_id\' in features:\n+ chain_index = _maybe_remove_leading_dim(features[\'asym_id\'])\n+ else:\n+ chain_index = np.zeros_like(_maybe_remove_leading_dim(features[\'aatype\']))\n+\n+ if b_factors is None:\n+ b_factors = np.zeros_like(fold_output[\'final_atom_mask\'])\n+\n+ return Protein(\n+ aatype=_maybe_remove_leading_dim(features[\'aatype\']),\n+ atom_positions=fold_output[\'final_atom_positions\'],\n+ atom_mask=fold_output[\'final_atom_mask\'],\n+ residue_index=_maybe_remove_leading_dim(features[\'residue_index\']) + 1,\n+ chain_index=chain_index,\n+ b_factors=b_factors)\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/common/protein_test.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/common/protein_test.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,114 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for protein.""" + +import os + +from absl.testing import absltest +from absl.testing import parameterized +from alphafold.common import protein +from alphafold.common import residue_constants +import numpy as np +# Internal import (7716). + +TEST_DATA_DIR = 'alphafold/common/testdata/' + + +class ProteinTest(parameterized.TestCase): + + def _check_shapes(self, prot, num_res): + """Check that the processed shapes are correct.""" + num_atoms = residue_constants.atom_type_num + self.assertEqual((num_res, num_atoms, 3), prot.atom_positions.shape) + self.assertEqual((num_res,), prot.aatype.shape) + self.assertEqual((num_res, num_atoms), prot.atom_mask.shape) + self.assertEqual((num_res,), prot.residue_index.shape) + self.assertEqual((num_res,), prot.chain_index.shape) + self.assertEqual((num_res, num_atoms), prot.b_factors.shape) + + @parameterized.named_parameters( + dict(testcase_name='chain_A', + pdb_file='2rbg.pdb', chain_id='A', num_res=282, num_chains=1), + dict(testcase_name='chain_B', + pdb_file='2rbg.pdb', chain_id='B', num_res=282, num_chains=1), + dict(testcase_name='multichain', + pdb_file='2rbg.pdb', chain_id=None, num_res=564, num_chains=2)) + def test_from_pdb_str(self, pdb_file, chain_id, num_res, num_chains): + pdb_file = os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, + pdb_file) + with open(pdb_file) as f: + pdb_string = f.read() + prot = protein.from_pdb_string(pdb_string, chain_id) + self._check_shapes(prot, num_res) + self.assertGreaterEqual(prot.aatype.min(), 0) + # Allow equal since unknown restypes have index equal to restype_num. + self.assertLessEqual(prot.aatype.max(), residue_constants.restype_num) + self.assertLen(np.unique(prot.chain_index), num_chains) + + def test_to_pdb(self): + with open( + os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, + '2rbg.pdb')) as f: + pdb_string = f.read() + prot = protein.from_pdb_string(pdb_string) + pdb_string_reconstr = protein.to_pdb(prot) + + for line in pdb_string_reconstr.splitlines(): + self.assertLen(line, 80) + + prot_reconstr = protein.from_pdb_string(pdb_string_reconstr) + + np.testing.assert_array_equal(prot_reconstr.aatype, prot.aatype) + np.testing.assert_array_almost_equal( + prot_reconstr.atom_positions, prot.atom_positions) + np.testing.assert_array_almost_equal( + prot_reconstr.atom_mask, prot.atom_mask) + np.testing.assert_array_equal( + prot_reconstr.residue_index, prot.residue_index) + np.testing.assert_array_equal( + prot_reconstr.chain_index, prot.chain_index) + np.testing.assert_array_almost_equal( + prot_reconstr.b_factors, prot.b_factors) + + def test_ideal_atom_mask(self): + with open( + os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, + '2rbg.pdb')) as f: + pdb_string = f.read() + prot = protein.from_pdb_string(pdb_string) + ideal_mask = protein.ideal_atom_mask(prot) + non_ideal_residues = set([102] + list(range(127, 286))) + for i, (res, atom_mask) in enumerate( + zip(prot.residue_index, prot.atom_mask)): + if res in non_ideal_residues: + self.assertFalse(np.all(atom_mask == ideal_mask[i]), msg=f'{res}') + else: + self.assertTrue(np.all(atom_mask == ideal_mask[i]), msg=f'{res}') + + def test_too_many_chains(self): + num_res = protein.PDB_MAX_CHAINS + 1 + num_atom_type = residue_constants.atom_type_num + with self.assertRaises(ValueError): + _ = protein.Protein( + atom_positions=np.random.random([num_res, num_atom_type, 3]), + aatype=np.random.randint(0, 21, [num_res]), + atom_mask=np.random.randint(0, 2, [num_res]).astype(np.float32), + residue_index=np.arange(1, num_res+1), + chain_index=np.arange(num_res), + b_factors=np.random.uniform(1, 100, [num_res])) + + +if __name__ == '__main__': + absltest.main() |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/common/residue_constants.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/common/residue_constants.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,897 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Constants used in AlphaFold."""\n+\n+import collections\n+import functools\n+import os\n+from typing import List, Mapping, Tuple\n+\n+import numpy as np\n+import tree\n+\n+# Internal import (35fd).\n+\n+\n+# Distance from one CA to next CA [trans configuration: omega = 180].\n+ca_ca = 3.80209737096\n+\n+# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in\n+# this order (or a relevant subset from chi1 onwards). ALA and GLY don\'t have\n+# chi angles so their chi angle lists are empty.\n+chi_angles_atoms = {\n+ \'ALA\': [],\n+ # Chi5 in arginine is always 0 +- 5 degrees, so ignore it.\n+ \'ARG\': [[\'N\', \'CA\', \'CB\', \'CG\'], [\'CA\', \'CB\', \'CG\', \'CD\'],\n+ [\'CB\', \'CG\', \'CD\', \'NE\'], [\'CG\', \'CD\', \'NE\', \'CZ\']],\n+ \'ASN\': [[\'N\', \'CA\', \'CB\', \'CG\'], [\'CA\', \'CB\', \'CG\', \'OD1\']],\n+ \'ASP\': [[\'N\', \'CA\', \'CB\', \'CG\'], [\'CA\', \'CB\', \'CG\', \'OD1\']],\n+ \'CYS\': [[\'N\', \'CA\', \'CB\', \'SG\']],\n+ \'GLN\': [[\'N\', \'CA\', \'CB\', \'CG\'], [\'CA\', \'CB\', \'CG\', \'CD\'],\n+ [\'CB\', \'CG\', \'CD\', \'OE1\']],\n+ \'GLU\': [[\'N\', \'CA\', \'CB\', \'CG\'], [\'CA\', \'CB\', \'CG\', \'CD\'],\n+ [\'CB\', \'CG\', \'CD\', \'OE1\']],\n+ \'GLY\': [],\n+ \'HIS\': [[\'N\', \'CA\', \'CB\', \'CG\'], [\'CA\', \'CB\', \'CG\', \'ND1\']],\n+ \'ILE\': [[\'N\', \'CA\', \'CB\', \'CG1\'], [\'CA\', \'CB\', \'CG1\', \'CD1\']],\n+ \'LEU\': [[\'N\', \'CA\', \'CB\', \'CG\'], [\'CA\', \'CB\', \'CG\', \'CD1\']],\n+ \'LYS\': [[\'N\', \'CA\', \'CB\', \'CG\'], [\'CA\', \'CB\', \'CG\', \'CD\'],\n+ [\'CB\', \'CG\', \'CD\', \'CE\'], [\'CG\', \'CD\', \'CE\', \'NZ\']],\n+ \'MET\': [[\'N\', \'CA\', \'CB\', \'CG\'], [\'CA\', \'CB\', \'CG\', \'SD\'],\n+ [\'CB\', \'CG\', \'SD\', \'CE\']],\n+ \'PHE\': [[\'N\', \'CA\', \'CB\', \'CG\'], [\'CA\', \'CB\', \'CG\', \'CD1\']],\n+ \'PRO\': [[\'N\', \'CA\', \'CB\', \'CG\'], [\'CA\', \'CB\', \'CG\', \'CD\']],\n+ \'SER\': [[\'N\', \'CA\', \'CB\', \'OG\']],\n+ \'THR\': [[\'N\', \'CA\', \'CB\', \'OG1\']],\n+ \'TRP\': [[\'N\', \'CA\', \'CB\', \'CG\'], [\'CA\', \'CB\', \'CG\', \'CD1\']],\n+ \'TYR\': [[\'N\', \'CA\', \'CB\', \'CG\'], [\'CA\', \'CB\', \'CG\', \'CD1\']],\n+ \'VAL\': [[\'N\', \'CA\', \'CB\', \'CG1\']],\n+}\n+\n+# If chi angles given in fixed-length array, this matrix determines how to mask\n+# them for each AA type. The order is as per restype_order (see below).\n+chi_angles_mask = [\n+ [0.0, 0.0, 0.0, 0.0], # ALA\n+ [1.0, 1.0, 1.0, 1.0], # ARG\n+ [1.0, 1.0, 0.0, 0.0], # ASN\n+ [1.0, 1.0, 0.0, 0.0], # ASP\n+ [1.0, 0.0, 0.0, 0.0], # CYS\n+ [1.0, 1.0, 1.0, 0.0], # GLN\n+ [1.0, 1.0, 1.0, 0.0], # GLU\n+ [0.0, 0.0, 0.0, 0.0], # GLY\n+ [1.0, 1.0, 0.0, 0.0], # HIS\n+ [1.0, 1.0, 0.0, 0.0], # ILE\n+ [1.0, 1.0, 0.0, 0.0], # LEU\n+ [1.0, 1.0, 1.0, 1.0], # LYS\n+ [1.0, 1.0, 1.0, 0.0], # MET\n+ [1.0, 1.0, 0.0, 0.0], # PHE\n+ [1.0, 1.0, 0.0, 0.0], # PRO\n+ [1.0, 0.0, 0.0, 0.0], # SER\n+ [1.0, 0.0, 0.0, 0.0], # THR\n+ [1.0, 1.0, 0.0, 0.0], # TRP\n+ [1.0, 1.0, 0.0, 0.0], # TYR\n+ [1.0, 0.0, 0.0, 0.0], # VAL\n+]\n+\n+# The following chi angles are pi periodic: they can be rotated by a multiple\n+# of pi without affecting the structure.\n+chi_pi_periodic = [\n+ [0.0, 0.0, 0.0, 0.0], # ALA\n+ [0.0, 0.0, 0.0, 0.0], # ARG\n+ [0.0, 0.0, 0.0, 0.0], # ASN\n+ [0.0, 1.0, 0.0, 0.0], # ASP\n+ [0.0, 0.0, 0.0, 0.0], # CYS\n+ [0.0, 0.0, 0.0, 0.0], # GLN\n+ [0.0, 0.0, 1.0, 0.0], # GLU\n+ [0.0, 0.0, 0.0, 0.0], # GLY\n+ [0.0, 0.0, 0.0, 0.0], # HIS\n+ [0.0, 0.0, 0.0, 0.0], # ILE\n+ [0.0, 0.0, 0.0, 0.0], # LEU\n+ [0.0, 0.0, 0.0, 0.0], # LYS\n+ [0.0, 0.0, 0.0'..b' :, :] = mat\n+\n+ # psi-frame to backbone\n+ mat = _make_rigid_transformation_4x4(\n+ ex=atom_positions[\'C\'] - atom_positions[\'CA\'],\n+ ey=atom_positions[\'CA\'] - atom_positions[\'N\'],\n+ translation=atom_positions[\'C\'])\n+ restype_rigid_group_default_frame[restype, 3, :, :] = mat\n+\n+ # chi1-frame to backbone\n+ if chi_angles_mask[restype][0]:\n+ base_atom_names = chi_angles_atoms[resname][0]\n+ base_atom_positions = [atom_positions[name] for name in base_atom_names]\n+ mat = _make_rigid_transformation_4x4(\n+ ex=base_atom_positions[2] - base_atom_positions[1],\n+ ey=base_atom_positions[0] - base_atom_positions[1],\n+ translation=base_atom_positions[2])\n+ restype_rigid_group_default_frame[restype, 4, :, :] = mat\n+\n+ # chi2-frame to chi1-frame\n+ # chi3-frame to chi2-frame\n+ # chi4-frame to chi3-frame\n+ # luckily all rotation axes for the next frame start at (0,0,0) of the\n+ # previous frame\n+ for chi_idx in range(1, 4):\n+ if chi_angles_mask[restype][chi_idx]:\n+ axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]\n+ axis_end_atom_position = atom_positions[axis_end_atom_name]\n+ mat = _make_rigid_transformation_4x4(\n+ ex=axis_end_atom_position,\n+ ey=np.array([-1., 0., 0.]),\n+ translation=axis_end_atom_position)\n+ restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat\n+\n+\n+_make_rigid_group_constants()\n+\n+\n+def make_atom14_dists_bounds(overlap_tolerance=1.5,\n+ bond_length_tolerance_factor=15):\n+ """compute upper and lower bounds for bonds to assess violations."""\n+ restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)\n+ restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)\n+ restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)\n+ residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()\n+ for restype, restype_letter in enumerate(restypes):\n+ resname = restype_1to3[restype_letter]\n+ atom_list = restype_name_to_atom14_names[resname]\n+\n+ # create lower and upper bounds for clashes\n+ for atom1_idx, atom1_name in enumerate(atom_list):\n+ if not atom1_name:\n+ continue\n+ atom1_radius = van_der_waals_radius[atom1_name[0]]\n+ for atom2_idx, atom2_name in enumerate(atom_list):\n+ if (not atom2_name) or atom1_idx == atom2_idx:\n+ continue\n+ atom2_radius = van_der_waals_radius[atom2_name[0]]\n+ lower = atom1_radius + atom2_radius - overlap_tolerance\n+ upper = 1e10\n+ restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower\n+ restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower\n+ restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper\n+ restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper\n+\n+ # overwrite lower and upper bounds for bonds and angles\n+ for b in residue_bonds[resname] + residue_virtual_bonds[resname]:\n+ atom1_idx = atom_list.index(b.atom1_name)\n+ atom2_idx = atom_list.index(b.atom2_name)\n+ lower = b.length - bond_length_tolerance_factor * b.stddev\n+ upper = b.length + bond_length_tolerance_factor * b.stddev\n+ restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower\n+ restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower\n+ restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper\n+ restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper\n+ restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev\n+ restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev\n+ return {\'lower_bound\': restype_atom14_bond_lower_bound, # shape (21,14,14)\n+ \'upper_bound\': restype_atom14_bond_upper_bound, # shape (21,14,14)\n+ \'stddev\': restype_atom14_bond_stddev, # shape (21,14,14)\n+ }\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/common/residue_constants_test.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/common/residue_constants_test.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,190 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Test that residue_constants generates correct values."""\n+\n+from absl.testing import absltest\n+from absl.testing import parameterized\n+from alphafold.common import residue_constants\n+import numpy as np\n+\n+\n+class ResidueConstantsTest(parameterized.TestCase):\n+\n+ @parameterized.parameters(\n+ (\'ALA\', 0),\n+ (\'CYS\', 1),\n+ (\'HIS\', 2),\n+ (\'MET\', 3),\n+ (\'LYS\', 4),\n+ (\'ARG\', 4),\n+ )\n+ def testChiAnglesAtoms(self, residue_name, chi_num):\n+ chi_angles_atoms = residue_constants.chi_angles_atoms[residue_name]\n+ self.assertLen(chi_angles_atoms, chi_num)\n+ for chi_angle_atoms in chi_angles_atoms:\n+ self.assertLen(chi_angle_atoms, 4)\n+\n+ def testChiGroupsForAtom(self):\n+ for k, chi_groups in residue_constants.chi_groups_for_atom.items():\n+ res_name, atom_name = k\n+ for chi_group_i, atom_i in chi_groups:\n+ self.assertEqual(\n+ atom_name,\n+ residue_constants.chi_angles_atoms[res_name][chi_group_i][atom_i])\n+\n+ @parameterized.parameters(\n+ (\'ALA\', 5), (\'ARG\', 11), (\'ASN\', 8), (\'ASP\', 8), (\'CYS\', 6), (\'GLN\', 9),\n+ (\'GLU\', 9), (\'GLY\', 4), (\'HIS\', 10), (\'ILE\', 8), (\'LEU\', 8), (\'LYS\', 9),\n+ (\'MET\', 8), (\'PHE\', 11), (\'PRO\', 7), (\'SER\', 6), (\'THR\', 7), (\'TRP\', 14),\n+ (\'TYR\', 12), (\'VAL\', 7)\n+ )\n+ def testResidueAtoms(self, atom_name, num_residue_atoms):\n+ residue_atoms = residue_constants.residue_atoms[atom_name]\n+ self.assertLen(residue_atoms, num_residue_atoms)\n+\n+ def testStandardAtomMask(self):\n+ with self.subTest(\'Check shape\'):\n+ self.assertEqual(residue_constants.STANDARD_ATOM_MASK.shape, (21, 37,))\n+\n+ with self.subTest(\'Check values\'):\n+ str_to_row = lambda s: [c == \'1\' for c in s] # More clear/concise.\n+ np.testing.assert_array_equal(\n+ residue_constants.STANDARD_ATOM_MASK,\n+ np.array([\n+ # NB This was defined by c+p but looks sane.\n+ str_to_row(\'11111 \'), # ALA\n+ str_to_row(\'111111 1 1 11 1 \'), # ARG\n+ str_to_row(\'111111 11 \'), # ASP\n+ str_to_row(\'111111 11 \'), # ASN\n+ str_to_row(\'11111 1 \'), # CYS\n+ str_to_row(\'111111 1 11 \'), # GLU\n+ str_to_row(\'111111 1 11 \'), # GLN\n+ str_to_row(\'111 1 \'), # GLY\n+ str_to_row(\'111111 11 1 1 \'), # HIS\n+ str_to_row(\'11111 11 1 \'), # ILE\n+ str_to_row(\'111111 11 \'), # LEU\n+ str_to_row(\'111111 1 1 1 \'), # LYS\n+ str_to_row(\'111111 11 \'), # MET\n+ str_to_row(\'111111 11 11 1 \'), # PHE\n+ str_to_row(\'111111 1 \'), # PRO\n+ str_to_row(\'11111 1 \'), # SER\n+ str_to_row(\'11111 1 1 \'), # THR\n+ str_to_row(\'111111 11 11 1 1 11 \'), # TRP\n+ str_to_row(\'111111 11 11 11 \'), # TYR\n+ str_to_'..b"ree_letter_restypes = [\n+ residue_constants.restype_1to3[r] for r in residue_constants.restypes]\n+ for restype, exp_restype in zip(\n+ three_letter_restypes, sorted(residue_constants.restype_1to3.values())):\n+ self.assertEqual(restype, exp_restype)\n+ self.assertEqual(residue_constants.restype_num, 20)\n+\n+ def testSequenceToOneHotHHBlits(self):\n+ one_hot = residue_constants.sequence_to_onehot(\n+ 'ABCDEFGHIJKLMNOPQRSTUVWXYZ-', residue_constants.HHBLITS_AA_TO_ID)\n+ exp_one_hot = np.array(\n+ [[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n+ [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n+ [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n+ [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n+ [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n+ [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n+ [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n+ [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n+ [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],\n+ [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],\n+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],\n+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],\n+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],\n+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],\n+ [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],\n+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],\n+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],\n+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],\n+ [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]])\n+ np.testing.assert_array_equal(one_hot, exp_one_hot)\n+\n+ def testSequenceToOneHotStandard(self):\n+ one_hot = residue_constants.sequence_to_onehot(\n+ 'ARNDCQEGHILKMFPSTWYV', residue_constants.restype_order)\n+ np.testing.assert_array_equal(one_hot, np.eye(20))\n+\n+ def testSequenceToOneHotUnknownMapping(self):\n+ seq = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'\n+ expected_out = np.zeros([26, 21])\n+ for row, position in enumerate(\n+ [0, 20, 4, 3, 6, 13, 7, 8, 9, 20, 11, 10, 12, 2, 20, 14, 5, 1, 15, 16,\n+ 20, 19, 17, 20, 18, 20]):\n+ expected_out[row, position] = 1\n+ aa_types = residue_constants.sequence_to_onehot(\n+ sequence=seq,\n+ mapping=residue_constants.restype_order_with_x,\n+ map_unknown_to_x=True)\n+ self.assertTrue((aa_types == expected_out).all())\n+\n+ @parameterized.named_parameters(\n+ ('lowercase', 'aaa'), # Insertions in A3M.\n+ ('gaps', '---'), # Gaps in A3M.\n+ ('dots', '...'), # Gaps in A3M.\n+ ('metadata', '>TEST'), # FASTA metadata line.\n+ )\n+ def testSequenceToOneHotUnknownMappingError(self, seq):\n+ with self.assertRaises(ValueError):\n+ residue_constants.sequence_to_onehot(\n+ sequence=seq,\n+ mapping=residue_constants.restype_order_with_x,\n+ map_unknown_to_x=True)\n+\n+\n+if __name__ == '__main__':\n+ absltest.main()\n" |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/common/testdata/2rbg.pdb --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/common/testdata/2rbg.pdb Tue Mar 01 02:53:05 2022 +0000 |
b |
b'@@ -0,0 +1,2784 @@\n+HEADER STRUCTURAL GENOMICS, UNKNOWN FUNCTION 19-SEP-07 2RBG \n+TITLE CRYSTAL STRUCTURE OF HYPOTHETICAL PROTEIN(ST0493) FROM \n+TITLE 2 SULFOLOBUS TOKODAII \n+COMPND MOL_ID: 1; \n+COMPND 2 MOLECULE: PUTATIVE UNCHARACTERIZED PROTEIN ST0493; \n+COMPND 3 CHAIN: A, B; \n+COMPND 4 ENGINEERED: YES \n+SOURCE MOL_ID: 1; \n+SOURCE 2 ORGANISM_SCIENTIFIC: SULFOLOBUS TOKODAII; \n+SOURCE 3 ORGANISM_TAXID: 111955; \n+SOURCE 4 STRAIN: STRAIN 7; \n+SOURCE 5 EXPRESSION_SYSTEM: ESCHERICHIA COLI; \n+SOURCE 6 EXPRESSION_SYSTEM_TAXID: 562; \n+SOURCE 7 EXPRESSION_SYSTEM_STRAIN: ROSETTA834(DE3); \n+SOURCE 8 EXPRESSION_SYSTEM_VECTOR_TYPE: PLASMID; \n+SOURCE 9 EXPRESSION_SYSTEM_PLASMID: PET-21A \n+KEYWDS HYPOTHETICAL PROTEIN, STRUCTURAL GENOMICS, UNKNOWN FUNCTION, \n+KEYWDS 2 NPPSFA, NATIONAL PROJECT ON PROTEIN STRUCTURAL AND \n+KEYWDS 3 FUNCTIONAL ANALYSES, RIKEN STRUCTURAL GENOMICS/PROTEOMICS \n+KEYWDS 4 INITIATIVE, RSGI \n+EXPDTA X-RAY DIFFRACTION \n+AUTHOR J.JEYAKANTHAN,S.KURAMITSU,S.YOKOYAMA,RIKEN STRUCTURAL \n+AUTHOR 2 GENOMICS/PROTEOMICS INITIATIVE (RSGI) \n+REVDAT 2 24-FEB-09 2RBG 1 VERSN \n+REVDAT 1 30-SEP-08 2RBG 0 \n+JRNL AUTH J.JEYAKANTHAN,S.KURAMITSU,S.YOKOYAMA \n+JRNL TITL CRYSTAL STRUCTURE OF HYPOTHETICAL PROTEIN(ST0493) \n+JRNL TITL 2 FROM SULFOLOBUS TOKODAII \n+JRNL REF TO BE PUBLISHED \n+JRNL REFN \n+REMARK 1 \n+REMARK 2 \n+REMARK 2 RESOLUTION. 1.75 ANGSTROMS. \n+REMARK 3 \n+REMARK 3 REFINEMENT. \n+REMARK 3 PROGRAM : CNS 1.1 \n+REMARK 3 AUTHORS : BRUNGER,ADAMS,CLORE,DELANO,GROS,GROSSE- \n+REMARK 3 : KUNSTLEVE,JIANG,KUSZEWSKI,NILGES, PANNU, \n+REMARK 3 : READ,RICE,SIMONSON,WARREN \n+REMARK 3 \n+REMARK 3 REFINEMENT TARGET : ENGH & HUBER \n+REMARK 3 \n+REMARK 3 DATA USED IN REFINEMENT. \n+REMARK 3 RESOLUTION RANGE HIGH (ANGSTROMS) : 1.75 \n+REMARK 3 RESOLUTION RANGE LOW (ANGSTROMS) : 33.49 \n+REMARK 3 DATA CUTOFF (SIGMA(F)) : 0.000 \n+REMARK 3 DATA CUTOFF HIGH (ABS(F)) : 2067291.840 \n+REMARK 3 DATA CUTOFF LOW (ABS(F)) : 0.0000 \n+REMARK 3 COMPLETENESS (WORKING+TEST) '..b'OH B 268 33.268 -11.967 34.839 1.00 44.52 O \n+HETATM 2367 O HOH B 269 21.291 7.640 55.382 1.00 37.07 O \n+HETATM 2368 O HOH B 270 40.543 -6.191 35.086 1.00 46.78 O \n+HETATM 2369 O HOH B 271 36.278 8.494 43.716 1.00 39.94 O \n+HETATM 2370 O HOH B 272 38.077 0.885 44.425 1.00 37.70 O \n+HETATM 2371 O HOH B 273 36.624 2.995 44.072 1.00 44.84 O \n+HETATM 2372 O HOH B 274 47.680 -3.802 54.241 1.00 29.52 O \n+HETATM 2373 O HOH B 275 47.542 -25.183 47.426 1.00 44.28 O \n+HETATM 2374 O HOH B 276 47.958 -0.641 51.434 1.00 41.18 O \n+HETATM 2375 O HOH B 277 48.773 -1.142 45.731 1.00 47.19 O \n+HETATM 2376 O HOH B 278 52.432 -3.449 47.286 1.00 34.07 O \n+HETATM 2377 O HOH B 279 22.927 -20.727 46.764 1.00 43.74 O \n+HETATM 2378 O HOH B 280 19.895 -12.192 66.540 1.00 37.79 O \n+HETATM 2379 O HOH B 281 41.198 10.198 58.267 1.00 48.98 O \n+HETATM 2380 O HOH B 282 44.205 11.703 55.646 1.00 52.92 O \n+HETATM 2381 O HOH B 283 42.359 7.497 60.196 1.00 46.88 O \n+HETATM 2382 O HOH B 284 43.862 -18.935 38.363 1.00 32.12 O \n+HETATM 2383 O HOH B 285 44.692 -12.023 39.188 1.00 33.96 O \n+CONECT 769 996 \n+CONECT 821 830 \n+CONECT 830 821 831 \n+CONECT 831 830 832 834 \n+CONECT 832 831 833 838 \n+CONECT 833 832 \n+CONECT 834 831 835 \n+CONECT 835 834 836 \n+CONECT 836 835 837 \n+CONECT 837 836 \n+CONECT 838 832 \n+CONECT 996 769 \n+CONECT 1800 2027 \n+CONECT 1852 1861 \n+CONECT 1861 1852 1862 \n+CONECT 1862 1861 1863 1865 \n+CONECT 1863 1862 1864 1869 \n+CONECT 1864 1863 \n+CONECT 1865 1862 1866 \n+CONECT 1866 1865 1867 \n+CONECT 1867 1866 1868 \n+CONECT 1868 1867 \n+CONECT 1869 1863 \n+CONECT 2027 1800 \n+CONECT 2063 2064 2065 2066 2067 \n+CONECT 2064 2063 \n+CONECT 2065 2063 \n+CONECT 2066 2063 \n+CONECT 2067 2063 \n+MASTER 266 0 3 17 10 0 2 6 2381 2 29 20 \n+END \n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/data/__init__.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/data/__init__.py Tue Mar 01 02:53:05 2022 +0000 |
b |
@@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data pipeline for model features.""" |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/data/feature_processing.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/data/feature_processing.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,231 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Feature processing logic for multimer data pipeline."""\n+\n+from typing import Iterable, MutableMapping, List\n+\n+from alphafold.common import residue_constants\n+from alphafold.data import msa_pairing\n+from alphafold.data import pipeline\n+import numpy as np\n+\n+REQUIRED_FEATURES = frozenset({\n+ \'aatype\', \'all_atom_mask\', \'all_atom_positions\', \'all_chains_entity_ids\',\n+ \'all_crops_all_chains_mask\', \'all_crops_all_chains_positions\',\n+ \'all_crops_all_chains_residue_ids\', \'assembly_num_chains\', \'asym_id\',\n+ \'bert_mask\', \'cluster_bias_mask\', \'deletion_matrix\', \'deletion_mean\',\n+ \'entity_id\', \'entity_mask\', \'mem_peak\', \'msa\', \'msa_mask\', \'num_alignments\',\n+ \'num_templates\', \'queue_size\', \'residue_index\', \'resolution\',\n+ \'seq_length\', \'seq_mask\', \'sym_id\', \'template_aatype\',\n+ \'template_all_atom_mask\', \'template_all_atom_positions\'\n+})\n+\n+MAX_TEMPLATES = 4\n+MSA_CROP_SIZE = 2048\n+\n+\n+def _is_homomer_or_monomer(chains: Iterable[pipeline.FeatureDict]) -> bool:\n+ """Checks if a list of chains represents a homomer/monomer example."""\n+ # Note that an entity_id of 0 indicates padding.\n+ num_unique_chains = len(np.unique(np.concatenate(\n+ [np.unique(chain[\'entity_id\'][chain[\'entity_id\'] > 0]) for\n+ chain in chains])))\n+ return num_unique_chains == 1\n+\n+\n+def pair_and_merge(\n+ all_chain_features: MutableMapping[str, pipeline.FeatureDict],\n+ is_prokaryote: bool) -> pipeline.FeatureDict:\n+ """Runs processing on features to augment, pair and merge.\n+\n+ Args:\n+ all_chain_features: A MutableMap of dictionaries of features for each chain.\n+ is_prokaryote: Whether the target complex is from a prokaryotic or\n+ eukaryotic organism.\n+\n+ Returns:\n+ A dictionary of features.\n+ """\n+\n+ process_unmerged_features(all_chain_features)\n+\n+ np_chains_list = list(all_chain_features.values())\n+\n+ pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list)\n+\n+ if pair_msa_sequences:\n+ np_chains_list = msa_pairing.create_paired_features(\n+ chains=np_chains_list, prokaryotic=is_prokaryote)\n+ np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list)\n+ np_chains_list = crop_chains(\n+ np_chains_list,\n+ msa_crop_size=MSA_CROP_SIZE,\n+ pair_msa_sequences=pair_msa_sequences,\n+ max_templates=MAX_TEMPLATES)\n+ np_example = msa_pairing.merge_chain_features(\n+ np_chains_list=np_chains_list, pair_msa_sequences=pair_msa_sequences,\n+ max_templates=MAX_TEMPLATES)\n+ np_example = process_final(np_example)\n+ return np_example\n+\n+\n+def crop_chains(\n+ chains_list: List[pipeline.FeatureDict],\n+ msa_crop_size: int,\n+ pair_msa_sequences: bool,\n+ max_templates: int) -> List[pipeline.FeatureDict]:\n+ """Crops the MSAs for a set of chains.\n+\n+ Args:\n+ chains_list: A list of chains to be cropped.\n+ msa_crop_size: The total number of sequences to crop from the MSA.\n+ pair_msa_sequences: Whether we are operating in sequence-pairing mode.\n+ max_templates: The maximum templates to use per chain.\n+\n+ Returns:\n+ The chains cropped.\n+ """\n+\n+ # Apply the cropping.\n+ cropped_chains = []\n+ for chain in chains_list:\n+ cropped_chain = _crop_single_chain(\n+ chain,\n+ msa_crop_size=msa_crop_size,\n+ pair_msa_sequences=pair_msa_sequences,\n+ max_templates=max_templates)\n+ cropped_chains.append(c'..b' msa_crop_size_all_seq)\n+\n+ # Restrict the unpaired crop size so that paired+unpaired sequences do not\n+ # exceed msa_seqs_per_chain for each chain.\n+ max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0)\n+ msa_crop_size = np.minimum(msa_size, max_msa_crop_size)\n+ else:\n+ msa_crop_size = np.minimum(msa_size, msa_crop_size)\n+\n+ include_templates = \'template_aatype\' in chain and max_templates\n+ if include_templates:\n+ num_templates = chain[\'template_aatype\'].shape[0]\n+ templates_crop_size = np.minimum(num_templates, max_templates)\n+\n+ for k in chain:\n+ k_split = k.split(\'_all_seq\')[0]\n+ if k_split in msa_pairing.TEMPLATE_FEATURES:\n+ chain[k] = chain[k][:templates_crop_size, :]\n+ elif k_split in msa_pairing.MSA_FEATURES:\n+ if \'_all_seq\' in k and pair_msa_sequences:\n+ chain[k] = chain[k][:msa_crop_size_all_seq, :]\n+ else:\n+ chain[k] = chain[k][:msa_crop_size, :]\n+\n+ chain[\'num_alignments\'] = np.asarray(msa_crop_size, dtype=np.int32)\n+ if include_templates:\n+ chain[\'num_templates\'] = np.asarray(templates_crop_size, dtype=np.int32)\n+ if pair_msa_sequences:\n+ chain[\'num_alignments_all_seq\'] = np.asarray(\n+ msa_crop_size_all_seq, dtype=np.int32)\n+ return chain\n+\n+\n+def process_final(np_example: pipeline.FeatureDict) -> pipeline.FeatureDict:\n+ """Final processing steps in data pipeline, after merging and pairing."""\n+ np_example = _correct_msa_restypes(np_example)\n+ np_example = _make_seq_mask(np_example)\n+ np_example = _make_msa_mask(np_example)\n+ np_example = _filter_features(np_example)\n+ return np_example\n+\n+\n+def _correct_msa_restypes(np_example):\n+ """Correct MSA restype to have the same order as residue_constants."""\n+ new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE\n+ np_example[\'msa\'] = np.take(new_order_list, np_example[\'msa\'], axis=0)\n+ np_example[\'msa\'] = np_example[\'msa\'].astype(np.int32)\n+ return np_example\n+\n+\n+def _make_seq_mask(np_example):\n+ np_example[\'seq_mask\'] = (np_example[\'entity_id\'] > 0).astype(np.float32)\n+ return np_example\n+\n+\n+def _make_msa_mask(np_example):\n+ """Mask features are all ones, but will later be zero-padded."""\n+\n+ np_example[\'msa_mask\'] = np.ones_like(np_example[\'msa\'], dtype=np.float32)\n+\n+ seq_mask = (np_example[\'entity_id\'] > 0).astype(np.float32)\n+ np_example[\'msa_mask\'] *= seq_mask[None]\n+\n+ return np_example\n+\n+\n+def _filter_features(np_example: pipeline.FeatureDict) -> pipeline.FeatureDict:\n+ """Filters features of example to only those requested."""\n+ return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES}\n+\n+\n+def process_unmerged_features(\n+ all_chain_features: MutableMapping[str, pipeline.FeatureDict]):\n+ """Postprocessing stage for per-chain features before merging."""\n+ num_chains = len(all_chain_features)\n+ for chain_features in all_chain_features.values():\n+ # Convert deletion matrices to float.\n+ chain_features[\'deletion_matrix\'] = np.asarray(\n+ chain_features.pop(\'deletion_matrix_int\'), dtype=np.float32)\n+ if \'deletion_matrix_int_all_seq\' in chain_features:\n+ chain_features[\'deletion_matrix_all_seq\'] = np.asarray(\n+ chain_features.pop(\'deletion_matrix_int_all_seq\'), dtype=np.float32)\n+\n+ chain_features[\'deletion_mean\'] = np.mean(\n+ chain_features[\'deletion_matrix\'], axis=0)\n+\n+ # Add all_atom_mask and dummy all_atom_positions based on aatype.\n+ all_atom_mask = residue_constants.STANDARD_ATOM_MASK[\n+ chain_features[\'aatype\']]\n+ chain_features[\'all_atom_mask\'] = all_atom_mask\n+ chain_features[\'all_atom_positions\'] = np.zeros(\n+ list(all_atom_mask.shape) + [3])\n+\n+ # Add assembly_num_chains.\n+ chain_features[\'assembly_num_chains\'] = np.asarray(num_chains)\n+\n+ # Add entity_mask.\n+ for chain_features in all_chain_features.values():\n+ chain_features[\'entity_mask\'] = (\n+ chain_features[\'entity_id\'] != 0).astype(np.int32)\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/data/mmcif_parsing.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/data/mmcif_parsing.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,386 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Parses the mmCIF file format."""\n+import collections\n+import dataclasses\n+import functools\n+import io\n+from typing import Any, Mapping, Optional, Sequence, Tuple\n+\n+from absl import logging\n+from Bio import PDB\n+from Bio.Data import SCOPData\n+\n+# Type aliases:\n+ChainId = str\n+PdbHeader = Mapping[str, Any]\n+PdbStructure = PDB.Structure.Structure\n+SeqRes = str\n+MmCIFDict = Mapping[str, Sequence[str]]\n+\n+\n+@dataclasses.dataclass(frozen=True)\n+class Monomer:\n+ id: str\n+ num: int\n+\n+\n+# Note - mmCIF format provides no guarantees on the type of author-assigned\n+# sequence numbers. They need not be integers.\n+@dataclasses.dataclass(frozen=True)\n+class AtomSite:\n+ residue_name: str\n+ author_chain_id: str\n+ mmcif_chain_id: str\n+ author_seq_num: str\n+ mmcif_seq_num: int\n+ insertion_code: str\n+ hetatm_atom: str\n+ model_num: int\n+\n+\n+# Used to map SEQRES index to a residue in the structure.\n+@dataclasses.dataclass(frozen=True)\n+class ResiduePosition:\n+ chain_id: str\n+ residue_number: int\n+ insertion_code: str\n+\n+\n+@dataclasses.dataclass(frozen=True)\n+class ResidueAtPosition:\n+ position: Optional[ResiduePosition]\n+ name: str\n+ is_missing: bool\n+ hetflag: str\n+\n+\n+@dataclasses.dataclass(frozen=True)\n+class MmcifObject:\n+ """Representation of a parsed mmCIF file.\n+\n+ Contains:\n+ file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all\n+ files being processed.\n+ header: Biopython header.\n+ structure: Biopython structure.\n+ chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g.\n+ {\'A\': \'ABCDEFG\'}\n+ seqres_to_structure: Dict; for each chain_id contains a mapping between\n+ SEQRES index and a ResidueAtPosition. e.g. {\'A\': {0: ResidueAtPosition,\n+ 1: ResidueAtPosition,\n+ ...}}\n+ raw_string: The raw string used to construct the MmcifObject.\n+ """\n+ file_id: str\n+ header: PdbHeader\n+ structure: PdbStructure\n+ chain_to_seqres: Mapping[ChainId, SeqRes]\n+ seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]]\n+ raw_string: Any\n+\n+\n+@dataclasses.dataclass(frozen=True)\n+class ParsingResult:\n+ """Returned by the parse function.\n+\n+ Contains:\n+ mmcif_object: A MmcifObject, may be None if no chain could be successfully\n+ parsed.\n+ errors: A dict mapping (file_id, chain_id) to any exception generated.\n+ """\n+ mmcif_object: Optional[MmcifObject]\n+ errors: Mapping[Tuple[str, str], Any]\n+\n+\n+class ParseError(Exception):\n+ """An error indicating that an mmCIF file could not be parsed."""\n+\n+\n+def mmcif_loop_to_list(prefix: str,\n+ parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]:\n+ """Extracts loop associated with a prefix from mmCIF data as a list.\n+\n+ Reference for loop_ in mmCIF:\n+ http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html\n+\n+ Args:\n+ prefix: Prefix shared by each of the data items in the loop.\n+ e.g. \'_entity_poly_seq.\', where the data items are _entity_poly_seq.num,\n+ _entity_poly_seq.mon_id. Should include the trailing period.\n+ parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython\n+ parser.\n+\n+ Returns:\n+ Returns a list of dicts; each dict represents 1 entry from an mmCIF loop.\n+ """\n+ cols ='..b'\n+\n+def get_release_date(parsed_info: MmCIFDict) -> str:\n+ """Returns the oldest revision date."""\n+ revision_dates = parsed_info[\'_pdbx_audit_revision_history.revision_date\']\n+ return min(revision_dates)\n+\n+\n+def _get_header(parsed_info: MmCIFDict) -> PdbHeader:\n+ """Returns a basic header containing method, release date and resolution."""\n+ header = {}\n+\n+ experiments = mmcif_loop_to_list(\'_exptl.\', parsed_info)\n+ header[\'structure_method\'] = \',\'.join([\n+ experiment[\'_exptl.method\'].lower() for experiment in experiments])\n+\n+ # Note: The release_date here corresponds to the oldest revision. We prefer to\n+ # use this for dataset filtering over the deposition_date.\n+ if \'_pdbx_audit_revision_history.revision_date\' in parsed_info:\n+ header[\'release_date\'] = get_release_date(parsed_info)\n+ else:\n+ logging.warning(\'Could not determine release_date: %s\',\n+ parsed_info[\'_entry.id\'])\n+\n+ header[\'resolution\'] = 0.00\n+ for res_key in (\'_refine.ls_d_res_high\', \'_em_3d_reconstruction.resolution\',\n+ \'_reflns.d_resolution_high\'):\n+ if res_key in parsed_info:\n+ try:\n+ raw_resolution = parsed_info[res_key][0]\n+ header[\'resolution\'] = float(raw_resolution)\n+ except ValueError:\n+ logging.debug(\'Invalid resolution format: %s\', parsed_info[res_key])\n+\n+ return header\n+\n+\n+def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:\n+ """Returns list of atom sites; contains data not present in the structure."""\n+ return [AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension\n+ parsed_info[\'_atom_site.label_comp_id\'],\n+ parsed_info[\'_atom_site.auth_asym_id\'],\n+ parsed_info[\'_atom_site.label_asym_id\'],\n+ parsed_info[\'_atom_site.auth_seq_id\'],\n+ parsed_info[\'_atom_site.label_seq_id\'],\n+ parsed_info[\'_atom_site.pdbx_PDB_ins_code\'],\n+ parsed_info[\'_atom_site.group_PDB\'],\n+ parsed_info[\'_atom_site.pdbx_PDB_model_num\'],\n+ )]\n+\n+\n+def _get_protein_chains(\n+ *, parsed_info: Mapping[str, Any]) -> Mapping[ChainId, Sequence[Monomer]]:\n+ """Extracts polymer information for protein chains only.\n+\n+ Args:\n+ parsed_info: _mmcif_dict produced by the Biopython parser.\n+\n+ Returns:\n+ A dict mapping mmcif chain id to a list of Monomers.\n+ """\n+ # Get polymer information for each entity in the structure.\n+ entity_poly_seqs = mmcif_loop_to_list(\'_entity_poly_seq.\', parsed_info)\n+\n+ polymers = collections.defaultdict(list)\n+ for entity_poly_seq in entity_poly_seqs:\n+ polymers[entity_poly_seq[\'_entity_poly_seq.entity_id\']].append(\n+ Monomer(id=entity_poly_seq[\'_entity_poly_seq.mon_id\'],\n+ num=int(entity_poly_seq[\'_entity_poly_seq.num\'])))\n+\n+ # Get chemical compositions. Will allow us to identify which of these polymers\n+ # are proteins.\n+ chem_comps = mmcif_loop_to_dict(\'_chem_comp.\', \'_chem_comp.id\', parsed_info)\n+\n+ # Get chains information for each entity. Necessary so that we can return a\n+ # dict keyed on chain id rather than entity.\n+ struct_asyms = mmcif_loop_to_list(\'_struct_asym.\', parsed_info)\n+\n+ entity_to_mmcif_chains = collections.defaultdict(list)\n+ for struct_asym in struct_asyms:\n+ chain_id = struct_asym[\'_struct_asym.id\']\n+ entity_id = struct_asym[\'_struct_asym.entity_id\']\n+ entity_to_mmcif_chains[entity_id].append(chain_id)\n+\n+ # Identify and return the valid protein chains.\n+ valid_chains = {}\n+ for entity_id, seq_info in polymers.items():\n+ chain_ids = entity_to_mmcif_chains[entity_id]\n+\n+ # Reject polymers without any peptide-like components, such as DNA/RNA.\n+ if any([\'peptide\' in chem_comps[monomer.id][\'_chem_comp.type\']\n+ for monomer in seq_info]):\n+ for chain_id in chain_ids:\n+ valid_chains[chain_id] = seq_info\n+ return valid_chains\n+\n+\n+def _is_set(data: str) -> bool:\n+ """Returns False if data is a special mmCIF character indicating \'unset\'."""\n+ return data not in (\'.\', \'?\')\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/data/msa_identifiers.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/data/msa_identifiers.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,92 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for extracting identifiers from MSA sequence descriptions.""" + +import dataclasses +import re +from typing import Optional + + +# Sequences coming from UniProtKB database come in the +# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE` +# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively). +_UNIPROT_PATTERN = re.compile( + r""" + ^ + # UniProtKB/TrEMBL or UniProtKB/Swiss-Prot + (?:tr|sp) + \| + # A primary accession number of the UniProtKB entry. + (?P<AccessionIdentifier>[A-Za-z0-9]{6,10}) + # Occasionally there is a _0 or _1 isoform suffix, which we ignore. + (?:_\d)? + \| + # TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic + # protein ID code. + (?:[A-Za-z0-9]+) + _ + # A mnemonic species identification code. + (?P<SpeciesIdentifier>([A-Za-z0-9]){1,5}) + # Small BFD uses a final value after an underscore, which we ignore. + (?:_\d+)? + $ + """, + re.VERBOSE) + + +@dataclasses.dataclass(frozen=True) +class Identifiers: + uniprot_accession_id: str = '' + species_id: str = '' + + +def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers: + """Gets accession id and species from an msa sequence identifier. + + The sequence identifier has the format specified by + _UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN. + An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE` + + Args: + msa_sequence_identifier: a sequence identifier. + + Returns: + An `Identifiers` instance with a uniprot_accession_id and species_id. These + can be empty in the case where no identifier was found. + """ + matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip()) + if matches: + return Identifiers( + uniprot_accession_id=matches.group('AccessionIdentifier'), + species_id=matches.group('SpeciesIdentifier')) + return Identifiers() + + +def _extract_sequence_identifier(description: str) -> Optional[str]: + """Extracts sequence identifier from description. Returns None if no match.""" + split_description = description.split() + if split_description: + return split_description[0].partition('/')[0] + else: + return None + + +def get_identifiers(description: str) -> Identifiers: + """Computes extra MSA features from the description.""" + sequence_identifier = _extract_sequence_identifier(description) + if sequence_identifier is None: + return Identifiers() + else: + return _parse_sequence_identifier(sequence_identifier) |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/data/msa_pairing.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/data/msa_pairing.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,638 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Pairing logic for multimer data pipeline."""\n+\n+import collections\n+import functools\n+import re\n+import string\n+from typing import Any, Dict, Iterable, List, Sequence\n+\n+from alphafold.common import residue_constants\n+from alphafold.data import pipeline\n+import numpy as np\n+import pandas as pd\n+import scipy.linalg\n+\n+ALPHA_ACCESSION_ID_MAP = {x: y for y, x in enumerate(string.ascii_uppercase)}\n+ALPHANUM_ACCESSION_ID_MAP = {\n+ chr: num for num, chr in enumerate(string.ascii_uppercase + string.digits)\n+} # A-Z,0-9\n+NUM_ACCESSION_ID_MAP = {str(x): x for x in range(10)} # 0-9\n+\n+MSA_GAP_IDX = residue_constants.restypes_with_x_and_gap.index(\'-\')\n+SEQUENCE_GAP_CUTOFF = 0.5\n+SEQUENCE_SIMILARITY_CUTOFF = 0.9\n+\n+MSA_PAD_VALUES = {\'msa_all_seq\': MSA_GAP_IDX,\n+ \'msa_mask_all_seq\': 1,\n+ \'deletion_matrix_all_seq\': 0,\n+ \'deletion_matrix_int_all_seq\': 0,\n+ \'msa\': MSA_GAP_IDX,\n+ \'msa_mask\': 1,\n+ \'deletion_matrix\': 0,\n+ \'deletion_matrix_int\': 0}\n+\n+MSA_FEATURES = (\'msa\', \'msa_mask\', \'deletion_matrix\', \'deletion_matrix_int\')\n+SEQ_FEATURES = (\'residue_index\', \'aatype\', \'all_atom_positions\',\n+ \'all_atom_mask\', \'seq_mask\', \'between_segment_residues\',\n+ \'has_alt_locations\', \'has_hetatoms\', \'asym_id\', \'entity_id\',\n+ \'sym_id\', \'entity_mask\', \'deletion_mean\',\n+ \'prediction_atom_mask\',\n+ \'literature_positions\', \'atom_indices_to_group_indices\',\n+ \'rigid_group_default_frame\')\n+TEMPLATE_FEATURES = (\'template_aatype\', \'template_all_atom_positions\',\n+ \'template_all_atom_mask\')\n+CHAIN_FEATURES = (\'num_alignments\', \'seq_length\')\n+\n+\n+domain_name_pattern = re.compile(\n+ r\'\'\'^(?P<pdb>[a-z\\d]{4})\n+ \\{(?P<bioassembly>[\\d+(\\+\\d+)?])\\}\n+ (?P<chain>[a-zA-Z\\d]+)\n+ \\{(?P<transform_index>\\d+)\\}$\n+ \'\'\', re.VERBOSE)\n+\n+\n+def create_paired_features(\n+ chains: Iterable[pipeline.FeatureDict],\n+ prokaryotic: bool,\n+ ) -> List[pipeline.FeatureDict]:\n+ """Returns the original chains with paired NUM_SEQ features.\n+\n+ Args:\n+ chains: A list of feature dictionaries for each chain.\n+ prokaryotic: Whether the target complex is from a prokaryotic organism.\n+ Used to determine the distance metric for pairing.\n+\n+ Returns:\n+ A list of feature dictionaries with sequence features including only\n+ rows to be paired.\n+ """\n+ chains = list(chains)\n+ chain_keys = chains[0].keys()\n+\n+ if len(chains) < 2:\n+ return chains\n+ else:\n+ updated_chains = []\n+ paired_chains_to_paired_row_indices = pair_sequences(\n+ chains, prokaryotic)\n+ paired_rows = reorder_paired_rows(\n+ paired_chains_to_paired_row_indices)\n+\n+ for chain_num, chain in enumerate(chains):\n+ new_chain = {k: v for k, v in chain.items() if \'_all_seq\' not in k}\n+ for feature_name in chain_keys:\n+ if feature_name.endswith(\'_all_seq\'):\n+ feats_padded = pad_features(chain[feature_name], feature_name)\n+ new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]]\n+ new_chain[\'num_alignments_all_seq\'] = np.asarray(\n+ len(paired_rows[:, chain_num]))\n+ updated_chains.append(new_chain)\n+ return updated_chains\n+\n+\n+def pad_features(feature: np'..b'1)\n+ elif feature_name_split in CHAIN_FEATURES:\n+ merged_example[feature_name] = np.sum(x for x in feats).astype(np.int32)\n+ else:\n+ merged_example[feature_name] = feats[0]\n+ return merged_example\n+\n+\n+def _merge_homomers_dense_msa(\n+ chains: Iterable[pipeline.FeatureDict]) -> Sequence[pipeline.FeatureDict]:\n+ """Merge all identical chains, making the resulting MSA dense.\n+\n+ Args:\n+ chains: An iterable of features for each chain.\n+\n+ Returns:\n+ A list of feature dictionaries. All features with the same entity_id\n+ will be merged - MSA features will be concatenated along the num_res\n+ dimension - making them dense.\n+ """\n+ entity_chains = collections.defaultdict(list)\n+ for chain in chains:\n+ entity_id = chain[\'entity_id\'][0]\n+ entity_chains[entity_id].append(chain)\n+\n+ grouped_chains = []\n+ for entity_id in sorted(entity_chains):\n+ chains = entity_chains[entity_id]\n+ grouped_chains.append(chains)\n+ chains = [\n+ _merge_features_from_multiple_chains(chains, pair_msa_sequences=True)\n+ for chains in grouped_chains]\n+ return chains\n+\n+\n+def _concatenate_paired_and_unpaired_features(\n+ example: pipeline.FeatureDict) -> pipeline.FeatureDict:\n+ """Merges paired and block-diagonalised features."""\n+ features = MSA_FEATURES\n+ for feature_name in features:\n+ if feature_name in example:\n+ feat = example[feature_name]\n+ feat_all_seq = example[feature_name + \'_all_seq\']\n+ merged_feat = np.concatenate([feat_all_seq, feat], axis=0)\n+ example[feature_name] = merged_feat\n+ example[\'num_alignments\'] = np.array(example[\'msa\'].shape[0],\n+ dtype=np.int32)\n+ return example\n+\n+\n+def merge_chain_features(np_chains_list: List[pipeline.FeatureDict],\n+ pair_msa_sequences: bool,\n+ max_templates: int) -> pipeline.FeatureDict:\n+ """Merges features for multiple chains to single FeatureDict.\n+\n+ Args:\n+ np_chains_list: List of FeatureDicts for each chain.\n+ pair_msa_sequences: Whether to merge paired MSAs.\n+ max_templates: The maximum number of templates to include.\n+\n+ Returns:\n+ Single FeatureDict for entire complex.\n+ """\n+ np_chains_list = _pad_templates(\n+ np_chains_list, max_templates=max_templates)\n+ np_chains_list = _merge_homomers_dense_msa(np_chains_list)\n+ # Unpaired MSA features will be always block-diagonalised; paired MSA\n+ # features will be concatenated.\n+ np_example = _merge_features_from_multiple_chains(\n+ np_chains_list, pair_msa_sequences=False)\n+ if pair_msa_sequences:\n+ np_example = _concatenate_paired_and_unpaired_features(np_example)\n+ np_example = _correct_post_merged_feats(\n+ np_example=np_example,\n+ np_chains_list=np_chains_list,\n+ pair_msa_sequences=pair_msa_sequences)\n+\n+ return np_example\n+\n+\n+def deduplicate_unpaired_sequences(\n+ np_chains: List[pipeline.FeatureDict]) -> List[pipeline.FeatureDict]:\n+ """Removes unpaired sequences which duplicate a paired sequence."""\n+\n+ feature_names = np_chains[0].keys()\n+ msa_features = MSA_FEATURES\n+\n+ for chain in np_chains:\n+ sequence_set = set(tuple(s) for s in chain[\'msa_all_seq\'])\n+ keep_rows = []\n+ # Go through unpaired MSA seqs and remove any rows that correspond to the\n+ # sequences that are already present in the paired MSA.\n+ for row_num, seq in enumerate(chain[\'msa\']):\n+ if tuple(seq) not in sequence_set:\n+ keep_rows.append(row_num)\n+ for feature_name in feature_names:\n+ if feature_name in msa_features:\n+ if keep_rows:\n+ chain[feature_name] = chain[feature_name][keep_rows]\n+ else:\n+ new_shape = list(chain[feature_name].shape)\n+ new_shape[0] = 0\n+ chain[feature_name] = np.zeros(new_shape,\n+ dtype=chain[feature_name].dtype)\n+ chain[\'num_alignments\'] = np.array(chain[\'msa\'].shape[0], dtype=np.int32)\n+ return np_chains\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/data/parsers.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/data/parsers.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,607 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Functions for parsing various file formats."""\n+import collections\n+import dataclasses\n+import itertools\n+import re\n+import string\n+from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set\n+\n+DeletionMatrix = Sequence[Sequence[int]]\n+\n+\n+@dataclasses.dataclass(frozen=True)\n+class Msa:\n+ """Class representing a parsed MSA file."""\n+ sequences: Sequence[str]\n+ deletion_matrix: DeletionMatrix\n+ descriptions: Sequence[str]\n+\n+ def __post_init__(self):\n+ if not (len(self.sequences) ==\n+ len(self.deletion_matrix) ==\n+ len(self.descriptions)):\n+ raise ValueError(\n+ \'All fields for an MSA must have the same length. \'\n+ f\'Got {len(self.sequences)} sequences, \'\n+ f\'{len(self.deletion_matrix)} rows in the deletion matrix and \'\n+ f\'{len(self.descriptions)} descriptions.\')\n+\n+ def __len__(self):\n+ return len(self.sequences)\n+\n+ def truncate(self, max_seqs: int):\n+ return Msa(sequences=self.sequences[:max_seqs],\n+ deletion_matrix=self.deletion_matrix[:max_seqs],\n+ descriptions=self.descriptions[:max_seqs])\n+\n+\n+@dataclasses.dataclass(frozen=True)\n+class TemplateHit:\n+ """Class representing a template hit."""\n+ index: int\n+ name: str\n+ aligned_cols: int\n+ sum_probs: Optional[float]\n+ query: str\n+ hit_sequence: str\n+ indices_query: List[int]\n+ indices_hit: List[int]\n+\n+\n+def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:\n+ """Parses FASTA string and returns list of strings with amino-acid sequences.\n+\n+ Arguments:\n+ fasta_string: The string contents of a FASTA file.\n+\n+ Returns:\n+ A tuple of two lists:\n+ * A list of sequences.\n+ * A list of sequence descriptions taken from the comment lines. In the\n+ same order as the sequences.\n+ """\n+ sequences = []\n+ descriptions = []\n+ index = -1\n+ for line in fasta_string.splitlines():\n+ line = line.strip()\n+ if line.startswith(\'>\'):\n+ index += 1\n+ descriptions.append(line[1:]) # Remove the \'>\' at the beginning.\n+ sequences.append(\'\')\n+ continue\n+ elif not line:\n+ continue # Skip blank lines.\n+ sequences[index] += line\n+\n+ return sequences, descriptions\n+\n+\n+def parse_stockholm(stockholm_string: str) -> Msa:\n+ """Parses sequences and deletion matrix from stockholm format alignment.\n+\n+ Args:\n+ stockholm_string: The string contents of a stockholm file. The first\n+ sequence in the file should be the query sequence.\n+\n+ Returns:\n+ A tuple of:\n+ * A list of sequences that have been aligned to the query. These\n+ might contain duplicates.\n+ * The deletion matrix for the alignment as a list of lists. The element\n+ at `deletion_matrix[i][j]` is the number of residues deleted from\n+ the aligned sequence i at residue position j.\n+ * The names of the targets matched, including the jackhmmer subsequence\n+ suffix.\n+ """\n+ name_to_sequence = collections.OrderedDict()\n+ for line in stockholm_string.splitlines():\n+ line = line.strip()\n+ if not line or line.startswith((\'#\', \'//\')):\n+ continue\n+ name, sequence = line.split()\n+ if name not in name_to_sequence:\n+ name_to_sequence[name] = \'\'\n+ name_to_sequence[name] += sequence\n+\n+ msa = []\n+ deletion_matrix = []\n+\n+ query = \'\'\n+ keep_columns = []\n+ f'..b' a results table, then has a sequence of hit\n+ # "paragraphs", each paragraph starting with a line \'No <hit number>\'. We\n+ # iterate through each paragraph to parse each hit.\n+\n+ block_starts = [i for i, line in enumerate(lines) if line.startswith(\'No \')]\n+\n+ hits = []\n+ if block_starts:\n+ block_starts.append(len(lines)) # Add the end of the final block.\n+ for i in range(len(block_starts) - 1):\n+ hits.append(_parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]]))\n+ return hits\n+\n+\n+def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:\n+ """Parse target to e-value mapping parsed from Jackhmmer tblout string."""\n+ e_values = {\'query\': 0}\n+ lines = [line for line in tblout.splitlines() if line[0] != \'#\']\n+ # As per http://eddylab.org/software/hmmer/Userguide.pdf fields are\n+ # space-delimited. Relevant fields are (1) target name: and\n+ # (5) E-value (full sequence) (numbering from 1).\n+ for line in lines:\n+ fields = line.split()\n+ e_value = fields[4]\n+ target_name = fields[0]\n+ e_values[target_name] = float(e_value)\n+ return e_values\n+\n+\n+def _get_indices(sequence: str, start: int) -> List[int]:\n+ """Returns indices for non-gap/insert residues starting at the given index."""\n+ indices = []\n+ counter = start\n+ for symbol in sequence:\n+ # Skip gaps but add a placeholder so that the alignment is preserved.\n+ if symbol == \'-\':\n+ indices.append(-1)\n+ # Skip deleted residues, but increase the counter.\n+ elif symbol.islower():\n+ counter += 1\n+ # Normal aligned residue. Increase the counter and append to indices.\n+ else:\n+ indices.append(counter)\n+ counter += 1\n+ return indices\n+\n+\n+@dataclasses.dataclass(frozen=True)\n+class HitMetadata:\n+ pdb_id: str\n+ chain: str\n+ start: int\n+ end: int\n+ length: int\n+ text: str\n+\n+\n+def _parse_hmmsearch_description(description: str) -> HitMetadata:\n+ """Parses the hmmsearch A3M sequence description line."""\n+ # Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text\n+ # Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352\n+ match = re.match(\n+ r\'^>?([a-z0-9]+)_(\\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$\',\n+ description.strip())\n+\n+ if not match:\n+ raise ValueError(f\'Could not parse description: "{description}".\')\n+\n+ return HitMetadata(\n+ pdb_id=match[1],\n+ chain=match[2],\n+ start=int(match[3]),\n+ end=int(match[4]),\n+ length=int(match[5]),\n+ text=match[6])\n+\n+\n+def parse_hmmsearch_a3m(query_sequence: str,\n+ a3m_string: str,\n+ skip_first: bool = True) -> Sequence[TemplateHit]:\n+ """Parses an a3m string produced by hmmsearch.\n+\n+ Args:\n+ query_sequence: The query sequence.\n+ a3m_string: The a3m string produced by hmmsearch.\n+ skip_first: Whether to skip the first sequence in the a3m string.\n+\n+ Returns:\n+ A sequence of `TemplateHit` results.\n+ """\n+ # Zip the descriptions and MSAs together, skip the first query sequence.\n+ parsed_a3m = list(zip(*parse_fasta(a3m_string)))\n+ if skip_first:\n+ parsed_a3m = parsed_a3m[1:]\n+\n+ indices_query = _get_indices(query_sequence, start=0)\n+\n+ hits = []\n+ for i, (hit_sequence, hit_description) in enumerate(parsed_a3m, start=1):\n+ if \'mol:protein\' not in hit_description:\n+ continue # Skip non-protein chains.\n+ metadata = _parse_hmmsearch_description(hit_description)\n+ # Aligned columns are only the match states.\n+ aligned_cols = sum([r.isupper() and r != \'-\' for r in hit_sequence])\n+ indices_hit = _get_indices(hit_sequence, start=metadata.start - 1)\n+\n+ hit = TemplateHit(\n+ index=i,\n+ name=f\'{metadata.pdb_id}_{metadata.chain}\',\n+ aligned_cols=aligned_cols,\n+ sum_probs=None,\n+ query=query_sequence,\n+ hit_sequence=hit_sequence.upper(),\n+ indices_query=indices_query,\n+ indices_hit=indices_hit,\n+ )\n+ hits.append(hit)\n+\n+ return hits\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/data/pipeline.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/data/pipeline.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,230 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Functions for building the input features for the AlphaFold model."""\n+\n+import os\n+from typing import Any, Mapping, MutableMapping, Optional, Sequence, Union\n+from absl import logging\n+from alphafold.common import residue_constants\n+from alphafold.data import msa_identifiers\n+from alphafold.data import parsers\n+from alphafold.data import templates\n+from alphafold.data.tools import hhblits\n+from alphafold.data.tools import hhsearch\n+from alphafold.data.tools import hmmsearch\n+from alphafold.data.tools import jackhmmer\n+import numpy as np\n+\n+# Internal import (7716).\n+\n+FeatureDict = MutableMapping[str, np.ndarray]\n+TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]\n+\n+\n+def make_sequence_features(\n+ sequence: str, description: str, num_res: int) -> FeatureDict:\n+ """Constructs a feature dict of sequence features."""\n+ features = {}\n+ features[\'aatype\'] = residue_constants.sequence_to_onehot(\n+ sequence=sequence,\n+ mapping=residue_constants.restype_order_with_x,\n+ map_unknown_to_x=True)\n+ features[\'between_segment_residues\'] = np.zeros((num_res,), dtype=np.int32)\n+ features[\'domain_name\'] = np.array([description.encode(\'utf-8\')],\n+ dtype=np.object_)\n+ features[\'residue_index\'] = np.array(range(num_res), dtype=np.int32)\n+ features[\'seq_length\'] = np.array([num_res] * num_res, dtype=np.int32)\n+ features[\'sequence\'] = np.array([sequence.encode(\'utf-8\')], dtype=np.object_)\n+ return features\n+\n+\n+def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:\n+ """Constructs a feature dict of MSA features."""\n+ if not msas:\n+ raise ValueError(\'At least one MSA must be provided.\')\n+\n+ int_msa = []\n+ deletion_matrix = []\n+ uniprot_accession_ids = []\n+ species_ids = []\n+ seen_sequences = set()\n+ for msa_index, msa in enumerate(msas):\n+ if not msa:\n+ raise ValueError(f\'MSA {msa_index} must contain at least one sequence.\')\n+ for sequence_index, sequence in enumerate(msa.sequences):\n+ if sequence in seen_sequences:\n+ continue\n+ seen_sequences.add(sequence)\n+ int_msa.append(\n+ [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence])\n+ deletion_matrix.append(msa.deletion_matrix[sequence_index])\n+ identifiers = msa_identifiers.get_identifiers(\n+ msa.descriptions[sequence_index])\n+ uniprot_accession_ids.append(\n+ identifiers.uniprot_accession_id.encode(\'utf-8\'))\n+ species_ids.append(identifiers.species_id.encode(\'utf-8\'))\n+\n+ num_res = len(msas[0].sequences[0])\n+ num_alignments = len(int_msa)\n+ features = {}\n+ features[\'deletion_matrix_int\'] = np.array(deletion_matrix, dtype=np.int32)\n+ features[\'msa\'] = np.array(int_msa, dtype=np.int32)\n+ features[\'num_alignments\'] = np.array(\n+ [num_alignments] * num_res, dtype=np.int32)\n+ features[\'msa_uniprot_accession_identifiers\'] = np.array(\n+ uniprot_accession_ids, dtype=np.object_)\n+ features[\'msa_species_identifiers\'] = np.array(species_ids, dtype=np.object_)\n+ return features\n+\n+\n+def run_msa_tool(msa_runner, input_fasta_path: str, msa_out_path: str,\n+ msa_format: str, use_precomputed_msas: bool,\n+ ) -> Mapping[str, Any]:\n+ """Runs an MSA tool, checking if output already exists first."""\n+ if not use_precomputed_msas or not os.path.exists(msa_out_path):'..b"\n+ input_fasta_str = f.read()\n+ input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)\n+ if len(input_seqs) != 1:\n+ raise ValueError(\n+ f'More than one input sequence found in {input_fasta_path}.')\n+ input_sequence = input_seqs[0]\n+ input_description = input_descs[0]\n+ num_res = len(input_sequence)\n+\n+ uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')\n+ jackhmmer_uniref90_result = run_msa_tool(\n+ self.jackhmmer_uniref90_runner, input_fasta_path, uniref90_out_path,\n+ 'sto', self.use_precomputed_msas)\n+ mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')\n+ jackhmmer_mgnify_result = run_msa_tool(\n+ self.jackhmmer_mgnify_runner, input_fasta_path, mgnify_out_path, 'sto',\n+ self.use_precomputed_msas)\n+\n+ msa_for_templates = jackhmmer_uniref90_result['sto']\n+ msa_for_templates = parsers.truncate_stockholm_msa(\n+ msa_for_templates, max_sequences=self.uniref_max_hits)\n+ msa_for_templates = parsers.deduplicate_stockholm_msa(\n+ msa_for_templates)\n+ msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(\n+ msa_for_templates)\n+\n+ if self.template_searcher.input_format == 'sto':\n+ pdb_templates_result = self.template_searcher.query(msa_for_templates)\n+ elif self.template_searcher.input_format == 'a3m':\n+ uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(msa_for_templates)\n+ pdb_templates_result = self.template_searcher.query(uniref90_msa_as_a3m)\n+ else:\n+ raise ValueError('Unrecognized template input format: '\n+ f'{self.template_searcher.input_format}')\n+\n+ pdb_hits_out_path = os.path.join(\n+ msa_output_dir, f'pdb_hits.{self.template_searcher.output_format}')\n+ with open(pdb_hits_out_path, 'w') as f:\n+ f.write(pdb_templates_result)\n+\n+ uniref90_msa = parsers.parse_stockholm(jackhmmer_uniref90_result['sto'])\n+ uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits)\n+ mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])\n+ mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits)\n+\n+ pdb_template_hits = self.template_searcher.get_template_hits(\n+ output_string=pdb_templates_result, input_sequence=input_sequence)\n+\n+ if self._use_small_bfd:\n+ bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto')\n+ jackhmmer_small_bfd_result = run_msa_tool(\n+ self.jackhmmer_small_bfd_runner, input_fasta_path, bfd_out_path,\n+ 'sto', self.use_precomputed_msas)\n+ bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])\n+ else:\n+ bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')\n+ hhblits_bfd_uniclust_result = run_msa_tool(\n+ self.hhblits_bfd_uniclust_runner, input_fasta_path, bfd_out_path,\n+ 'a3m', self.use_precomputed_msas)\n+ bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])\n+\n+ templates_result = self.template_featurizer.get_templates(\n+ query_sequence=input_sequence,\n+ hits=pdb_template_hits)\n+\n+ sequence_features = make_sequence_features(\n+ sequence=input_sequence,\n+ description=input_description,\n+ num_res=num_res)\n+\n+ msa_features = make_msa_features((uniref90_msa, bfd_msa, mgnify_msa))\n+\n+ logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa))\n+ logging.info('BFD MSA size: %d sequences.', len(bfd_msa))\n+ logging.info('MGnify MSA size: %d sequences.', len(mgnify_msa))\n+ logging.info('Final (deduplicated) MSA size: %d sequences.',\n+ msa_features['num_alignments'][0])\n+ logging.info('Total number of templates (NB: this can include bad '\n+ 'templates and is later filtered to top 4): %d.',\n+ templates_result.features['template_domain_names'].shape[0])\n+\n+ return {**sequence_features, **msa_features, **templates_result.features}\n" |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/data/pipeline_multimer.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/data/pipeline_multimer.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,288 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Functions for building the features for the AlphaFold multimer model."""\n+\n+import collections\n+import contextlib\n+import copy\n+import dataclasses\n+import json\n+import os\n+import tempfile\n+from typing import Mapping, MutableMapping, Sequence\n+\n+from absl import logging\n+from alphafold.common import protein\n+from alphafold.common import residue_constants\n+from alphafold.data import feature_processing\n+from alphafold.data import msa_pairing\n+from alphafold.data import parsers\n+from alphafold.data import pipeline\n+from alphafold.data.tools import jackhmmer\n+import numpy as np\n+\n+# Internal import (7716).\n+\n+\n+@dataclasses.dataclass(frozen=True)\n+class _FastaChain:\n+ sequence: str\n+ description: str\n+\n+\n+def _make_chain_id_map(*,\n+ sequences: Sequence[str],\n+ descriptions: Sequence[str],\n+ ) -> Mapping[str, _FastaChain]:\n+ """Makes a mapping from PDB-format chain ID to sequence and description."""\n+ if len(sequences) != len(descriptions):\n+ raise ValueError(\'sequences and descriptions must have equal length. \'\n+ f\'Got {len(sequences)} != {len(descriptions)}.\')\n+ if len(sequences) > protein.PDB_MAX_CHAINS:\n+ raise ValueError(\'Cannot process more chains than the PDB format supports. \'\n+ f\'Got {len(sequences)} chains.\')\n+ chain_id_map = {}\n+ for chain_id, sequence, description in zip(\n+ protein.PDB_CHAIN_IDS, sequences, descriptions):\n+ chain_id_map[chain_id] = _FastaChain(\n+ sequence=sequence, description=description)\n+ return chain_id_map\n+\n+\n+@contextlib.contextmanager\n+def temp_fasta_file(fasta_str: str):\n+ with tempfile.NamedTemporaryFile(\'w\', suffix=\'.fasta\') as fasta_file:\n+ fasta_file.write(fasta_str)\n+ fasta_file.seek(0)\n+ yield fasta_file.name\n+\n+\n+def convert_monomer_features(\n+ monomer_features: pipeline.FeatureDict,\n+ chain_id: str) -> pipeline.FeatureDict:\n+ """Reshapes and modifies monomer features for multimer models."""\n+ converted = {}\n+ converted[\'auth_chain_id\'] = np.asarray(chain_id, dtype=np.object_)\n+ unnecessary_leading_dim_feats = {\n+ \'sequence\', \'domain_name\', \'num_alignments\', \'seq_length\'}\n+ for feature_name, feature in monomer_features.items():\n+ if feature_name in unnecessary_leading_dim_feats:\n+ # asarray ensures it\'s a np.ndarray.\n+ feature = np.asarray(feature[0], dtype=feature.dtype)\n+ elif feature_name == \'aatype\':\n+ # The multimer model performs the one-hot operation itself.\n+ feature = np.argmax(feature, axis=-1).astype(np.int32)\n+ elif feature_name == \'template_aatype\':\n+ feature = np.argmax(feature, axis=-1).astype(np.int32)\n+ new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE\n+ feature = np.take(new_order_list, feature.astype(np.int32), axis=0)\n+ elif feature_name == \'template_all_atom_masks\':\n+ feature_name = \'template_all_atom_mask\'\n+ converted[feature_name] = feature\n+ return converted\n+\n+\n+def int_id_to_str_id(num: int) -> str:\n+ """Encodes a number as a string, using reverse spreadsheet style naming.\n+\n+ Args:\n+ num: A positive integer.\n+\n+ Returns:\n+ A string that encodes the positive integer using reverse spreadsheet style,\n+ naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the\n+ usual way to encode cha'..b'\n+\n+ def _process_single_chain(\n+ self,\n+ chain_id: str,\n+ sequence: str,\n+ description: str,\n+ msa_output_dir: str,\n+ is_homomer_or_monomer: bool) -> pipeline.FeatureDict:\n+ """Runs the monomer pipeline on a single chain."""\n+ chain_fasta_str = f\'>chain_{chain_id}\\n{sequence}\\n\'\n+ chain_msa_output_dir = os.path.join(msa_output_dir, chain_id)\n+ if not os.path.exists(chain_msa_output_dir):\n+ os.makedirs(chain_msa_output_dir)\n+ with temp_fasta_file(chain_fasta_str) as chain_fasta_path:\n+ logging.info(\'Running monomer pipeline on chain %s: %s\',\n+ chain_id, description)\n+ chain_features = self._monomer_data_pipeline.process(\n+ input_fasta_path=chain_fasta_path,\n+ msa_output_dir=chain_msa_output_dir)\n+\n+ # We only construct the pairing features if there are 2 or more unique\n+ # sequences.\n+ if not is_homomer_or_monomer:\n+ all_seq_msa_features = self._all_seq_msa_features(chain_fasta_path,\n+ chain_msa_output_dir)\n+ chain_features.update(all_seq_msa_features)\n+ return chain_features\n+\n+ def _all_seq_msa_features(self, input_fasta_path, msa_output_dir):\n+ """Get MSA features for unclustered uniprot, for pairing."""\n+ out_path = os.path.join(msa_output_dir, \'uniprot_hits.sto\')\n+ result = pipeline.run_msa_tool(\n+ self._uniprot_msa_runner, input_fasta_path, out_path, \'sto\',\n+ self.use_precomputed_msas)\n+ msa = parsers.parse_stockholm(result[\'sto\'])\n+ msa = msa.truncate(max_seqs=self._max_uniprot_hits)\n+ all_seq_features = pipeline.make_msa_features([msa])\n+ valid_feats = msa_pairing.MSA_FEATURES + (\n+ \'msa_uniprot_accession_identifiers\',\n+ \'msa_species_identifiers\',\n+ )\n+ feats = {f\'{k}_all_seq\': v for k, v in all_seq_features.items()\n+ if k in valid_feats}\n+ return feats\n+\n+ def process(self,\n+ input_fasta_path: str,\n+ msa_output_dir: str,\n+ is_prokaryote: bool = False) -> pipeline.FeatureDict:\n+ """Runs alignment tools on the input sequences and creates features."""\n+ with open(input_fasta_path) as f:\n+ input_fasta_str = f.read()\n+ input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)\n+\n+ chain_id_map = _make_chain_id_map(sequences=input_seqs,\n+ descriptions=input_descs)\n+ chain_id_map_path = os.path.join(msa_output_dir, \'chain_id_map.json\')\n+ with open(chain_id_map_path, \'w\') as f:\n+ chain_id_map_dict = {chain_id: dataclasses.asdict(fasta_chain)\n+ for chain_id, fasta_chain in chain_id_map.items()}\n+ json.dump(chain_id_map_dict, f, indent=4, sort_keys=True)\n+\n+ all_chain_features = {}\n+ sequence_features = {}\n+ is_homomer_or_monomer = len(set(input_seqs)) == 1\n+ for chain_id, fasta_chain in chain_id_map.items():\n+ if fasta_chain.sequence in sequence_features:\n+ all_chain_features[chain_id] = copy.deepcopy(\n+ sequence_features[fasta_chain.sequence])\n+ continue\n+ chain_features = self._process_single_chain(\n+ chain_id=chain_id,\n+ sequence=fasta_chain.sequence,\n+ description=fasta_chain.description,\n+ msa_output_dir=msa_output_dir,\n+ is_homomer_or_monomer=is_homomer_or_monomer)\n+\n+ chain_features = convert_monomer_features(chain_features,\n+ chain_id=chain_id)\n+ all_chain_features[chain_id] = chain_features\n+ sequence_features[fasta_chain.sequence] = chain_features\n+\n+ all_chain_features = add_assembly_features(all_chain_features)\n+\n+ np_example = feature_processing.pair_and_merge(\n+ all_chain_features=all_chain_features,\n+ is_prokaryote=is_prokaryote,\n+ )\n+\n+ # Pad MSA to avoid zero-sized extra_msa.\n+ np_example = pad_msa(np_example, 512)\n+\n+ return np_example\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/data/templates.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/data/templates.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,1010 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Functions for getting templates and calculating template features."""\n+import abc\n+import dataclasses\n+import datetime\n+import functools\n+import glob\n+import os\n+import re\n+from typing import Any, Dict, Mapping, Optional, Sequence, Tuple\n+\n+from absl import logging\n+from alphafold.common import residue_constants\n+from alphafold.data import mmcif_parsing\n+from alphafold.data import parsers\n+from alphafold.data.tools import kalign\n+import numpy as np\n+\n+# Internal import (7716).\n+\n+\n+class Error(Exception):\n+ """Base class for exceptions."""\n+\n+\n+class NoChainsError(Error):\n+ """An error indicating that template mmCIF didn\'t have any chains."""\n+\n+\n+class SequenceNotInTemplateError(Error):\n+ """An error indicating that template mmCIF didn\'t contain the sequence."""\n+\n+\n+class NoAtomDataInTemplateError(Error):\n+ """An error indicating that template mmCIF didn\'t contain atom positions."""\n+\n+\n+class TemplateAtomMaskAllZerosError(Error):\n+ """An error indicating that template mmCIF had all atom positions masked."""\n+\n+\n+class QueryToTemplateAlignError(Error):\n+ """An error indicating that the query can\'t be aligned to the template."""\n+\n+\n+class CaDistanceError(Error):\n+ """An error indicating that a CA atom distance exceeds a threshold."""\n+\n+\n+class MultipleChainsError(Error):\n+ """An error indicating that multiple chains were found for a given ID."""\n+\n+\n+# Prefilter exceptions.\n+class PrefilterError(Exception):\n+ """A base class for template prefilter exceptions."""\n+\n+\n+class DateError(PrefilterError):\n+ """An error indicating that the hit date was after the max allowed date."""\n+\n+\n+class AlignRatioError(PrefilterError):\n+ """An error indicating that the hit align ratio to the query was too small."""\n+\n+\n+class DuplicateError(PrefilterError):\n+ """An error indicating that the hit was an exact subsequence of the query."""\n+\n+\n+class LengthError(PrefilterError):\n+ """An error indicating that the hit was too short."""\n+\n+\n+TEMPLATE_FEATURES = {\n+ \'template_aatype\': np.float32,\n+ \'template_all_atom_masks\': np.float32,\n+ \'template_all_atom_positions\': np.float32,\n+ \'template_domain_names\': np.object,\n+ \'template_sequence\': np.object,\n+ \'template_sum_probs\': np.float32,\n+}\n+\n+\n+def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]:\n+ """Returns PDB id and chain id for an HHSearch Hit."""\n+ # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.\n+ id_match = re.match(r\'[a-zA-Z\\d]{4}_[a-zA-Z0-9.]+\', hit.name)\n+ if not id_match:\n+ raise ValueError(f\'hit.name did not start with PDBID_chain: {hit.name}\')\n+ pdb_id, chain_id = id_match.group(0).split(\'_\')\n+ return pdb_id.lower(), chain_id\n+\n+\n+def _is_after_cutoff(\n+ pdb_id: str,\n+ release_dates: Mapping[str, datetime.datetime],\n+ release_date_cutoff: Optional[datetime.datetime]) -> bool:\n+ """Checks if the template date is after the release date cutoff.\n+\n+ Args:\n+ pdb_id: 4 letter pdb code.\n+ release_dates: Dictionary mapping PDB ids to their structure release dates.\n+ release_date_cutoff: Max release date that is valid for this query.\n+\n+ Returns:\n+ True if the template release date is after the cutoff, False otherwise.\n+ """\n+ if release_date_cutoff is None:\n+ raise ValueError(\'The release_date_cutoff must not be None.\')\n+ if pdb_id in release_dates:\n+ retur'..b'esult.features is None:\n+ logging.info(\'Skipped invalid hit %s, error: %s, warning: %s\',\n+ hit.name, result.error, result.warning)\n+ else:\n+ # Increment the hit counter, since we got features out of this hit.\n+ num_hits += 1\n+ for k in template_features:\n+ template_features[k].append(result.features[k])\n+\n+ for name in template_features:\n+ if num_hits > 0:\n+ template_features[name] = np.stack(\n+ template_features[name], axis=0).astype(TEMPLATE_FEATURES[name])\n+ else:\n+ # Make sure the feature has correct dtype even if empty.\n+ template_features[name] = np.array([], dtype=TEMPLATE_FEATURES[name])\n+\n+ return TemplateSearchResult(\n+ features=template_features, errors=errors, warnings=warnings)\n+\n+\n+class HmmsearchHitFeaturizer(TemplateHitFeaturizer):\n+ """A class for turning a3m hits from hmmsearch to template features."""\n+\n+ def get_templates(\n+ self,\n+ query_sequence: str,\n+ hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult:\n+ """Computes the templates for given query sequence (more details above)."""\n+ logging.info(\'Searching for template for: %s\', query_sequence)\n+\n+ template_features = {}\n+ for template_feature_name in TEMPLATE_FEATURES:\n+ template_features[template_feature_name] = []\n+\n+ already_seen = set()\n+ errors = []\n+ warnings = []\n+\n+ if not hits or hits[0].sum_probs is None:\n+ sorted_hits = hits\n+ else:\n+ sorted_hits = sorted(hits, key=lambda x: x.sum_probs, reverse=True)\n+\n+ for hit in sorted_hits:\n+ # We got all the templates we wanted, stop processing hits.\n+ if len(already_seen) >= self._max_hits:\n+ break\n+\n+ result = _process_single_hit(\n+ query_sequence=query_sequence,\n+ hit=hit,\n+ mmcif_dir=self._mmcif_dir,\n+ max_template_date=self._max_template_date,\n+ release_dates=self._release_dates,\n+ obsolete_pdbs=self._obsolete_pdbs,\n+ strict_error_check=self._strict_error_check,\n+ kalign_binary_path=self._kalign_binary_path)\n+\n+ if result.error:\n+ errors.append(result.error)\n+\n+ # There could be an error even if there are some results, e.g. thrown by\n+ # other unparsable chains in the same mmCIF file.\n+ if result.warning:\n+ warnings.append(result.warning)\n+\n+ if result.features is None:\n+ logging.debug(\'Skipped invalid hit %s, error: %s, warning: %s\',\n+ hit.name, result.error, result.warning)\n+ else:\n+ already_seen_key = result.features[\'template_sequence\']\n+ if already_seen_key in already_seen:\n+ continue\n+ # Increment the hit counter, since we got features out of this hit.\n+ already_seen.add(already_seen_key)\n+ for k in template_features:\n+ template_features[k].append(result.features[k])\n+\n+ if already_seen:\n+ for name in template_features:\n+ template_features[name] = np.stack(\n+ template_features[name], axis=0).astype(TEMPLATE_FEATURES[name])\n+ else:\n+ num_res = len(query_sequence)\n+ # Construct a default template with all zeros.\n+ template_features = {\n+ \'template_aatype\': np.zeros(\n+ (1, num_res, len(residue_constants.restypes_with_x_and_gap)),\n+ np.float32),\n+ \'template_all_atom_masks\': np.zeros(\n+ (1, num_res, residue_constants.atom_type_num), np.float32),\n+ \'template_all_atom_positions\': np.zeros(\n+ (1, num_res, residue_constants.atom_type_num, 3), np.float32),\n+ \'template_domain_names\': np.array([\'\'.encode()], dtype=np.object),\n+ \'template_sequence\': np.array([\'\'.encode()], dtype=np.object),\n+ \'template_sum_probs\': np.array([0], dtype=np.float32)\n+ }\n+ return TemplateSearchResult(\n+ features=template_features, errors=errors, warnings=warnings)\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/data/tools/__init__.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/data/tools/__init__.py Tue Mar 01 02:53:05 2022 +0000 |
b |
@@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Python wrappers for third party tools.""" |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/data/tools/hhblits.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/data/tools/hhblits.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,155 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Library to run HHblits from Python.""" + +import glob +import os +import subprocess +from typing import Any, List, Mapping, Optional, Sequence + +from absl import logging +from alphafold.data.tools import utils +# Internal import (7716). + + +_HHBLITS_DEFAULT_P = 20 +_HHBLITS_DEFAULT_Z = 500 + + +class HHBlits: + """Python wrapper of the HHblits binary.""" + + def __init__(self, + *, + binary_path: str, + databases: Sequence[str], + n_cpu: int = 4, + n_iter: int = 3, + e_value: float = 0.001, + maxseq: int = 1_000_000, + realign_max: int = 100_000, + maxfilt: int = 100_000, + min_prefilter_hits: int = 1000, + all_seqs: bool = False, + alt: Optional[int] = None, + p: int = _HHBLITS_DEFAULT_P, + z: int = _HHBLITS_DEFAULT_Z): + """Initializes the Python HHblits wrapper. + + Args: + binary_path: The path to the HHblits executable. + databases: A sequence of HHblits database paths. This should be the + common prefix for the database files (i.e. up to but not including + _hhm.ffindex etc.) + n_cpu: The number of CPUs to give HHblits. + n_iter: The number of HHblits iterations. + e_value: The E-value, see HHblits docs for more details. + maxseq: The maximum number of rows in an input alignment. Note that this + parameter is only supported in HHBlits version 3.1 and higher. + realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500. + maxfilt: Max number of hits allowed to pass the 2nd prefilter. + HHblits default: 20000. + min_prefilter_hits: Min number of hits to pass prefilter. + HHblits default: 100. + all_seqs: Return all sequences in the MSA / Do not filter the result MSA. + HHblits default: False. + alt: Show up to this many alternative alignments. + p: Minimum Prob for a hit to be included in the output hhr file. + HHblits default: 20. + z: Hard cap on number of hits reported in the hhr file. + HHblits default: 500. NB: The relevant HHblits flag is -Z not -z. + + Raises: + RuntimeError: If HHblits binary not found within the path. + """ + self.binary_path = binary_path + self.databases = databases + + for database_path in self.databases: + if not glob.glob(database_path + '_*'): + logging.error('Could not find HHBlits database %s', database_path) + raise ValueError(f'Could not find HHBlits database {database_path}') + + self.n_cpu = n_cpu + self.n_iter = n_iter + self.e_value = e_value + self.maxseq = maxseq + self.realign_max = realign_max + self.maxfilt = maxfilt + self.min_prefilter_hits = min_prefilter_hits + self.all_seqs = all_seqs + self.alt = alt + self.p = p + self.z = z + + def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]: + """Queries the database using HHblits.""" + with utils.tmpdir_manager() as query_tmp_dir: + a3m_path = os.path.join(query_tmp_dir, 'output.a3m') + + db_cmd = [] + for db_path in self.databases: + db_cmd.append('-d') + db_cmd.append(db_path) + cmd = [ + self.binary_path, + '-i', input_fasta_path, + '-cpu', str(self.n_cpu), + '-oa3m', a3m_path, + '-o', '/dev/null', + '-n', str(self.n_iter), + '-e', str(self.e_value), + '-maxseq', str(self.maxseq), + '-realign_max', str(self.realign_max), + '-maxfilt', str(self.maxfilt), + '-min_prefilter_hits', str(self.min_prefilter_hits)] + if self.all_seqs: + cmd += ['-all'] + if self.alt: + cmd += ['-alt', str(self.alt)] + if self.p != _HHBLITS_DEFAULT_P: + cmd += ['-p', str(self.p)] + if self.z != _HHBLITS_DEFAULT_Z: + cmd += ['-Z', str(self.z)] + cmd += db_cmd + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + with utils.timing('HHblits query'): + stdout, stderr = process.communicate() + retcode = process.wait() + + if retcode: + # Logs have a 15k character limit, so log HHblits error line by line. + logging.error('HHblits failed. HHblits stderr begin:') + for error_line in stderr.decode('utf-8').splitlines(): + if error_line.strip(): + logging.error(error_line.strip()) + logging.error('HHblits stderr end') + raise RuntimeError('HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % ( + stdout.decode('utf-8'), stderr[:500_000].decode('utf-8'))) + + with open(a3m_path) as f: + a3m = f.read() + + raw_output = dict( + a3m=a3m, + output=stdout, + stderr=stderr, + n_iter=self.n_iter, + e_value=self.e_value) + return [raw_output] |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/data/tools/hhsearch.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/data/tools/hhsearch.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,107 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Library to run HHsearch from Python.""" + +import glob +import os +import subprocess +from typing import Sequence + +from absl import logging + +from alphafold.data import parsers +from alphafold.data.tools import utils +# Internal import (7716). + + +class HHSearch: + """Python wrapper of the HHsearch binary.""" + + def __init__(self, + *, + binary_path: str, + databases: Sequence[str], + maxseq: int = 1_000_000): + """Initializes the Python HHsearch wrapper. + + Args: + binary_path: The path to the HHsearch executable. + databases: A sequence of HHsearch database paths. This should be the + common prefix for the database files (i.e. up to but not including + _hhm.ffindex etc.) + maxseq: The maximum number of rows in an input alignment. Note that this + parameter is only supported in HHBlits version 3.1 and higher. + + Raises: + RuntimeError: If HHsearch binary not found within the path. + """ + self.binary_path = binary_path + self.databases = databases + self.maxseq = maxseq + + for database_path in self.databases: + if not glob.glob(database_path + '_*'): + logging.error('Could not find HHsearch database %s', database_path) + raise ValueError(f'Could not find HHsearch database {database_path}') + + @property + def output_format(self) -> str: + return 'hhr' + + @property + def input_format(self) -> str: + return 'a3m' + + def query(self, a3m: str) -> str: + """Queries the database using HHsearch using a given a3m.""" + with utils.tmpdir_manager() as query_tmp_dir: + input_path = os.path.join(query_tmp_dir, 'query.a3m') + hhr_path = os.path.join(query_tmp_dir, 'output.hhr') + with open(input_path, 'w') as f: + f.write(a3m) + + db_cmd = [] + for db_path in self.databases: + db_cmd.append('-d') + db_cmd.append(db_path) + cmd = [self.binary_path, + '-i', input_path, + '-o', hhr_path, + '-maxseq', str(self.maxseq) + ] + db_cmd + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + with utils.timing('HHsearch query'): + stdout, stderr = process.communicate() + retcode = process.wait() + + if retcode: + # Stderr is truncated to prevent proto size errors in Beam. + raise RuntimeError( + 'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % ( + stdout.decode('utf-8'), stderr[:100_000].decode('utf-8'))) + + with open(hhr_path) as f: + hhr = f.read() + return hhr + + def get_template_hits(self, + output_string: str, + input_sequence: str) -> Sequence[parsers.TemplateHit]: + """Gets parsed template hits from the raw string output by the tool.""" + del input_sequence # Used by hmmseach but not needed for hhsearch. + return parsers.parse_hhr(output_string) |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/data/tools/hmmbuild.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/data/tools/hmmbuild.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,138 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Python wrapper for hmmbuild - construct HMM profiles from MSA.""" + +import os +import re +import subprocess + +from absl import logging +from alphafold.data.tools import utils +# Internal import (7716). + + +class Hmmbuild(object): + """Python wrapper of the hmmbuild binary.""" + + def __init__(self, + *, + binary_path: str, + singlemx: bool = False): + """Initializes the Python hmmbuild wrapper. + + Args: + binary_path: The path to the hmmbuild executable. + singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to + just use a common substitution score matrix. + + Raises: + RuntimeError: If hmmbuild binary not found within the path. + """ + self.binary_path = binary_path + self.singlemx = singlemx + + def build_profile_from_sto(self, sto: str, model_construction='fast') -> str: + """Builds a HHM for the aligned sequences given as an A3M string. + + Args: + sto: A string with the aligned sequences in the Stockholm format. + model_construction: Whether to use reference annotation in the msa to + determine consensus columns ('hand') or default ('fast'). + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + """ + return self._build_profile(sto, model_construction=model_construction) + + def build_profile_from_a3m(self, a3m: str) -> str: + """Builds a HHM for the aligned sequences given as an A3M string. + + Args: + a3m: A string with the aligned sequences in the A3M format. + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + """ + lines = [] + for line in a3m.splitlines(): + if not line.startswith('>'): + line = re.sub('[a-z]+', '', line) # Remove inserted residues. + lines.append(line + '\n') + msa = ''.join(lines) + return self._build_profile(msa, model_construction='fast') + + def _build_profile(self, msa: str, model_construction: str = 'fast') -> str: + """Builds a HMM for the aligned sequences given as an MSA string. + + Args: + msa: A string with the aligned sequences, in A3M or STO format. + model_construction: Whether to use reference annotation in the msa to + determine consensus columns ('hand') or default ('fast'). + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + ValueError: If unspecified arguments are provided. + """ + if model_construction not in {'hand', 'fast'}: + raise ValueError(f'Invalid model_construction {model_construction} - only' + 'hand and fast supported.') + + with utils.tmpdir_manager() as query_tmp_dir: + input_query = os.path.join(query_tmp_dir, 'query.msa') + output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm') + + with open(input_query, 'w') as f: + f.write(msa) + + cmd = [self.binary_path] + # If adding flags, we have to do so before the output and input: + + if model_construction == 'hand': + cmd.append(f'--{model_construction}') + if self.singlemx: + cmd.append('--singlemx') + cmd.extend([ + '--amino', + output_hmm_path, + input_query, + ]) + + logging.info('Launching subprocess %s', cmd) + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + + with utils.timing('hmmbuild query'): + stdout, stderr = process.communicate() + retcode = process.wait() + logging.info('hmmbuild stdout:\n%s\n\nstderr:\n%s\n', + stdout.decode('utf-8'), stderr.decode('utf-8')) + + if retcode: + raise RuntimeError('hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n' + % (stdout.decode('utf-8'), stderr.decode('utf-8'))) + + with open(output_hmm_path, encoding='utf-8') as f: + hmm = f.read() + + return hmm |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/data/tools/hmmsearch.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/data/tools/hmmsearch.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,131 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Python wrapper for hmmsearch - search profile against a sequence db.""" + +import os +import subprocess +from typing import Optional, Sequence + +from absl import logging +from alphafold.data import parsers +from alphafold.data.tools import hmmbuild +from alphafold.data.tools import utils +# Internal import (7716). + + +class Hmmsearch(object): + """Python wrapper of the hmmsearch binary.""" + + def __init__(self, + *, + binary_path: str, + hmmbuild_binary_path: str, + database_path: str, + flags: Optional[Sequence[str]] = None): + """Initializes the Python hmmsearch wrapper. + + Args: + binary_path: The path to the hmmsearch executable. + hmmbuild_binary_path: The path to the hmmbuild executable. Used to build + an hmm from an input a3m. + database_path: The path to the hmmsearch database (FASTA format). + flags: List of flags to be used by hmmsearch. + + Raises: + RuntimeError: If hmmsearch binary not found within the path. + """ + self.binary_path = binary_path + self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path) + self.database_path = database_path + if flags is None: + # Default hmmsearch run settings. + flags = ['--F1', '0.1', + '--F2', '0.1', + '--F3', '0.1', + '--incE', '100', + '-E', '100', + '--domE', '100', + '--incdomE', '100'] + self.flags = flags + + if not os.path.exists(self.database_path): + logging.error('Could not find hmmsearch database %s', database_path) + raise ValueError(f'Could not find hmmsearch database {database_path}') + + @property + def output_format(self) -> str: + return 'sto' + + @property + def input_format(self) -> str: + return 'sto' + + def query(self, msa_sto: str) -> str: + """Queries the database using hmmsearch using a given stockholm msa.""" + hmm = self.hmmbuild_runner.build_profile_from_sto(msa_sto, + model_construction='hand') + return self.query_with_hmm(hmm) + + def query_with_hmm(self, hmm: str) -> str: + """Queries the database using hmmsearch using a given hmm.""" + with utils.tmpdir_manager() as query_tmp_dir: + hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm') + out_path = os.path.join(query_tmp_dir, 'output.sto') + with open(hmm_input_path, 'w') as f: + f.write(hmm) + + cmd = [ + self.binary_path, + '--noali', # Don't include the alignment in stdout. + '--cpu', '8' + ] + # If adding flags, we have to do so before the output and input: + if self.flags: + cmd.extend(self.flags) + cmd.extend([ + '-A', out_path, + hmm_input_path, + self.database_path, + ]) + + logging.info('Launching sub-process %s', cmd) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + with utils.timing( + f'hmmsearch ({os.path.basename(self.database_path)}) query'): + stdout, stderr = process.communicate() + retcode = process.wait() + + if retcode: + raise RuntimeError( + 'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % ( + stdout.decode('utf-8'), stderr.decode('utf-8'))) + + with open(out_path) as f: + out_msa = f.read() + + return out_msa + + def get_template_hits(self, + output_string: str, + input_sequence: str) -> Sequence[parsers.TemplateHit]: + """Gets parsed template hits from the raw string output by the tool.""" + a3m_string = parsers.convert_stockholm_to_a3m(output_string, + remove_first_row_gaps=False) + template_hits = parsers.parse_hmmsearch_a3m( + query_sequence=input_sequence, + a3m_string=a3m_string, + skip_first=False) + return template_hits |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/data/tools/jackhmmer.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/data/tools/jackhmmer.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,201 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Library to run Jackhmmer from Python.""" + +from concurrent import futures +import glob +import os +import subprocess +from typing import Any, Callable, Mapping, Optional, Sequence +from urllib import request + +from absl import logging + +from alphafold.data.tools import utils +# Internal import (7716). + + +class Jackhmmer: + """Python wrapper of the Jackhmmer binary.""" + + def __init__(self, + *, + binary_path: str, + database_path: str, + n_cpu: int = 8, + n_iter: int = 1, + e_value: float = 0.0001, + z_value: Optional[int] = None, + get_tblout: bool = False, + filter_f1: float = 0.0005, + filter_f2: float = 0.00005, + filter_f3: float = 0.0000005, + incdom_e: Optional[float] = None, + dom_e: Optional[float] = None, + num_streamed_chunks: Optional[int] = None, + streaming_callback: Optional[Callable[[int], None]] = None): + """Initializes the Python Jackhmmer wrapper. + + Args: + binary_path: The path to the jackhmmer executable. + database_path: The path to the jackhmmer database (FASTA format). + n_cpu: The number of CPUs to give Jackhmmer. + n_iter: The number of Jackhmmer iterations. + e_value: The E-value, see Jackhmmer docs for more details. + z_value: The Z-value, see Jackhmmer docs for more details. + get_tblout: Whether to save tblout string. + filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. + filter_f2: Viterbi pre-filter, set to >1.0 to turn off. + filter_f3: Forward pre-filter, set to >1.0 to turn off. + incdom_e: Domain e-value criteria for inclusion of domains in MSA/next + round. + dom_e: Domain e-value criteria for inclusion in tblout. + num_streamed_chunks: Number of database chunks to stream over. + streaming_callback: Callback function run after each chunk iteration with + the iteration number as argument. + """ + self.binary_path = binary_path + self.database_path = database_path + self.num_streamed_chunks = num_streamed_chunks + + if not os.path.exists(self.database_path) and num_streamed_chunks is None: + logging.error('Could not find Jackhmmer database %s', database_path) + raise ValueError(f'Could not find Jackhmmer database {database_path}') + + self.n_cpu = n_cpu + self.n_iter = n_iter + self.e_value = e_value + self.z_value = z_value + self.filter_f1 = filter_f1 + self.filter_f2 = filter_f2 + self.filter_f3 = filter_f3 + self.incdom_e = incdom_e + self.dom_e = dom_e + self.get_tblout = get_tblout + self.streaming_callback = streaming_callback + + def _query_chunk(self, input_fasta_path: str, database_path: str + ) -> Mapping[str, Any]: + """Queries the database chunk using Jackhmmer.""" + with utils.tmpdir_manager() as query_tmp_dir: + sto_path = os.path.join(query_tmp_dir, 'output.sto') + + # The F1/F2/F3 are the expected proportion to pass each of the filtering + # stages (which get progressively more expensive), reducing these + # speeds up the pipeline at the expensive of sensitivity. They are + # currently set very low to make querying Mgnify run in a reasonable + # amount of time. + cmd_flags = [ + # Don't pollute stdout with Jackhmmer output. + '-o', '/dev/null', + '-A', sto_path, + '--noali', + '--F1', str(self.filter_f1), + '--F2', str(self.filter_f2), + '--F3', str(self.filter_f3), + '--incE', str(self.e_value), + # Report only sequences with E-values <= x in per-sequence output. + '-E', str(self.e_value), + '--cpu', str(self.n_cpu), + '-N', str(self.n_iter) + ] + if self.get_tblout: + tblout_path = os.path.join(query_tmp_dir, 'tblout.txt') + cmd_flags.extend(['--tblout', tblout_path]) + + if self.z_value: + cmd_flags.extend(['-Z', str(self.z_value)]) + + if self.dom_e is not None: + cmd_flags.extend(['--domE', str(self.dom_e)]) + + if self.incdom_e is not None: + cmd_flags.extend(['--incdomE', str(self.incdom_e)]) + + cmd = [self.binary_path] + cmd_flags + [input_fasta_path, + database_path] + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + with utils.timing( + f'Jackhmmer ({os.path.basename(database_path)}) query'): + _, stderr = process.communicate() + retcode = process.wait() + + if retcode: + raise RuntimeError( + 'Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8')) + + # Get e-values for each target name + tbl = '' + if self.get_tblout: + with open(tblout_path) as f: + tbl = f.read() + + with open(sto_path) as f: + sto = f.read() + + raw_output = dict( + sto=sto, + tbl=tbl, + stderr=stderr, + n_iter=self.n_iter, + e_value=self.e_value) + + return raw_output + + def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]: + """Queries the database using Jackhmmer.""" + if self.num_streamed_chunks is None: + return [self._query_chunk(input_fasta_path, self.database_path)] + + db_basename = os.path.basename(self.database_path) + db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}' + db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}' + + # Remove existing files to prevent OOM + for f in glob.glob(db_local_chunk('[0-9]*')): + try: + os.remove(f) + except OSError: + print(f'OSError while deleting {f}') + + # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk + with futures.ThreadPoolExecutor(max_workers=2) as executor: + chunked_output = [] + for i in range(1, self.num_streamed_chunks + 1): + # Copy the chunk locally + if i == 1: + future = executor.submit( + request.urlretrieve, db_remote_chunk(i), db_local_chunk(i)) + if i < self.num_streamed_chunks: + next_future = executor.submit( + request.urlretrieve, db_remote_chunk(i+1), db_local_chunk(i+1)) + + # Run Jackhmmer with the chunk + future.result() + chunked_output.append( + self._query_chunk(input_fasta_path, db_local_chunk(i))) + + # Remove the local copy of the chunk + os.remove(db_local_chunk(i)) + # Do not set next_future for the last chunk so that this works even for + # databases with only 1 chunk. + if i < self.num_streamed_chunks: + future = next_future + if self.streaming_callback: + self.streaming_callback(i) + return chunked_output |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/data/tools/kalign.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/data/tools/kalign.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,104 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Python wrapper for Kalign.""" +import os +import subprocess +from typing import Sequence + +from absl import logging + +from alphafold.data.tools import utils +# Internal import (7716). + + +def _to_a3m(sequences: Sequence[str]) -> str: + """Converts sequences to an a3m file.""" + names = ['sequence %d' % i for i in range(1, len(sequences) + 1)] + a3m = [] + for sequence, name in zip(sequences, names): + a3m.append(u'>' + name + u'\n') + a3m.append(sequence + u'\n') + return ''.join(a3m) + + +class Kalign: + """Python wrapper of the Kalign binary.""" + + def __init__(self, *, binary_path: str): + """Initializes the Python Kalign wrapper. + + Args: + binary_path: The path to the Kalign binary. + + Raises: + RuntimeError: If Kalign binary not found within the path. + """ + self.binary_path = binary_path + + def align(self, sequences: Sequence[str]) -> str: + """Aligns the sequences and returns the alignment in A3M string. + + Args: + sequences: A list of query sequence strings. The sequences have to be at + least 6 residues long (Kalign requires this). Note that the order in + which you give the sequences might alter the output slightly as + different alignment tree might get constructed. + + Returns: + A string with the alignment in a3m format. + + Raises: + RuntimeError: If Kalign fails. + ValueError: If any of the sequences is less than 6 residues long. + """ + logging.info('Aligning %d sequences', len(sequences)) + + for s in sequences: + if len(s) < 6: + raise ValueError('Kalign requires all sequences to be at least 6 ' + 'residues long. Got %s (%d residues).' % (s, len(s))) + + with utils.tmpdir_manager() as query_tmp_dir: + input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta') + output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m') + + with open(input_fasta_path, 'w') as f: + f.write(_to_a3m(sequences)) + + cmd = [ + self.binary_path, + '-i', input_fasta_path, + '-o', output_a3m_path, + '-format', 'fasta', + ] + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + + with utils.timing('Kalign query'): + stdout, stderr = process.communicate() + retcode = process.wait() + logging.info('Kalign stdout:\n%s\n\nstderr:\n%s\n', + stdout.decode('utf-8'), stderr.decode('utf-8')) + + if retcode: + raise RuntimeError('Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' + % (stdout.decode('utf-8'), stderr.decode('utf-8'))) + + with open(output_a3m_path) as f: + a3m = f.read() + + return a3m |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/data/tools/utils.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/data/tools/utils.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,40 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common utilities for data pipeline tools.""" +import contextlib +import shutil +import tempfile +import time +from typing import Optional + +from absl import logging + + +@contextlib.contextmanager +def tmpdir_manager(base_dir: Optional[str] = None): + """Context manager that deletes a temporary directory on exit.""" + tmpdir = tempfile.mkdtemp(dir=base_dir) + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +@contextlib.contextmanager +def timing(msg: str): + logging.info('Started %s', msg) + tic = time.time() + yield + toc = time.time() + logging.info('Finished %s in %.3f seconds', msg, toc - tic) |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/__init__.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/__init__.py Tue Mar 01 02:53:05 2022 +0000 |
b |
@@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Alphafold model.""" |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/all_atom.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/all_atom.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,1141 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Ops for all atom representations.\n+\n+Generally we employ two different representations for all atom coordinates,\n+one is atom37 where each heavy atom corresponds to a given position in a 37\n+dimensional array, This mapping is non amino acid specific, but each slot\n+corresponds to an atom of a given name, for example slot 12 always corresponds\n+to \'C delta 1\', positions that are not present for a given amino acid are\n+zeroed out and denoted by a mask.\n+The other representation we employ is called atom14, this is a more dense way\n+of representing atoms with 14 slots. Here a given slot will correspond to a\n+different kind of atom depending on amino acid type, for example slot 5\n+corresponds to \'N delta 2\' for Aspargine, but to \'C delta 1\' for Isoleucine.\n+14 is chosen because it is the maximum number of heavy atoms for any standard\n+amino acid.\n+The order of slots can be found in \'residue_constants.residue_atoms\'.\n+Internally the model uses the atom14 representation because it is\n+computationally more efficient.\n+The internal atom14 representation is turned into the atom37 at the output of\n+the network to facilitate easier conversion to existing protein datastructures.\n+"""\n+\n+from typing import Dict, Optional\n+from alphafold.common import residue_constants\n+\n+from alphafold.model import r3\n+from alphafold.model import utils\n+import jax\n+import jax.numpy as jnp\n+import numpy as np\n+\n+\n+def squared_difference(x, y):\n+ return jnp.square(x - y)\n+\n+\n+def get_chi_atom_indices():\n+ """Returns atom indices needed to compute chi angles for all residue types.\n+\n+ Returns:\n+ A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are\n+ in the order specified in residue_constants.restypes + unknown residue type\n+ at the end. For chi angles which are not defined on the residue, the\n+ positions indices are by default set to 0.\n+ """\n+ chi_atom_indices = []\n+ for residue_name in residue_constants.restypes:\n+ residue_name = residue_constants.restype_1to3[residue_name]\n+ residue_chi_angles = residue_constants.chi_angles_atoms[residue_name]\n+ atom_indices = []\n+ for chi_angle in residue_chi_angles:\n+ atom_indices.append(\n+ [residue_constants.atom_order[atom] for atom in chi_angle])\n+ for _ in range(4 - len(atom_indices)):\n+ atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.\n+ chi_atom_indices.append(atom_indices)\n+\n+ chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.\n+\n+ return jnp.asarray(chi_atom_indices)\n+\n+\n+def atom14_to_atom37(atom14_data: jnp.ndarray, # (N, 14, ...)\n+ batch: Dict[str, jnp.ndarray]\n+ ) -> jnp.ndarray: # (N, 37, ...)\n+ """Convert atom14 to atom37 representation."""\n+ assert len(atom14_data.shape) in [2, 3]\n+ assert \'residx_atom37_to_atom14\' in batch\n+ assert \'atom37_atom_exists\' in batch\n+\n+ atom37_data = utils.batched_gather(atom14_data,\n+ batch[\'residx_atom37_to_atom14\'],\n+ batch_dims=1)\n+ if len(atom14_data.shape) == 2:\n+ atom37_data *= batch[\'atom37_atom_exists\']\n+ elif len(atom14_data.shape) == 3:\n+ atom37_data *= batch[\'atom37_atom_exists\'][:, :,\n+ None].astype(atom37_data.dtype)\n+ return atom37_data\n+\n+\n+def atom37'..b'_frames.rot.xx.ndim == 1\n+ assert frames_mask.ndim == 1, frames_mask.ndim\n+ assert pred_positions.x.ndim == 1\n+ assert target_positions.x.ndim == 1\n+ assert positions_mask.ndim == 1\n+\n+ # Compute array of predicted positions in the predicted frames.\n+ # r3.Vecs (num_frames, num_positions)\n+ local_pred_pos = r3.rigids_mul_vecs(\n+ jax.tree_map(lambda r: r[:, None], r3.invert_rigids(pred_frames)),\n+ jax.tree_map(lambda x: x[None, :], pred_positions))\n+\n+ # Compute array of target positions in the target frames.\n+ # r3.Vecs (num_frames, num_positions)\n+ local_target_pos = r3.rigids_mul_vecs(\n+ jax.tree_map(lambda r: r[:, None], r3.invert_rigids(target_frames)),\n+ jax.tree_map(lambda x: x[None, :], target_positions))\n+\n+ # Compute errors between the structures.\n+ # jnp.ndarray (num_frames, num_positions)\n+ error_dist = jnp.sqrt(\n+ r3.vecs_squared_distance(local_pred_pos, local_target_pos)\n+ + epsilon)\n+\n+ if l1_clamp_distance:\n+ error_dist = jnp.clip(error_dist, 0, l1_clamp_distance)\n+\n+ normed_error = error_dist / length_scale\n+ normed_error *= jnp.expand_dims(frames_mask, axis=-1)\n+ normed_error *= jnp.expand_dims(positions_mask, axis=-2)\n+\n+ normalization_factor = (\n+ jnp.sum(frames_mask, axis=-1) *\n+ jnp.sum(positions_mask, axis=-1))\n+ return (jnp.sum(normed_error, axis=(-2, -1)) /\n+ (epsilon + normalization_factor))\n+\n+\n+def _make_renaming_matrices():\n+ """Matrices to map atoms to symmetry partners in ambiguous case."""\n+ # As the atom naming is ambiguous for 7 of the 20 amino acids, provide\n+ # alternative groundtruth coordinates where the naming is swapped\n+ restype_3 = [\n+ residue_constants.restype_1to3[res] for res in residue_constants.restypes\n+ ]\n+ restype_3 += [\'UNK\']\n+ # Matrices for renaming ambiguous atoms.\n+ all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}\n+ for resname, swap in residue_constants.residue_atom_renaming_swaps.items():\n+ correspondences = np.arange(14)\n+ for source_atom_swap, target_atom_swap in swap.items():\n+ source_index = residue_constants.restype_name_to_atom14_names[\n+ resname].index(source_atom_swap)\n+ target_index = residue_constants.restype_name_to_atom14_names[\n+ resname].index(target_atom_swap)\n+ correspondences[source_index] = target_index\n+ correspondences[target_index] = source_index\n+ renaming_matrix = np.zeros((14, 14), dtype=np.float32)\n+ for index, correspondence in enumerate(correspondences):\n+ renaming_matrix[index, correspondence] = 1.\n+ all_matrices[resname] = renaming_matrix.astype(np.float32)\n+ renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3])\n+ return renaming_matrices\n+\n+\n+RENAMING_MATRICES = _make_renaming_matrices()\n+\n+\n+def get_alt_atom14(aatype, positions, mask):\n+ """Get alternative atom14 positions.\n+\n+ Constructs renamed atom positions for ambiguous residues.\n+\n+ Jumper et al. (2021) Suppl. Table 3 "Ambiguous atom names due to 180 degree-\n+ rotation-symmetry"\n+\n+ Args:\n+ aatype: Amino acid at given position\n+ positions: Atom positions as r3.Vecs in atom14 representation, (N, 14)\n+ mask: Atom masks in atom14 representation, (N, 14)\n+ Returns:\n+ renamed atom positions, renamed atom mask\n+ """\n+ # pick the transformation matrices for the given residue sequence\n+ # shape (num_res, 14, 14)\n+ renaming_transform = utils.batched_gather(\n+ jnp.asarray(RENAMING_MATRICES), aatype)\n+\n+ positions = jax.tree_map(lambda x: x[:, :, None], positions)\n+ alternative_positions = jax.tree_map(\n+ lambda x: jnp.sum(x, axis=1), positions * renaming_transform)\n+\n+ # Create the mask for the alternative ground truth (differs from the\n+ # ground truth mask, if only one of the atoms in an ambiguous pair has a\n+ # ground truth position)\n+ alternative_mask = jnp.sum(mask[..., None] * renaming_transform, axis=1)\n+\n+ return alternative_positions, alternative_mask\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/all_atom_multimer.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/all_atom_multimer.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,966 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+"""Ops for all atom representations."""\n+\n+from typing import Dict, Text\n+\n+from alphafold.common import residue_constants\n+from alphafold.model import geometry\n+from alphafold.model import utils\n+import jax\n+import jax.numpy as jnp\n+import numpy as np\n+\n+\n+def squared_difference(x, y):\n+ return jnp.square(x - y)\n+\n+\n+def _make_chi_atom_indices():\n+ """Returns atom indices needed to compute chi angles for all residue types.\n+\n+ Returns:\n+ A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are\n+ in the order specified in residue_constants.restypes + unknown residue type\n+ at the end. For chi angles which are not defined on the residue, the\n+ positions indices are by default set to 0.\n+ """\n+ chi_atom_indices = []\n+ for residue_name in residue_constants.restypes:\n+ residue_name = residue_constants.restype_1to3[residue_name]\n+ residue_chi_angles = residue_constants.chi_angles_atoms[residue_name]\n+ atom_indices = []\n+ for chi_angle in residue_chi_angles:\n+ atom_indices.append(\n+ [residue_constants.atom_order[atom] for atom in chi_angle])\n+ for _ in range(4 - len(atom_indices)):\n+ atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.\n+ chi_atom_indices.append(atom_indices)\n+\n+ chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.\n+\n+ return np.array(chi_atom_indices)\n+\n+\n+def _make_renaming_matrices():\n+ """Matrices to map atoms to symmetry partners in ambiguous case."""\n+ # As the atom naming is ambiguous for 7 of the 20 amino acids, provide\n+ # alternative groundtruth coordinates where the naming is swapped\n+ restype_3 = [\n+ residue_constants.restype_1to3[res] for res in residue_constants.restypes\n+ ]\n+ restype_3 += [\'UNK\']\n+ # Matrices for renaming ambiguous atoms.\n+ all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}\n+ for resname, swap in residue_constants.residue_atom_renaming_swaps.items():\n+ correspondences = np.arange(14)\n+ for source_atom_swap, target_atom_swap in swap.items():\n+ source_index = residue_constants.restype_name_to_atom14_names[\n+ resname].index(source_atom_swap)\n+ target_index = residue_constants.restype_name_to_atom14_names[\n+ resname].index(target_atom_swap)\n+ correspondences[source_index] = target_index\n+ correspondences[target_index] = source_index\n+ renaming_matrix = np.zeros((14, 14), dtype=np.float32)\n+ for index, correspondence in enumerate(correspondences):\n+ renaming_matrix[index, correspondence] = 1.\n+ all_matrices[resname] = renaming_matrix.astype(np.float32)\n+ renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3])\n+ return renaming_matrices\n+\n+\n+def _make_restype_atom37_mask():\n+ """Mask of which atoms are present for which residue type in atom37."""\n+ # create the corresponding mask\n+ restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)\n+ for restype, restype_letter in enumerate(residue_constants.restypes):\n+ restype_name = residue_constants.restype_1to3[restype_letter]\n+ atom_names = residue_constants.residue_atoms[restype_name]\n+ for atom_name in atom_names:\n+ atom_type = residue_constants.atom_order[atom_name]\n+ restype_atom37_mask[restype, atom_type] = 1\n+ return restype_atom37_mask\n+\n+\n+def _make_restype_atom14_'..b'omputes the chi angles given all atom positions and the amino acid type.\n+\n+ Args:\n+ positions: A Vec3Array of shape\n+ [num_res, residue_constants.atom_type_num], with positions of\n+ atoms needed to calculate chi angles. Supports up to 1 batch dimension.\n+ mask: An optional tensor of shape\n+ [num_res, residue_constants.atom_type_num] that masks which atom\n+ positions are set for each residue. If given, then the chi mask will be\n+ set to 1 for a chi angle only if the amino acid has that chi angle and all\n+ the chi atoms needed to calculate that chi angle are set. If not given\n+ (set to None), the chi mask will be set to 1 for a chi angle if the amino\n+ acid has that chi angle and whether the actual atoms needed to calculate\n+ it were set will be ignored.\n+ aatype: A tensor of shape [num_res] with amino acid type integer\n+ code (0 to 21). Supports up to 1 batch dimension.\n+\n+ Returns:\n+ A tuple of tensors (chi_angles, mask), where both have shape\n+ [num_res, 4]. The mask masks out unused chi angles for amino acid\n+ types that have less than 4 chi angles. If atom_positions_mask is set, the\n+ chi mask will also mask out uncomputable chi angles.\n+ """\n+\n+ # Don\'t assert on the num_res and batch dimensions as they might be unknown.\n+ assert positions.shape[-1] == residue_constants.atom_type_num\n+ assert mask.shape[-1] == residue_constants.atom_type_num\n+\n+ # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4].\n+ chi_atom_indices = get_chi_atom_indices()\n+ # Select atoms to compute chis. Shape: [num_res, chis=4, atoms=4].\n+ atom_indices = utils.batched_gather(\n+ params=chi_atom_indices, indices=aatype, axis=0)\n+ # Gather atom positions. Shape: [num_res, chis=4, atoms=4, xyz=3].\n+ chi_angle_atoms = jax.tree_map(\n+ lambda x: utils.batched_gather( # pylint: disable=g-long-lambda\n+ params=x, indices=atom_indices, axis=-1, batch_dims=1), positions)\n+ a, b, c, d = [chi_angle_atoms[..., i] for i in range(4)]\n+\n+ chi_angles = geometry.dihedral_angle(a, b, c, d)\n+\n+ # Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4].\n+ chi_angles_mask = list(residue_constants.chi_angles_mask)\n+ chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])\n+ chi_angles_mask = jnp.asarray(chi_angles_mask)\n+ # Compute the chi angle mask. Shape [num_res, chis=4].\n+ chi_mask = utils.batched_gather(params=chi_angles_mask, indices=aatype,\n+ axis=0)\n+\n+ # The chi_mask is set to 1 only when all necessary chi angle atoms were set.\n+ # Gather the chi angle atoms mask. Shape: [num_res, chis=4, atoms=4].\n+ chi_angle_atoms_mask = utils.batched_gather(\n+ params=mask, indices=atom_indices, axis=-1, batch_dims=1)\n+ # Check if all 4 chi angle atoms were set. Shape: [num_res, chis=4].\n+ chi_angle_atoms_mask = jnp.prod(chi_angle_atoms_mask, axis=[-1])\n+ chi_mask = chi_mask * chi_angle_atoms_mask.astype(jnp.float32)\n+\n+ return chi_angles, chi_mask\n+\n+\n+def make_transform_from_reference(\n+ a_xyz: geometry.Vec3Array,\n+ b_xyz: geometry.Vec3Array,\n+ c_xyz: geometry.Vec3Array) -> geometry.Rigid3Array:\n+ """Returns rotation and translation matrices to convert from reference.\n+\n+ Note that this method does not take care of symmetries. If you provide the\n+ coordinates in the non-standard way, the A atom will end up in the negative\n+ y-axis rather than in the positive y-axis. You need to take care of such\n+ cases in your code.\n+\n+ Args:\n+ a_xyz: A Vec3Array.\n+ b_xyz: A Vec3Array.\n+ c_xyz: A Vec3Array.\n+\n+ Returns:\n+ A Rigid3Array which, when applied to coordinates in a canonicalized\n+ reference frame, will give coordinates approximately equal\n+ the original coordinates (in the global frame).\n+ """\n+ rotation = geometry.Rot3Array.from_two_vectors(c_xyz - b_xyz,\n+ a_xyz - b_xyz)\n+ return geometry.Rigid3Array(rotation, b_xyz)\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/all_atom_test.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/all_atom_test.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,135 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for all_atom.""" + +from absl.testing import absltest +from absl.testing import parameterized +from alphafold.model import all_atom +from alphafold.model import r3 +import numpy as np + +L1_CLAMP_DISTANCE = 10 + + +def get_identity_rigid(shape): + """Returns identity rigid transform.""" + + ones = np.ones(shape) + zeros = np.zeros(shape) + rot = r3.Rots(ones, zeros, zeros, + zeros, ones, zeros, + zeros, zeros, ones) + trans = r3.Vecs(zeros, zeros, zeros) + return r3.Rigids(rot, trans) + + +def get_global_rigid_transform(rot_angle, translation, bcast_dims): + """Returns rigid transform that globally rotates/translates by same amount.""" + + rot_angle = np.asarray(rot_angle) + translation = np.asarray(translation) + if bcast_dims: + for _ in range(bcast_dims): + rot_angle = np.expand_dims(rot_angle, 0) + translation = np.expand_dims(translation, 0) + sin_angle = np.sin(np.deg2rad(rot_angle)) + cos_angle = np.cos(np.deg2rad(rot_angle)) + ones = np.ones_like(sin_angle) + zeros = np.zeros_like(sin_angle) + rot = r3.Rots(ones, zeros, zeros, + zeros, cos_angle, -sin_angle, + zeros, sin_angle, cos_angle) + trans = r3.Vecs(translation[..., 0], translation[..., 1], translation[..., 2]) + return r3.Rigids(rot, trans) + + +class AllAtomTest(parameterized.TestCase, absltest.TestCase): + + @parameterized.named_parameters( + ('identity', 0, [0, 0, 0]), + ('rot_90', 90, [0, 0, 0]), + ('trans_10', 0, [0, 0, 10]), + ('rot_174_trans_1', 174, [1, 1, 1])) + def test_frame_aligned_point_error_perfect_on_global_transform( + self, rot_angle, translation): + """Tests global transform between target and preds gives perfect score.""" + + # pylint: disable=bad-whitespace + target_positions = np.array( + [[ 21.182, 23.095, 19.731], + [ 22.055, 20.919, 17.294], + [ 24.599, 20.005, 15.041], + [ 25.567, 18.214, 12.166], + [ 28.063, 17.082, 10.043], + [ 28.779, 15.569, 6.985], + [ 30.581, 13.815, 4.612], + [ 29.258, 12.193, 2.296]]) + # pylint: enable=bad-whitespace + global_rigid_transform = get_global_rigid_transform( + rot_angle, translation, 1) + + target_positions = r3.vecs_from_tensor(target_positions) + pred_positions = r3.rigids_mul_vecs( + global_rigid_transform, target_positions) + positions_mask = np.ones(target_positions.x.shape[0]) + + target_frames = get_identity_rigid(10) + pred_frames = r3.rigids_mul_rigids(global_rigid_transform, target_frames) + frames_mask = np.ones(10) + + fape = all_atom.frame_aligned_point_error( + pred_frames, target_frames, frames_mask, pred_positions, + target_positions, positions_mask, L1_CLAMP_DISTANCE, + L1_CLAMP_DISTANCE, epsilon=0) + self.assertAlmostEqual(fape, 0.) + + @parameterized.named_parameters( + ('identity', + [[0, 0, 0], [5, 0, 0], [10, 0, 0]], + [[0, 0, 0], [5, 0, 0], [10, 0, 0]], + 0.), + ('shift_2.5', + [[0, 0, 0], [5, 0, 0], [10, 0, 0]], + [[2.5, 0, 0], [7.5, 0, 0], [7.5, 0, 0]], + 0.25), + ('shift_5', + [[0, 0, 0], [5, 0, 0], [10, 0, 0]], + [[5, 0, 0], [10, 0, 0], [15, 0, 0]], + 0.5), + ('shift_10', + [[0, 0, 0], [5, 0, 0], [10, 0, 0]], + [[10, 0, 0], [15, 0, 0], [0, 0, 0]], + 1.)) + def test_frame_aligned_point_error_matches_expected( + self, target_positions, pred_positions, expected_alddt): + """Tests score matches expected.""" + + target_frames = get_identity_rigid(2) + pred_frames = target_frames + frames_mask = np.ones(2) + + target_positions = r3.vecs_from_tensor(np.array(target_positions)) + pred_positions = r3.vecs_from_tensor(np.array(pred_positions)) + positions_mask = np.ones(target_positions.x.shape[0]) + + alddt = all_atom.frame_aligned_point_error( + pred_frames, target_frames, frames_mask, pred_positions, + target_positions, positions_mask, L1_CLAMP_DISTANCE, + L1_CLAMP_DISTANCE, epsilon=0) + self.assertAlmostEqual(alddt, expected_alddt) + + +if __name__ == '__main__': + absltest.main() |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/common_modules.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/common_modules.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,130 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A collection of common Haiku modules for use in protein folding.""" +import numbers +from typing import Union, Sequence + +import haiku as hk +import jax.numpy as jnp +import numpy as np + + +# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) +TRUNCATED_NORMAL_STDDEV_FACTOR = np.asarray(.87962566103423978, + dtype=np.float32) + + +def get_initializer_scale(initializer_name, input_shape): + """Get Initializer for weights and scale to multiply activations by.""" + + if initializer_name == 'zeros': + w_init = hk.initializers.Constant(0.0) + else: + # fan-in scaling + scale = 1. + for channel_dim in input_shape: + scale /= channel_dim + if initializer_name == 'relu': + scale *= 2 + + noise_scale = scale + + stddev = np.sqrt(noise_scale) + # Adjust stddev for truncation. + stddev = stddev / TRUNCATED_NORMAL_STDDEV_FACTOR + w_init = hk.initializers.TruncatedNormal(mean=0.0, stddev=stddev) + + return w_init + + +class Linear(hk.Module): + """Protein folding specific Linear module. + + This differs from the standard Haiku Linear in a few ways: + * It supports inputs and outputs of arbitrary rank + * Initializers are specified by strings + """ + + def __init__(self, + num_output: Union[int, Sequence[int]], + initializer: str = 'linear', + num_input_dims: int = 1, + use_bias: bool = True, + bias_init: float = 0., + precision = None, + name: str = 'linear'): + """Constructs Linear Module. + + Args: + num_output: Number of output channels. Can be tuple when outputting + multiple dimensions. + initializer: What initializer to use, should be one of {'linear', 'relu', + 'zeros'} + num_input_dims: Number of dimensions from the end to project. + use_bias: Whether to include trainable bias + bias_init: Value used to initialize bias. + precision: What precision to use for matrix multiplication, defaults + to None. + name: Name of module, used for name scopes. + """ + super().__init__(name=name) + if isinstance(num_output, numbers.Integral): + self.output_shape = (num_output,) + else: + self.output_shape = tuple(num_output) + self.initializer = initializer + self.use_bias = use_bias + self.bias_init = bias_init + self.num_input_dims = num_input_dims + self.num_output_dims = len(self.output_shape) + self.precision = precision + + def __call__(self, inputs): + """Connects Module. + + Args: + inputs: Tensor with at least num_input_dims dimensions. + + Returns: + output of shape [...] + num_output. + """ + + num_input_dims = self.num_input_dims + + if self.num_input_dims > 0: + in_shape = inputs.shape[-self.num_input_dims:] + else: + in_shape = () + + weight_init = get_initializer_scale(self.initializer, in_shape) + + in_letters = 'abcde'[:self.num_input_dims] + out_letters = 'hijkl'[:self.num_output_dims] + + weight_shape = in_shape + self.output_shape + weights = hk.get_parameter('weights', weight_shape, inputs.dtype, + weight_init) + + equation = f'...{in_letters}, {in_letters}{out_letters}->...{out_letters}' + + output = jnp.einsum(equation, inputs, weights, precision=self.precision) + + if self.use_bias: + bias = hk.get_parameter('bias', self.output_shape, inputs.dtype, + hk.initializers.Constant(self.bias_init)) + output += bias + + return output + |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/config.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/config.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,657 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+"""Model config."""\n+\n+import copy\n+from alphafold.model.tf import shape_placeholders\n+import ml_collections\n+\n+NUM_RES = shape_placeholders.NUM_RES\n+NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ\n+NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ\n+NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES\n+\n+\n+def model_config(name: str) -> ml_collections.ConfigDict:\n+ """Get the ConfigDict of a CASP14 model."""\n+\n+ if \'multimer\' in name:\n+ return CONFIG_MULTIMER\n+\n+ if name not in CONFIG_DIFFS:\n+ raise ValueError(f\'Invalid model name {name}.\')\n+ cfg = copy.deepcopy(CONFIG)\n+ cfg.update_from_flattened_dict(CONFIG_DIFFS[name])\n+ return cfg\n+\n+\n+MODEL_PRESETS = {\n+ \'monomer\': (\n+ \'model_1\',\n+ \'model_2\',\n+ \'model_3\',\n+ \'model_4\',\n+ \'model_5\',\n+ ),\n+ \'monomer_ptm\': (\n+ \'model_1_ptm\',\n+ \'model_2_ptm\',\n+ \'model_3_ptm\',\n+ \'model_4_ptm\',\n+ \'model_5_ptm\',\n+ ),\n+ \'multimer\': (\n+ \'model_1_multimer\',\n+ \'model_2_multimer\',\n+ \'model_3_multimer\',\n+ \'model_4_multimer\',\n+ \'model_5_multimer\',\n+ ),\n+}\n+MODEL_PRESETS[\'monomer_casp14\'] = MODEL_PRESETS[\'monomer\']\n+\n+\n+CONFIG_DIFFS = {\n+ \'model_1\': {\n+ # Jumper et al. (2021) Suppl. Table 5, Model 1.1.1\n+ \'data.common.max_extra_msa\': 5120,\n+ \'data.common.reduce_msa_clusters_by_max_templates\': True,\n+ \'data.common.use_templates\': True,\n+ \'model.embeddings_and_evoformer.template.embed_torsion_angles\': True,\n+ \'model.embeddings_and_evoformer.template.enabled\': True\n+ },\n+ \'model_2\': {\n+ # Jumper et al. (2021) Suppl. Table 5, Model 1.1.2\n+ \'data.common.reduce_msa_clusters_by_max_templates\': True,\n+ \'data.common.use_templates\': True,\n+ \'model.embeddings_and_evoformer.template.embed_torsion_angles\': True,\n+ \'model.embeddings_and_evoformer.template.enabled\': True\n+ },\n+ \'model_3\': {\n+ # Jumper et al. (2021) Suppl. Table 5, Model 1.2.1\n+ \'data.common.max_extra_msa\': 5120,\n+ },\n+ \'model_4\': {\n+ # Jumper et al. (2021) Suppl. Table 5, Model 1.2.2\n+ \'data.common.max_extra_msa\': 5120,\n+ },\n+ \'model_5\': {\n+ # Jumper et al. (2021) Suppl. Table 5, Model 1.2.3\n+ },\n+\n+ # The following models are fine-tuned from the corresponding models above\n+ # with an additional predicted_aligned_error head that can produce\n+ # predicted TM-score (pTM) and predicted aligned errors.\n+ \'model_1_ptm\': {\n+ \'data.common.max_extra_msa\': 5120,\n+ \'data.common.reduce_msa_clusters_by_max_templates\': True,\n+ \'data.common.use_templates\': True,\n+ \'model.embeddings_and_evoformer.template.embed_torsion_angles\': True,\n+ \'model.embeddings_and_evoformer.template.enabled\': True,\n+ \'model.heads.predicted_aligned_error.weight\': 0.1\n+ },\n+ \'model_2_ptm\': {\n+ \'data.common.reduce_msa_clusters_by_max_templates\': True,\n+ \'data.common.use_templates\': True,\n+ \'model.embeddings_and_evoformer.template.embed_torsion_angles\': True,\n+ \'model.embeddings_and_evoformer.template.enabled\': True,\n+ \'model.heads.predicted_aligned_error.weight\': 0.1\n+ },\n+ \'model_3_ptm\': {\n+ \'data.common.max_extra_msa\': 5120,\n+ \'model.heads.predicted_aligned_error.weight\': 0.1\n+ },\n+ \'model_4_pt'..b" 'dropout_rate': 0.25,\n+ 'gating': True,\n+ 'num_head': 4,\n+ 'orientation': 'per_column',\n+ 'shared_dropout': True\n+ },\n+ 'triangle_attention_starting_node': {\n+ 'dropout_rate': 0.25,\n+ 'gating': True,\n+ 'num_head': 4,\n+ 'orientation': 'per_row',\n+ 'shared_dropout': True\n+ },\n+ 'triangle_multiplication_incoming': {\n+ 'dropout_rate': 0.25,\n+ 'equation': 'kjc,kic->ijc',\n+ 'num_intermediate_channel': 64,\n+ 'orientation': 'per_row',\n+ 'shared_dropout': True\n+ },\n+ 'triangle_multiplication_outgoing': {\n+ 'dropout_rate': 0.25,\n+ 'equation': 'ikc,jkc->ijc',\n+ 'num_intermediate_channel': 64,\n+ 'orientation': 'per_row',\n+ 'shared_dropout': True\n+ }\n+ }\n+ },\n+ },\n+ 'global_config': {\n+ 'deterministic': False,\n+ 'multimer_mode': True,\n+ 'subbatch_size': 4,\n+ 'use_remat': False,\n+ 'zero_init': True\n+ },\n+ 'heads': {\n+ 'distogram': {\n+ 'first_break': 2.3125,\n+ 'last_break': 21.6875,\n+ 'num_bins': 64,\n+ 'weight': 0.3\n+ },\n+ 'experimentally_resolved': {\n+ 'filter_by_resolution': True,\n+ 'max_resolution': 3.0,\n+ 'min_resolution': 0.1,\n+ 'weight': 0.01\n+ },\n+ 'masked_msa': {\n+ 'weight': 2.0\n+ },\n+ 'predicted_aligned_error': {\n+ 'filter_by_resolution': True,\n+ 'max_error_bin': 31.0,\n+ 'max_resolution': 3.0,\n+ 'min_resolution': 0.1,\n+ 'num_bins': 64,\n+ 'num_channels': 128,\n+ 'weight': 0.1\n+ },\n+ 'predicted_lddt': {\n+ 'filter_by_resolution': True,\n+ 'max_resolution': 3.0,\n+ 'min_resolution': 0.1,\n+ 'num_bins': 50,\n+ 'num_channels': 128,\n+ 'weight': 0.01\n+ },\n+ 'structure_module': {\n+ 'angle_norm_weight': 0.01,\n+ 'chi_weight': 0.5,\n+ 'clash_overlap_tolerance': 1.5,\n+ 'dropout': 0.1,\n+ 'interface_fape': {\n+ 'atom_clamp_distance': 1000.0,\n+ 'loss_unit_distance': 20.0\n+ },\n+ 'intra_chain_fape': {\n+ 'atom_clamp_distance': 10.0,\n+ 'loss_unit_distance': 10.0\n+ },\n+ 'num_channel': 384,\n+ 'num_head': 12,\n+ 'num_layer': 8,\n+ 'num_layer_in_transition': 3,\n+ 'num_point_qk': 4,\n+ 'num_point_v': 8,\n+ 'num_scalar_qk': 16,\n+ 'num_scalar_v': 16,\n+ 'position_scale': 20.0,\n+ 'sidechain': {\n+ 'atom_clamp_distance': 10.0,\n+ 'loss_unit_distance': 10.0,\n+ 'num_channel': 128,\n+ 'num_residual_block': 2,\n+ 'weight_frac': 0.5\n+ },\n+ 'structural_violation_loss_weight': 1.0,\n+ 'violation_tolerance_factor': 12.0,\n+ 'weight': 1.0\n+ }\n+ },\n+ 'num_ensemble_eval': 1,\n+ 'num_recycle': 3,\n+ 'resample_msa_in_recycling': True\n+ }\n+})\n" |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/data.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/data.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,39 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convenience functions for reading data.""" + +import io +import os +from typing import List +from alphafold.model import utils +import haiku as hk +import numpy as np +# Internal import (7716). + + +def casp_model_names(data_dir: str) -> List[str]: + params = os.listdir(os.path.join(data_dir, 'params')) + return [os.path.splitext(filename)[0] for filename in params] + + +def get_model_haiku_params(model_name: str, data_dir: str) -> hk.Params: + """Get the Haiku parameters from a model name.""" + + path = os.path.join(data_dir, 'params', f'params_{model_name}.npz') + + with open(path, 'rb') as f: + params = np.load(io.BytesIO(f.read()), allow_pickle=False) + + return utils.flat_params_to_haiku(params) |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/features.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/features.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,104 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Code to generate processed features.""" +import copy +from typing import List, Mapping, Tuple + +from alphafold.model.tf import input_pipeline +from alphafold.model.tf import proteins_dataset + +import ml_collections +import numpy as np +import tensorflow.compat.v1 as tf + +FeatureDict = Mapping[str, np.ndarray] + + +def make_data_config( + config: ml_collections.ConfigDict, + num_res: int, + ) -> Tuple[ml_collections.ConfigDict, List[str]]: + """Makes a data config for the input pipeline.""" + cfg = copy.deepcopy(config.data) + + feature_names = cfg.common.unsupervised_features + if cfg.common.use_templates: + feature_names += cfg.common.template_features + + with cfg.unlocked(): + cfg.eval.crop_size = num_res + + return cfg, feature_names + + +def tf_example_to_features(tf_example: tf.train.Example, + config: ml_collections.ConfigDict, + random_seed: int = 0) -> FeatureDict: + """Converts tf_example to numpy feature dictionary.""" + num_res = int(tf_example.features.feature['seq_length'].int64_list.value[0]) + cfg, feature_names = make_data_config(config, num_res=num_res) + + if 'deletion_matrix_int' in set(tf_example.features.feature): + deletion_matrix_int = ( + tf_example.features.feature['deletion_matrix_int'].int64_list.value) + feat = tf.train.Feature(float_list=tf.train.FloatList( + value=map(float, deletion_matrix_int))) + tf_example.features.feature['deletion_matrix'].CopyFrom(feat) + del tf_example.features.feature['deletion_matrix_int'] + + tf_graph = tf.Graph() + with tf_graph.as_default(), tf.device('/device:CPU:0'): + tf.compat.v1.set_random_seed(random_seed) + tensor_dict = proteins_dataset.create_tensor_dict( + raw_data=tf_example.SerializeToString(), + features=feature_names) + processed_batch = input_pipeline.process_tensors_from_config( + tensor_dict, cfg) + + tf_graph.finalize() + + with tf.Session(graph=tf_graph) as sess: + features = sess.run(processed_batch) + + return {k: v for k, v in features.items() if v.dtype != 'O'} + + +def np_example_to_features(np_example: FeatureDict, + config: ml_collections.ConfigDict, + random_seed: int = 0) -> FeatureDict: + """Preprocesses NumPy feature dict using TF pipeline.""" + np_example = dict(np_example) + num_res = int(np_example['seq_length'][0]) + cfg, feature_names = make_data_config(config, num_res=num_res) + + if 'deletion_matrix_int' in np_example: + np_example['deletion_matrix'] = ( + np_example.pop('deletion_matrix_int').astype(np.float32)) + + tf_graph = tf.Graph() + with tf_graph.as_default(), tf.device('/device:CPU:0'): + tf.compat.v1.set_random_seed(random_seed) + tensor_dict = proteins_dataset.np_to_tensor_dict( + np_example=np_example, features=feature_names) + + processed_batch = input_pipeline.process_tensors_from_config( + tensor_dict, cfg) + + tf_graph.finalize() + + with tf.Session(graph=tf_graph) as sess: + features = sess.run(processed_batch) + + return {k: v for k, v in features.items() if v.dtype != 'O'} |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/folding.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/folding.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,1009 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Modules and utilities for the structure module."""\n+\n+import functools\n+from typing import Dict\n+from alphafold.common import residue_constants\n+from alphafold.model import all_atom\n+from alphafold.model import common_modules\n+from alphafold.model import prng\n+from alphafold.model import quat_affine\n+from alphafold.model import r3\n+from alphafold.model import utils\n+import haiku as hk\n+import jax\n+import jax.numpy as jnp\n+import ml_collections\n+import numpy as np\n+\n+\n+def squared_difference(x, y):\n+ return jnp.square(x - y)\n+\n+\n+class InvariantPointAttention(hk.Module):\n+ """Invariant Point attention module.\n+\n+ The high-level idea is that this attention module works over a set of points\n+ and associated orientations in 3D space (e.g. protein residues).\n+\n+ Each residue outputs a set of queries and keys as points in their local\n+ reference frame. The attention is then defined as the euclidean distance\n+ between the queries and keys in the global frame.\n+\n+ Jumper et al. (2021) Suppl. Alg. 22 "InvariantPointAttention"\n+ """\n+\n+ def __init__(self,\n+ config,\n+ global_config,\n+ dist_epsilon=1e-8,\n+ name=\'invariant_point_attention\'):\n+ """Initialize.\n+\n+ Args:\n+ config: Structure Module Config\n+ global_config: Global Config of Model.\n+ dist_epsilon: Small value to avoid NaN in distance calculation.\n+ name: Haiku Module name.\n+ """\n+ super().__init__(name=name)\n+\n+ self._dist_epsilon = dist_epsilon\n+ self._zero_initialize_last = global_config.zero_init\n+\n+ self.config = config\n+\n+ self.global_config = global_config\n+\n+ def __call__(self, inputs_1d, inputs_2d, mask, affine):\n+ """Compute geometry-aware attention.\n+\n+ Given a set of query residues (defined by affines and associated scalar\n+ features), this function computes geometry-aware attention between the\n+ query residues and target residues.\n+\n+ The residues produce points in their local reference frame, which\n+ are converted into the global frame in order to compute attention via\n+ euclidean distance.\n+\n+ Equivalently, the target residues produce points in their local frame to be\n+ used as attention values, which are converted into the query residues\'\n+ local frames.\n+\n+ Args:\n+ inputs_1d: (N, C) 1D input embedding that is the basis for the\n+ scalar queries.\n+ inputs_2d: (N, M, C\') 2D input embedding, used for biases and values.\n+ mask: (N, 1) mask to indicate which elements of inputs_1d participate\n+ in the attention.\n+ affine: QuatAffine object describing the position and orientation of\n+ every element in inputs_1d.\n+\n+ Returns:\n+ Transformation of the input embedding.\n+ """\n+ num_residues, _ = inputs_1d.shape\n+\n+ # Improve readability by removing a large number of \'self\'s.\n+ num_head = self.config.num_head\n+ num_scalar_qk = self.config.num_scalar_qk\n+ num_point_qk = self.config.num_point_qk\n+ num_scalar_v = self.config.num_scalar_v\n+ num_point_v = self.config.num_point_v\n+ num_output = self.config.num_channel\n+\n+ assert num_scalar_qk > 0\n+ assert num_point_qk > 0\n+ assert num_point_v > 0\n+\n+ # Construct scalar queries of shape:\n+ # [num_query_residues, num_head, num_points]\n+ q_scalar = common_modules.Linear(\n+ n'..b'hifted, pred_angles), -1)\n+ sq_chi_error = jnp.minimum(sq_chi_error, sq_chi_error_shifted)\n+\n+ sq_chi_loss = utils.mask_mean(mask=chi_mask[None], value=sq_chi_error)\n+ ret[\'chi_loss\'] = sq_chi_loss\n+ ret[\'loss\'] += config.chi_weight * sq_chi_loss\n+ unnormed_angles = jnp.reshape(\n+ value[\'sidechains\'][\'unnormalized_angles_sin_cos\'], [-1, num_res, 7, 2])\n+ angle_norm = jnp.sqrt(jnp.sum(jnp.square(unnormed_angles), axis=-1) + eps)\n+ norm_error = jnp.abs(angle_norm - 1.)\n+ angle_norm_loss = utils.mask_mean(mask=sequence_mask[None, :, None],\n+ value=norm_error)\n+\n+ ret[\'angle_norm_loss\'] = angle_norm_loss\n+ ret[\'loss\'] += config.angle_norm_weight * angle_norm_loss\n+\n+\n+def generate_new_affine(sequence_mask):\n+ num_residues, _ = sequence_mask.shape\n+ quaternion = jnp.tile(\n+ jnp.reshape(jnp.asarray([1., 0., 0., 0.]), [1, 4]),\n+ [num_residues, 1])\n+\n+ translation = jnp.zeros([num_residues, 3])\n+ return quat_affine.QuatAffine(quaternion, translation, unstack_inputs=True)\n+\n+\n+def l2_normalize(x, axis=-1, epsilon=1e-12):\n+ return x / jnp.sqrt(\n+ jnp.maximum(jnp.sum(x**2, axis=axis, keepdims=True), epsilon))\n+\n+\n+class MultiRigidSidechain(hk.Module):\n+ """Class to make side chain atoms."""\n+\n+ def __init__(self, config, global_config, name=\'rigid_sidechain\'):\n+ super().__init__(name=name)\n+ self.config = config\n+ self.global_config = global_config\n+\n+ def __call__(self, affine, representations_list, aatype):\n+ """Predict side chains using multi-rigid representations.\n+\n+ Args:\n+ affine: The affines for each residue (translations in angstroms).\n+ representations_list: A list of activations to predict side chains from.\n+ aatype: Amino acid types.\n+\n+ Returns:\n+ Dict containing atom positions and frames (in angstroms).\n+ """\n+ act = [\n+ common_modules.Linear( # pylint: disable=g-complex-comprehension\n+ self.config.num_channel,\n+ name=\'input_projection\')(jax.nn.relu(x))\n+ for x in representations_list\n+ ]\n+ # Sum the activation list (equivalent to concat then Linear).\n+ act = sum(act)\n+\n+ final_init = \'zeros\' if self.global_config.zero_init else \'linear\'\n+\n+ # Mapping with some residual blocks.\n+ for _ in range(self.config.num_residual_block):\n+ old_act = act\n+ act = common_modules.Linear(\n+ self.config.num_channel,\n+ initializer=\'relu\',\n+ name=\'resblock1\')(\n+ jax.nn.relu(act))\n+ act = common_modules.Linear(\n+ self.config.num_channel,\n+ initializer=final_init,\n+ name=\'resblock2\')(\n+ jax.nn.relu(act))\n+ act += old_act\n+\n+ # Map activations to torsion angles. Shape: (num_res, 14).\n+ num_res = act.shape[0]\n+ unnormalized_angles = common_modules.Linear(\n+ 14, name=\'unnormalized_angles\')(\n+ jax.nn.relu(act))\n+ unnormalized_angles = jnp.reshape(\n+ unnormalized_angles, [num_res, 7, 2])\n+ angles = l2_normalize(unnormalized_angles, axis=-1)\n+\n+ outputs = {\n+ \'angles_sin_cos\': angles, # jnp.ndarray (N, 7, 2)\n+ \'unnormalized_angles_sin_cos\':\n+ unnormalized_angles, # jnp.ndarray (N, 7, 2)\n+ }\n+\n+ # Map torsion angles to frames.\n+ backb_to_global = r3.rigids_from_quataffine(affine)\n+\n+ # Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates"\n+\n+ # r3.Rigids with shape (N, 8).\n+ all_frames_to_global = all_atom.torsion_angles_to_frames(\n+ aatype,\n+ backb_to_global,\n+ angles)\n+\n+ # Use frames and literature positions to create the final atom coordinates.\n+ # r3.Vecs with shape (N, 14).\n+ pred_positions = all_atom.frames_and_literature_positions_to_atom14_pos(\n+ aatype, all_frames_to_global)\n+\n+ outputs.update({\n+ \'atom_pos\': pred_positions, # r3.Vecs (N, 14)\n+ \'frames\': all_frames_to_global, # r3.Rigids (N, 8)\n+ })\n+ return outputs\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/folding_multimer.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/folding_multimer.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,1160 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Modules and utilities for the structure module in the multimer system."""\n+\n+import functools\n+import numbers\n+from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union\n+\n+from alphafold.common import residue_constants\n+from alphafold.model import all_atom_multimer\n+from alphafold.model import common_modules\n+from alphafold.model import geometry\n+from alphafold.model import modules\n+from alphafold.model import prng\n+from alphafold.model import utils\n+from alphafold.model.geometry import utils as geometry_utils\n+import haiku as hk\n+import jax\n+import jax.numpy as jnp\n+import ml_collections\n+import numpy as np\n+\n+\n+EPSILON = 1e-8\n+Float = Union[float, jnp.ndarray]\n+\n+\n+def squared_difference(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:\n+ """Computes Squared difference between two arrays."""\n+ return jnp.square(x - y)\n+\n+\n+def make_backbone_affine(\n+ positions: geometry.Vec3Array,\n+ mask: jnp.ndarray,\n+ aatype: jnp.ndarray,\n+ ) -> Tuple[geometry.Rigid3Array, jnp.ndarray]:\n+ """Make backbone Rigid3Array and mask."""\n+ del aatype\n+ a = residue_constants.atom_order[\'N\']\n+ b = residue_constants.atom_order[\'CA\']\n+ c = residue_constants.atom_order[\'C\']\n+\n+ rigid_mask = (mask[:, a] * mask[:, b] * mask[:, c]).astype(\n+ jnp.float32)\n+\n+ rigid = all_atom_multimer.make_transform_from_reference(\n+ a_xyz=positions[:, a], b_xyz=positions[:, b], c_xyz=positions[:, c])\n+\n+ return rigid, rigid_mask\n+\n+\n+class QuatRigid(hk.Module):\n+ """Module for projecting Rigids via a quaternion."""\n+\n+ def __init__(self,\n+ global_config: ml_collections.ConfigDict,\n+ rigid_shape: Union[int, Iterable[int]] = tuple(),\n+ full_quat: bool = False,\n+ init: str = \'zeros\',\n+ name: str = \'quat_rigid\'):\n+ """Module projecting a Rigid Object.\n+\n+ For this Module the Rotation is parametrized as a quaternion,\n+ If \'full_quat\' is True a 4 vector is produced for the rotation which is\n+ normalized and treated as a quaternion.\n+ When \'full_quat\' is False a 3 vector is produced and the 1st component of\n+ the quaternion is set to 1.\n+\n+ Args:\n+ global_config: Global Config, used to set certain properties of underlying\n+ Linear module, see common_modules.Linear for details.\n+ rigid_shape: Shape of Rigids relative to shape of activations, e.g. when\n+ activations have shape (n,) and this is (m,) output will be (n, m)\n+ full_quat: Whether to parametrize rotation using full quaternion.\n+ init: initializer to use, see common_modules.Linear for details\n+ name: Name to use for module.\n+ """\n+ self.init = init\n+ self.global_config = global_config\n+ if isinstance(rigid_shape, int):\n+ self.rigid_shape = (rigid_shape,)\n+ else:\n+ self.rigid_shape = tuple(rigid_shape)\n+ self.full_quat = full_quat\n+ super(QuatRigid, self).__init__(name=name)\n+\n+ def __call__(self, activations: jnp.ndarray) -> geometry.Rigid3Array:\n+ """Executes Module.\n+\n+ This returns a set of rigid with the same shape as activations, projecting\n+ the channel dimension, rigid_shape controls the trailing dimensions.\n+ For example when activations is shape (12, 5) and rigid_shape is (3, 2)\n+ then the shape of the output rigids will be (12, 3, 2).\n+ This also supports passing in an '..b'+ config.angle_norm_weight * angle_norm_loss)\n+ return loss, sq_chi_loss, angle_norm_loss\n+\n+\n+def l2_normalize(x: jnp.ndarray,\n+ axis: int = -1,\n+ epsilon: float = 1e-12\n+ ) -> jnp.ndarray:\n+ return x / jnp.sqrt(\n+ jnp.maximum(jnp.sum(x**2, axis=axis, keepdims=True), epsilon))\n+\n+\n+def get_renamed_chi_angles(aatype: jnp.ndarray,\n+ chi_angles: jnp.ndarray,\n+ alt_is_better: jnp.ndarray\n+ ) -> jnp.ndarray:\n+ """Return renamed chi angles."""\n+ chi_angle_is_ambiguous = utils.batched_gather(\n+ jnp.array(residue_constants.chi_pi_periodic, dtype=jnp.float32), aatype)\n+ alt_chi_angles = chi_angles + np.pi * chi_angle_is_ambiguous\n+ # Map back to [-pi, pi].\n+ alt_chi_angles = alt_chi_angles - 2 * np.pi * (alt_chi_angles > np.pi).astype(\n+ jnp.float32)\n+ alt_is_better = alt_is_better[:, None]\n+ return (1. - alt_is_better) * chi_angles + alt_is_better * alt_chi_angles\n+\n+\n+class MultiRigidSidechain(hk.Module):\n+ """Class to make side chain atoms."""\n+\n+ def __init__(self,\n+ config: ml_collections.ConfigDict,\n+ global_config: ml_collections.ConfigDict,\n+ name: str = \'rigid_sidechain\'):\n+ super().__init__(name=name)\n+ self.config = config\n+ self.global_config = global_config\n+\n+ def __call__(self,\n+ rigid: geometry.Rigid3Array,\n+ representations_list: Iterable[jnp.ndarray],\n+ aatype: jnp.ndarray\n+ ) -> Dict[str, Any]:\n+ """Predict sidechains using multi-rigid representations.\n+\n+ Args:\n+ rigid: The Rigid\'s for each residue (translations in angstoms)\n+ representations_list: A list of activations to predict sidechains from.\n+ aatype: amino acid types.\n+\n+ Returns:\n+ dict containing atom positions and frames (in angstrom)\n+ """\n+ act = [\n+ common_modules.Linear( # pylint: disable=g-complex-comprehension\n+ self.config.num_channel,\n+ name=\'input_projection\')(jax.nn.relu(x))\n+ for x in representations_list]\n+ # Sum the activation list (equivalent to concat then Conv1D)\n+ act = sum(act)\n+\n+ final_init = \'zeros\' if self.global_config.zero_init else \'linear\'\n+\n+ # Mapping with some residual blocks.\n+ for _ in range(self.config.num_residual_block):\n+ old_act = act\n+ act = common_modules.Linear(\n+ self.config.num_channel,\n+ initializer=\'relu\',\n+ name=\'resblock1\')(\n+ jax.nn.relu(act))\n+ act = common_modules.Linear(\n+ self.config.num_channel,\n+ initializer=final_init,\n+ name=\'resblock2\')(\n+ jax.nn.relu(act))\n+ act += old_act\n+\n+ # Map activations to torsion angles.\n+ # [batch_size, num_res, 14]\n+ num_res = act.shape[0]\n+ unnormalized_angles = common_modules.Linear(\n+ 14, name=\'unnormalized_angles\')(\n+ jax.nn.relu(act))\n+ unnormalized_angles = jnp.reshape(\n+ unnormalized_angles, [num_res, 7, 2])\n+ angles = l2_normalize(unnormalized_angles, axis=-1)\n+\n+ outputs = {\n+ \'angles_sin_cos\': angles, # jnp.ndarray (N, 7, 2)\n+ \'unnormalized_angles_sin_cos\':\n+ unnormalized_angles, # jnp.ndarray (N, 7, 2)\n+ }\n+\n+ # Map torsion angles to frames.\n+ # geometry.Rigid3Array with shape (N, 8)\n+ all_frames_to_global = all_atom_multimer.torsion_angles_to_frames(\n+ aatype,\n+ rigid,\n+ angles)\n+\n+ # Use frames and literature positions to create the final atom coordinates.\n+ # geometry.Vec3Array with shape (N, 14)\n+ pred_positions = all_atom_multimer.frames_and_literature_positions_to_atom14_pos(\n+ aatype, all_frames_to_global)\n+\n+ outputs.update({\n+ \'atom_pos\': pred_positions, # geometry.Vec3Array (N, 14)\n+ \'frames\': all_frames_to_global, # geometry.Rigid3Array (N, 8)\n+ })\n+ return outputs\n+\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/geometry/__init__.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/geometry/__init__.py Tue Mar 01 02:53:05 2022 +0000 |
b |
@@ -0,0 +1,31 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Geometry Module.""" + +from alphafold.model.geometry import rigid_matrix_vector +from alphafold.model.geometry import rotation_matrix +from alphafold.model.geometry import struct_of_array +from alphafold.model.geometry import vector + +Rot3Array = rotation_matrix.Rot3Array +Rigid3Array = rigid_matrix_vector.Rigid3Array + +StructOfArray = struct_of_array.StructOfArray + +Vec3Array = vector.Vec3Array +square_euclidean_distance = vector.square_euclidean_distance +euclidean_distance = vector.euclidean_distance +dihedral_angle = vector.dihedral_angle +dot = vector.dot +cross = vector.cross |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/geometry/rigid_matrix_vector.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/geometry/rigid_matrix_vector.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,106 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rigid3Array Transformations represented by a Matrix and a Vector.""" + +from __future__ import annotations +from typing import Union + +from alphafold.model.geometry import rotation_matrix +from alphafold.model.geometry import struct_of_array +from alphafold.model.geometry import vector +import jax +import jax.numpy as jnp + +Float = Union[float, jnp.ndarray] + +VERSION = '0.1' + + +@struct_of_array.StructOfArray(same_dtype=True) +class Rigid3Array: + """Rigid Transformation, i.e. element of special euclidean group.""" + + rotation: rotation_matrix.Rot3Array + translation: vector.Vec3Array + + def __matmul__(self, other: Rigid3Array) -> Rigid3Array: + new_rotation = self.rotation @ other.rotation + new_translation = self.apply_to_point(other.translation) + return Rigid3Array(new_rotation, new_translation) + + def inverse(self) -> Rigid3Array: + """Return Rigid3Array corresponding to inverse transform.""" + inv_rotation = self.rotation.inverse() + inv_translation = inv_rotation.apply_to_point(-self.translation) + return Rigid3Array(inv_rotation, inv_translation) + + def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Apply Rigid3Array transform to point.""" + return self.rotation.apply_to_point(point) + self.translation + + def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Apply inverse Rigid3Array transform to point.""" + new_point = point - self.translation + return self.rotation.apply_inverse_to_point(new_point) + + def compose_rotation(self, other_rotation): + rot = self.rotation @ other_rotation + trans = jax.tree_map(lambda x: jnp.broadcast_to(x, rot.shape), + self.translation) + return Rigid3Array(rot, trans) + + @classmethod + def identity(cls, shape, dtype=jnp.float32) -> Rigid3Array: + """Return identity Rigid3Array of given shape.""" + return cls( + rotation_matrix.Rot3Array.identity(shape, dtype=dtype), + vector.Vec3Array.zeros(shape, dtype=dtype)) + + def scale_translation(self, factor: Float) -> Rigid3Array: + """Scale translation in Rigid3Array by 'factor'.""" + return Rigid3Array(self.rotation, self.translation * factor) + + def to_array(self): + rot_array = self.rotation.to_array() + vec_array = self.translation.to_array() + return jnp.concatenate([rot_array, vec_array[..., None]], axis=-1) + + @classmethod + def from_array(cls, array): + rot = rotation_matrix.Rot3Array.from_array(array[..., :3]) + vec = vector.Vec3Array.from_array(array[..., -1]) + return cls(rot, vec) + + @classmethod + def from_array4x4(cls, array: jnp.ndarray) -> Rigid3Array: + """Construct Rigid3Array from homogeneous 4x4 array.""" + assert array.shape[-1] == 4 + assert array.shape[-2] == 4 + rotation = rotation_matrix.Rot3Array( + array[..., 0, 0], array[..., 0, 1], array[..., 0, 2], + array[..., 1, 0], array[..., 1, 1], array[..., 1, 2], + array[..., 2, 0], array[..., 2, 1], array[..., 2, 2] + ) + translation = vector.Vec3Array( + array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]) + return cls(rotation, translation) + + def __getstate__(self): + return (VERSION, (self.rotation, self.translation)) + + def __setstate__(self, state): + version, (rot, trans) = state + del version + object.__setattr__(self, 'rotation', rot) + object.__setattr__(self, 'translation', trans) |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/geometry/rotation_matrix.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/geometry/rotation_matrix.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,157 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rot3Array Matrix Class.""" + +from __future__ import annotations +import dataclasses + +from alphafold.model.geometry import struct_of_array +from alphafold.model.geometry import utils +from alphafold.model.geometry import vector +import jax +import jax.numpy as jnp +import numpy as np + +COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz'] + +VERSION = '0.1' + + +@struct_of_array.StructOfArray(same_dtype=True) +class Rot3Array: + """Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays.""" + + xx: jnp.ndarray = dataclasses.field(metadata={'dtype': jnp.float32}) + xy: jnp.ndarray + xz: jnp.ndarray + yx: jnp.ndarray + yy: jnp.ndarray + yz: jnp.ndarray + zx: jnp.ndarray + zy: jnp.ndarray + zz: jnp.ndarray + + __array_ufunc__ = None + + def inverse(self) -> Rot3Array: + """Returns inverse of Rot3Array.""" + return Rot3Array(self.xx, self.yx, self.zx, + self.xy, self.yy, self.zy, + self.xz, self.yz, self.zz) + + def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Applies Rot3Array to point.""" + return vector.Vec3Array( + self.xx * point.x + self.xy * point.y + self.xz * point.z, + self.yx * point.x + self.yy * point.y + self.yz * point.z, + self.zx * point.x + self.zy * point.y + self.zz * point.z) + + def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Applies inverse Rot3Array to point.""" + return self.inverse().apply_to_point(point) + + def __matmul__(self, other: Rot3Array) -> Rot3Array: + """Composes two Rot3Arrays.""" + c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx)) + c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy)) + c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz)) + return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z) + + @classmethod + def identity(cls, shape, dtype=jnp.float32) -> Rot3Array: + """Returns identity of given shape.""" + ones = jnp.ones(shape, dtype=dtype) + zeros = jnp.zeros(shape, dtype=dtype) + return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) + + @classmethod + def from_two_vectors(cls, e0: vector.Vec3Array, + e1: vector.Vec3Array) -> Rot3Array: + """Construct Rot3Array from two Vectors. + + Rot3Array is constructed such that in the corresponding frame 'e0' lies on + the positive x-Axis and 'e1' lies in the xy plane with positive sign of y. + + Args: + e0: Vector + e1: Vector + Returns: + Rot3Array + """ + # Normalize the unit vector for the x-axis, e0. + e0 = e0.normalized() + # make e1 perpendicular to e0. + c = e1.dot(e0) + e1 = (e1 - c * e0).normalized() + # Compute e2 as cross product of e0 and e1. + e2 = e0.cross(e1) + return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) + + @classmethod + def from_array(cls, array: jnp.ndarray) -> Rot3Array: + """Construct Rot3Array Matrix from array of shape. [..., 3, 3].""" + unstacked = utils.unstack(array, axis=-2) + unstacked = sum([utils.unstack(x, axis=-1) for x in unstacked], []) + return cls(*unstacked) + + def to_array(self) -> jnp.ndarray: + """Convert Rot3Array to array of shape [..., 3, 3].""" + return jnp.stack( + [jnp.stack([self.xx, self.xy, self.xz], axis=-1), + jnp.stack([self.yx, self.yy, self.yz], axis=-1), + jnp.stack([self.zx, self.zy, self.zz], axis=-1)], + axis=-2) + + @classmethod + def from_quaternion(cls, + w: jnp.ndarray, + x: jnp.ndarray, + y: jnp.ndarray, + z: jnp.ndarray, + normalize: bool = True, + epsilon: float = 1e-6) -> Rot3Array: + """Construct Rot3Array from components of quaternion.""" + if normalize: + inv_norm = jax.lax.rsqrt(jnp.maximum(epsilon, w**2 + x**2 + y**2 + z**2)) + w *= inv_norm + x *= inv_norm + y *= inv_norm + z *= inv_norm + xx = 1 - 2 * (jnp.square(y) + jnp.square(z)) + xy = 2 * (x * y - w * z) + xz = 2 * (x * z + w * y) + yx = 2 * (x * y + w * z) + yy = 1 - 2 * (jnp.square(x) + jnp.square(z)) + yz = 2 * (y * z - w * x) + zx = 2 * (x * z - w * y) + zy = 2 * (y * z + w * x) + zz = 1 - 2 * (jnp.square(x) + jnp.square(y)) + return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) + + @classmethod + def random_uniform(cls, key, shape, dtype=jnp.float32) -> Rot3Array: + """Samples uniform random Rot3Array according to Haar Measure.""" + quat_array = jax.random.normal(key, tuple(shape) + (4,), dtype=dtype) + quats = utils.unstack(quat_array) + return cls.from_quaternion(*quats) + + def __getstate__(self): + return (VERSION, + [np.asarray(getattr(self, field)) for field in COMPONENTS]) + + def __setstate__(self, state): + version, state = state + del version + for i, field in enumerate(COMPONENTS): + object.__setattr__(self, field, state[i]) |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/geometry/struct_of_array.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/geometry/struct_of_array.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,220 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Class decorator to represent (nested) struct of arrays.""" + +import dataclasses + +import jax + + +def get_item(instance, key): + sliced = {} + for field in get_array_fields(instance): + num_trailing_dims = field.metadata.get('num_trailing_dims', 0) + this_key = key + if isinstance(key, tuple) and Ellipsis in this_key: + this_key += (slice(None),) * num_trailing_dims + sliced[field.name] = getattr(instance, field.name)[this_key] + return dataclasses.replace(instance, **sliced) + + +@property +def get_shape(instance): + """Returns Shape for given instance of dataclass.""" + first_field = dataclasses.fields(instance)[0] + num_trailing_dims = first_field.metadata.get('num_trailing_dims', None) + value = getattr(instance, first_field.name) + if num_trailing_dims: + return value.shape[:-num_trailing_dims] + else: + return value.shape + + +def get_len(instance): + """Returns length for given instance of dataclass.""" + shape = instance.shape + if shape: + return shape[0] + else: + raise TypeError('len() of unsized object') # Match jax.numpy behavior. + + +@property +def get_dtype(instance): + """Returns Dtype for given instance of dataclass.""" + fields = dataclasses.fields(instance) + sets_dtype = [ + field.name for field in fields if field.metadata.get('sets_dtype', False) + ] + if sets_dtype: + assert len(sets_dtype) == 1, 'at most field can set dtype' + field_value = getattr(instance, sets_dtype[0]) + elif instance.same_dtype: + field_value = getattr(instance, fields[0].name) + else: + # Should this be Value Error? + raise AttributeError('Trying to access Dtype on Struct of Array without' + 'either "same_dtype" or field setting dtype') + + if hasattr(field_value, 'dtype'): + return field_value.dtype + else: + # Should this be Value Error? + raise AttributeError(f'field_value {field_value} does not have dtype') + + +def replace(instance, **kwargs): + return dataclasses.replace(instance, **kwargs) + + +def post_init(instance): + """Validate instance has same shapes & dtypes.""" + array_fields = get_array_fields(instance) + arrays = list(get_array_fields(instance, return_values=True).values()) + first_field = array_fields[0] + # These slightly weird constructions about checking whether the leaves are + # actual arrays is since e.g. vmap internally relies on being able to + # construct pytree's with object() as leaves, this would break the checking + # as such we are only validating the object when the entries in the dataclass + # Are arrays or other dataclasses of arrays. + try: + dtype = instance.dtype + except AttributeError: + dtype = None + if dtype is not None: + first_shape = instance.shape + for array, field in zip(arrays, array_fields): + field_shape = array.shape + num_trailing_dims = field.metadata.get('num_trailing_dims', None) + if num_trailing_dims: + array_shape = array.shape + field_shape = array_shape[:-num_trailing_dims] + msg = (f'field {field} should have number of trailing dims' + ' {num_trailing_dims}') + assert len(array_shape) == len(first_shape) + num_trailing_dims, msg + else: + field_shape = array.shape + + shape_msg = (f"Stripped Shape {field_shape} of field {field} doesn't " + f"match shape {first_shape} of field {first_field}") + assert field_shape == first_shape, shape_msg + + field_dtype = array.dtype + + allowed_metadata_dtypes = field.metadata.get('allowed_dtypes', []) + if allowed_metadata_dtypes: + msg = f'Dtype is {field_dtype} but must be in {allowed_metadata_dtypes}' + assert field_dtype in allowed_metadata_dtypes, msg + + if 'dtype' in field.metadata: + target_dtype = field.metadata['dtype'] + else: + target_dtype = dtype + + msg = f'Dtype is {field_dtype} but must be {target_dtype}' + assert field_dtype == target_dtype, msg + + +def flatten(instance): + """Flatten Struct of Array instance.""" + array_likes = list(get_array_fields(instance, return_values=True).values()) + flat_array_likes = [] + inner_treedefs = [] + num_arrays = [] + for array_like in array_likes: + flat_array_like, inner_treedef = jax.tree_flatten(array_like) + inner_treedefs.append(inner_treedef) + flat_array_likes += flat_array_like + num_arrays.append(len(flat_array_like)) + metadata = get_metadata_fields(instance, return_values=True) + metadata = type(instance).metadata_cls(**metadata) + return flat_array_likes, (inner_treedefs, metadata, num_arrays) + + +def make_metadata_class(cls): + metadata_fields = get_fields(cls, + lambda x: x.metadata.get('is_metadata', False)) + metadata_cls = dataclasses.make_dataclass( + cls_name='Meta' + cls.__name__, + fields=[(field.name, field.type, field) for field in metadata_fields], + frozen=True, + eq=True) + return metadata_cls + + +def get_fields(cls_or_instance, filterfn, return_values=False): + fields = dataclasses.fields(cls_or_instance) + fields = [field for field in fields if filterfn(field)] + if return_values: + return { + field.name: getattr(cls_or_instance, field.name) for field in fields + } + else: + return fields + + +def get_array_fields(cls, return_values=False): + return get_fields( + cls, + lambda x: not x.metadata.get('is_metadata', False), + return_values=return_values) + + +def get_metadata_fields(cls, return_values=False): + return get_fields( + cls, + lambda x: x.metadata.get('is_metadata', False), + return_values=return_values) + + +class StructOfArray: + """Class Decorator for Struct Of Arrays.""" + + def __init__(self, same_dtype=True): + self.same_dtype = same_dtype + + def __call__(self, cls): + cls.__array_ufunc__ = None + cls.replace = replace + cls.same_dtype = self.same_dtype + cls.dtype = get_dtype + cls.shape = get_shape + cls.__len__ = get_len + cls.__getitem__ = get_item + cls.__post_init__ = post_init + new_cls = dataclasses.dataclass(cls, frozen=True, eq=False) # pytype: disable=wrong-keyword-args + # pytree claims to require metadata to be hashable, not sure why, + # But making derived dataclass that can just hold metadata + new_cls.metadata_cls = make_metadata_class(new_cls) + + def unflatten(aux, data): + inner_treedefs, metadata, num_arrays = aux + array_fields = [field.name for field in get_array_fields(new_cls)] + value_dict = {} + array_start = 0 + for num_array, inner_treedef, array_field in zip(num_arrays, + inner_treedefs, + array_fields): + value_dict[array_field] = jax.tree_unflatten( + inner_treedef, data[array_start:array_start + num_array]) + array_start += num_array + metadata_fields = get_metadata_fields(new_cls) + for field in metadata_fields: + value_dict[field.name] = getattr(metadata, field.name) + + return new_cls(**value_dict) + + jax.tree_util.register_pytree_node( + nodetype=new_cls, flatten_func=flatten, unflatten_func=unflatten) + return new_cls |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/geometry/test_utils.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/geometry/test_utils.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,98 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared utils for tests.""" + +import dataclasses + +from alphafold.model.geometry import rigid_matrix_vector +from alphafold.model.geometry import rotation_matrix +from alphafold.model.geometry import vector +import jax.numpy as jnp +import numpy as np + + +def assert_rotation_matrix_equal(matrix1: rotation_matrix.Rot3Array, + matrix2: rotation_matrix.Rot3Array): + for field in dataclasses.fields(rotation_matrix.Rot3Array): + field = field.name + np.testing.assert_array_equal( + getattr(matrix1, field), getattr(matrix2, field)) + + +def assert_rotation_matrix_close(mat1: rotation_matrix.Rot3Array, + mat2: rotation_matrix.Rot3Array): + np.testing.assert_array_almost_equal(mat1.to_array(), mat2.to_array(), 6) + + +def assert_array_equal_to_rotation_matrix(array: jnp.ndarray, + matrix: rotation_matrix.Rot3Array): + """Check that array and Matrix match.""" + np.testing.assert_array_equal(matrix.xx, array[..., 0, 0]) + np.testing.assert_array_equal(matrix.xy, array[..., 0, 1]) + np.testing.assert_array_equal(matrix.xz, array[..., 0, 2]) + np.testing.assert_array_equal(matrix.yx, array[..., 1, 0]) + np.testing.assert_array_equal(matrix.yy, array[..., 1, 1]) + np.testing.assert_array_equal(matrix.yz, array[..., 1, 2]) + np.testing.assert_array_equal(matrix.zx, array[..., 2, 0]) + np.testing.assert_array_equal(matrix.zy, array[..., 2, 1]) + np.testing.assert_array_equal(matrix.zz, array[..., 2, 2]) + + +def assert_array_close_to_rotation_matrix(array: jnp.ndarray, + matrix: rotation_matrix.Rot3Array): + np.testing.assert_array_almost_equal(matrix.to_array(), array, 6) + + +def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array): + np.testing.assert_array_equal(vec1.x, vec2.x) + np.testing.assert_array_equal(vec1.y, vec2.y) + np.testing.assert_array_equal(vec1.z, vec2.z) + + +def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array): + np.testing.assert_allclose(vec1.x, vec2.x, atol=1e-6, rtol=0.) + np.testing.assert_allclose(vec1.y, vec2.y, atol=1e-6, rtol=0.) + np.testing.assert_allclose(vec1.z, vec2.z, atol=1e-6, rtol=0.) + + +def assert_array_close_to_vector(array: jnp.ndarray, vec: vector.Vec3Array): + np.testing.assert_allclose(vec.to_array(), array, atol=1e-6, rtol=0.) + + +def assert_array_equal_to_vector(array: jnp.ndarray, vec: vector.Vec3Array): + np.testing.assert_array_equal(vec.to_array(), array) + + +def assert_rigid_equal_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array, + rigid2: rigid_matrix_vector.Rigid3Array): + assert_rot_trans_equal_to_rigid(rigid1.rotation, rigid1.translation, rigid2) + + +def assert_rigid_close_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array, + rigid2: rigid_matrix_vector.Rigid3Array): + assert_rot_trans_close_to_rigid(rigid1.rotation, rigid1.translation, rigid2) + + +def assert_rot_trans_equal_to_rigid(rot: rotation_matrix.Rot3Array, + trans: vector.Vec3Array, + rigid: rigid_matrix_vector.Rigid3Array): + assert_rotation_matrix_equal(rot, rigid.rotation) + assert_vectors_equal(trans, rigid.translation) + + +def assert_rot_trans_close_to_rigid(rot: rotation_matrix.Rot3Array, + trans: vector.Vec3Array, + rigid: rigid_matrix_vector.Rigid3Array): + assert_rotation_matrix_close(rot, rigid.rotation) + assert_vectors_close(trans, rigid.translation) |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/geometry/utils.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/geometry/utils.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,23 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils for geometry library.""" + +from typing import List + +import jax.numpy as jnp + + +def unstack(value: jnp.ndarray, axis: int = -1) -> List[jnp.ndarray]: + return [jnp.squeeze(v, axis=axis) + for v in jnp.split(value, value.shape[axis], axis=axis)] |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/geometry/vector.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/geometry/vector.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,217 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Vec3Array Class.""" + +from __future__ import annotations +import dataclasses +from typing import Union + +from alphafold.model.geometry import struct_of_array +from alphafold.model.geometry import utils +import jax +import jax.numpy as jnp +import numpy as np + +Float = Union[float, jnp.ndarray] + +VERSION = '0.1' + + +@struct_of_array.StructOfArray(same_dtype=True) +class Vec3Array: + """Vec3Array in 3 dimensional Space implemented as struct of arrays. + + This is done in order to improve performance and precision. + On TPU small matrix multiplications are very suboptimal and will waste large + compute ressources, furthermore any matrix multiplication on tpu happen in + mixed bfloat16/float32 precision, which is often undesirable when handling + physical coordinates. + In most cases this will also be faster on cpu's/gpu's since it allows for + easier use of vector instructions. + """ + + x: jnp.ndarray = dataclasses.field(metadata={'dtype': jnp.float32}) + y: jnp.ndarray + z: jnp.ndarray + + def __post_init__(self): + if hasattr(self.x, 'dtype'): + assert self.x.dtype == self.y.dtype + assert self.x.dtype == self.z.dtype + assert all([x == y for x, y in zip(self.x.shape, self.y.shape)]) + assert all([x == z for x, z in zip(self.x.shape, self.z.shape)]) + + def __add__(self, other: Vec3Array) -> Vec3Array: + return jax.tree_multimap(lambda x, y: x + y, self, other) + + def __sub__(self, other: Vec3Array) -> Vec3Array: + return jax.tree_multimap(lambda x, y: x - y, self, other) + + def __mul__(self, other: Float) -> Vec3Array: + return jax.tree_map(lambda x: x * other, self) + + def __rmul__(self, other: Float) -> Vec3Array: + return self * other + + def __truediv__(self, other: Float) -> Vec3Array: + return jax.tree_map(lambda x: x / other, self) + + def __neg__(self) -> Vec3Array: + return jax.tree_map(lambda x: -x, self) + + def __pos__(self) -> Vec3Array: + return jax.tree_map(lambda x: x, self) + + def cross(self, other: Vec3Array) -> Vec3Array: + """Compute cross product between 'self' and 'other'.""" + new_x = self.y * other.z - self.z * other.y + new_y = self.z * other.x - self.x * other.z + new_z = self.x * other.y - self.y * other.x + return Vec3Array(new_x, new_y, new_z) + + def dot(self, other: Vec3Array) -> Float: + """Compute dot product between 'self' and 'other'.""" + return self.x * other.x + self.y * other.y + self.z * other.z + + def norm(self, epsilon: float = 1e-6) -> Float: + """Compute Norm of Vec3Array, clipped to epsilon.""" + # To avoid NaN on the backward pass, we must use maximum before the sqrt + norm2 = self.dot(self) + if epsilon: + norm2 = jnp.maximum(norm2, epsilon**2) + return jnp.sqrt(norm2) + + def norm2(self): + return self.dot(self) + + def normalized(self, epsilon: float = 1e-6) -> Vec3Array: + """Return unit vector with optional clipping.""" + return self / self.norm(epsilon) + + @classmethod + def zeros(cls, shape, dtype=jnp.float32): + """Return Vec3Array corresponding to zeros of given shape.""" + return cls( + jnp.zeros(shape, dtype), jnp.zeros(shape, dtype), + jnp.zeros(shape, dtype)) + + def to_array(self) -> jnp.ndarray: + return jnp.stack([self.x, self.y, self.z], axis=-1) + + @classmethod + def from_array(cls, array): + return cls(*utils.unstack(array)) + + def __getstate__(self): + return (VERSION, + [np.asarray(self.x), + np.asarray(self.y), + np.asarray(self.z)]) + + def __setstate__(self, state): + version, state = state + del version + for i, letter in enumerate('xyz'): + object.__setattr__(self, letter, state[i]) + + +def square_euclidean_distance(vec1: Vec3Array, + vec2: Vec3Array, + epsilon: float = 1e-6) -> Float: + """Computes square of euclidean distance between 'vec1' and 'vec2'. + + Args: + vec1: Vec3Array to compute distance to + vec2: Vec3Array to compute distance from, should be + broadcast compatible with 'vec1' + epsilon: distance is clipped from below to be at least epsilon + + Returns: + Array of square euclidean distances; + shape will be result of broadcasting 'vec1' and 'vec2' + """ + difference = vec1 - vec2 + distance = difference.dot(difference) + if epsilon: + distance = jnp.maximum(distance, epsilon) + return distance + + +def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float: + return vector1.dot(vector2) + + +def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float: + return vector1.cross(vector2) + + +def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float: + return vector.norm(epsilon) + + +def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array: + return vector.normalized(epsilon) + + +def euclidean_distance(vec1: Vec3Array, + vec2: Vec3Array, + epsilon: float = 1e-6) -> Float: + """Computes euclidean distance between 'vec1' and 'vec2'. + + Args: + vec1: Vec3Array to compute euclidean distance to + vec2: Vec3Array to compute euclidean distance from, should be + broadcast compatible with 'vec1' + epsilon: distance is clipped from below to be at least epsilon + + Returns: + Array of euclidean distances; + shape will be result of broadcasting 'vec1' and 'vec2' + """ + distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2) + distance = jnp.sqrt(distance_sq) + return distance + + +def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array, + d: Vec3Array) -> Float: + """Computes torsion angle for a quadruple of points. + + For points (a, b, c, d), this is the angle between the planes defined by + points (a, b, c) and (b, c, d). It is also known as the dihedral angle. + + Arguments: + a: A Vec3Array of coordinates. + b: A Vec3Array of coordinates. + c: A Vec3Array of coordinates. + d: A Vec3Array of coordinates. + + Returns: + A tensor of angles in radians: [-pi, pi]. + """ + v1 = a - b + v2 = b - c + v3 = d - c + + c1 = v1.cross(v2) + c2 = v3.cross(v2) + c3 = c2.cross(c1) + + v2_mag = v2.norm() + return jnp.arctan2(c3.dot(v2), v2_mag * c1.dot(c2)) + + +def random_gaussian_vector(shape, key, dtype=jnp.float32): + vec_array = jax.random.normal(key, shape + (3,), dtype) + return Vec3Array.from_array(vec_array) |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/layer_stack.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/layer_stack.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,274 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Function to stack repeats of a layer function without shared parameters."""\n+\n+import collections\n+import contextlib\n+import functools\n+import inspect\n+from typing import Any, Callable, Optional, Tuple, Union\n+\n+import haiku as hk\n+import jax\n+import jax.numpy as jnp\n+\n+LayerStackCarry = collections.namedtuple(\'LayerStackCarry\', [\'x\', \'rng\'])\n+LayerStackScanned = collections.namedtuple(\'LayerStackScanned\',\n+ [\'i\', \'args_ys\'])\n+\n+# WrappedFn should take in arbitrarily nested `jnp.ndarray`, and return the\n+# exact same type. We cannot express this with `typing`. So we just use it\n+# to inform the user. In reality, the typing below will accept anything.\n+NestedArray = Any\n+WrappedFn = Callable[..., Union[NestedArray, Tuple[NestedArray]]]\n+\n+\n+def _check_no_varargs(f):\n+ if list(inspect.signature(\n+ f).parameters.values())[0].kind == inspect.Parameter.VAR_POSITIONAL:\n+ raise ValueError(\n+ \'The function `f` should not have any `varargs` (that is *args) \'\n+ \'argument. Instead, it should only use explicit positional\'\n+ \'arguments.\')\n+\n+\n+@contextlib.contextmanager\n+def nullcontext():\n+ yield\n+\n+\n+def maybe_with_rng(key):\n+ if key is not None:\n+ return hk.with_rng(key)\n+ else:\n+ return nullcontext()\n+\n+\n+def maybe_fold_in(key, data):\n+ if key is not None:\n+ return jax.random.fold_in(key, data)\n+ else:\n+ return None\n+\n+\n+class _LayerStack(hk.Module):\n+ """Module to compose parameterized functions, implemented as a scan."""\n+\n+ def __init__(self,\n+ count: int,\n+ unroll: int,\n+ name: Optional[str] = None):\n+ """Iterate a function `f` `count` times, with non-shared parameters."""\n+ super().__init__(name=name)\n+ self._count = count\n+ self._unroll = unroll\n+\n+ def __call__(self, x, *args_ys):\n+ count = self._count\n+ if hk.running_init():\n+ # At initialization time, we run just one layer but add an extra first\n+ # dimension to every initialized tensor, making sure to use different\n+ # random keys for different slices.\n+ def creator(next_creator, shape, dtype, init, context):\n+ del context\n+\n+ def multi_init(shape, dtype):\n+ assert shape[0] == count\n+ key = hk.maybe_next_rng_key()\n+\n+ def rng_context_init(slice_idx):\n+ slice_key = maybe_fold_in(key, slice_idx)\n+ with maybe_with_rng(slice_key):\n+ return init(shape[1:], dtype)\n+\n+ return jax.vmap(rng_context_init)(jnp.arange(count))\n+\n+ return next_creator((count,) + tuple(shape), dtype, multi_init)\n+\n+ def getter(next_getter, value, context):\n+ trailing_dims = len(context.original_shape) + 1\n+ sliced_value = jax.lax.index_in_dim(\n+ value, index=0, axis=value.ndim - trailing_dims, keepdims=False)\n+ return next_getter(sliced_value)\n+\n+ with hk.experimental.custom_creator(\n+ creator), hk.experimental.custom_getter(getter):\n+ if len(args_ys) == 1 and args_ys[0] is None:\n+ args0 = (None,)\n+ else:\n+ args0 = [\n+ jax.lax.dynamic_index_in_dim(ys, 0, keepdims=False)\n+ for ys in args_ys\n+ ]\n+ x, z = self._call_wrapped(x, *args0)\n+ if z is None:\n+ return x, z\n+\n+ # Broadcast state '..b'(x=out_x, rng=rng), z\n+\n+ carry = LayerStackCarry(x=x, rng=hk.maybe_next_rng_key())\n+ scanned = LayerStackScanned(i=jnp.arange(count, dtype=jnp.int32),\n+ args_ys=args_ys)\n+\n+ carry, zs = hk.scan(\n+ layer, carry, scanned, length=count, unroll=self._unroll)\n+ return carry.x, zs\n+\n+ def _call_wrapped(self,\n+ x: jnp.ndarray,\n+ *args,\n+ ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:\n+ raise NotImplementedError()\n+\n+\n+class _LayerStackNoState(_LayerStack):\n+ """_LayerStack impl with no per-layer state provided to the function."""\n+\n+ def __init__(self,\n+ f: WrappedFn,\n+ count: int,\n+ unroll: int,\n+ name: Optional[str] = None):\n+ super().__init__(count=count, unroll=unroll, name=name)\n+ _check_no_varargs(f)\n+ self._f = f\n+\n+ @hk.transparent\n+ def _call_wrapped(self, args, y):\n+ del y\n+ ret = self._f(*args)\n+ if len(args) == 1:\n+ # If the function takes a single argument, the wrapped function receives\n+ # a tuple of length 1, and therefore it must return a tuple of length 1.\n+ ret = (ret,)\n+ return ret, None\n+\n+\n+class _LayerStackWithState(_LayerStack):\n+ """_LayerStack impl with per-layer state provided to the function."""\n+\n+ def __init__(self,\n+ f: WrappedFn,\n+ count: int,\n+ unroll: int,\n+ name: Optional[str] = None):\n+ super().__init__(count=count, unroll=unroll, name=name)\n+ self._f = f\n+\n+ @hk.transparent\n+ def _call_wrapped(self, x, *args):\n+ return self._f(x, *args)\n+\n+\n+def layer_stack(num_layers: int,\n+ with_state=False,\n+ unroll: int = 1,\n+ name: Optional[str] = None):\n+ """Utility to wrap a Haiku function and recursively apply it to an input.\n+\n+ A function is valid if it uses only explicit position parameters, and\n+ its return type matches its input type. The position parameters can be\n+ arbitrarily nested structures with `jnp.ndarray` at the leaf nodes. Note\n+ that kwargs are not supported, neither are functions with variable number\n+ of parameters (specified by `*args`).\n+\n+ If `with_state=False` then the new, wrapped function can be understood as\n+ performing the following:\n+ ```\n+ for i in range(num_layers):\n+ x = f(x)\n+ return x\n+ ```\n+\n+ And if `with_state=True`, assuming `f` takes two arguments on top of `x`:\n+ ```\n+ for i in range(num_layers):\n+ x, zs[i] = f(x, ys_0[i], ys_1[i])\n+ return x, zs\n+ ```\n+ The code using `layer_stack` for the above function would be:\n+ ```\n+ def f(x, y_0, y_1):\n+ ...\n+ return new_x, z\n+ x, zs = layer_stack.layer_stack(num_layers,\n+ with_state=True)(f)(x, ys_0, ys_1)\n+ ```\n+\n+ Crucially, any parameters created inside `f` will not be shared across\n+ iterations.\n+\n+ Args:\n+ num_layers: The number of times to iterate the wrapped function.\n+ with_state: Whether or not to pass per-layer state to the wrapped function.\n+ unroll: the unroll used by `scan`.\n+ name: Name of the Haiku context.\n+\n+ Returns:\n+ Callable that will produce a layer stack when called with a valid function.\n+ """\n+ def iterate(f):\n+ if with_state:\n+ @functools.wraps(f)\n+ def wrapped(x, *args):\n+ for ys in args:\n+ assert ys.shape[0] == num_layers\n+ return _LayerStackWithState(\n+ f, num_layers, unroll=unroll, name=name)(x, *args)\n+ else:\n+ _check_no_varargs(f)\n+ @functools.wraps(f)\n+ def wrapped(*args):\n+ ret = _LayerStackNoState(\n+ f, num_layers, unroll=unroll, name=name)(args, None)[0]\n+ if len(args) == 1:\n+ # If the function takes a single argument, we must also return a\n+ # single value, and not a tuple of length 1.\n+ ret = ret[0]\n+ return ret\n+\n+ return wrapped\n+ return iterate\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/layer_stack_test.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/layer_stack_test.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,335 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Tests for layer_stack."""\n+\n+import functools\n+from absl.testing import absltest\n+from absl.testing import parameterized\n+from alphafold.model import layer_stack\n+import haiku as hk\n+import jax\n+import jax.numpy as jnp\n+import numpy as np\n+import scipy\n+\n+\n+# Suffixes applied by Haiku for repeated module names.\n+suffixes = [\'\'] + [f\'_{i}\' for i in range(1, 100)]\n+\n+\n+def _slice_layers_params(layers_params):\n+ sliced_layers_params = {}\n+ for k, v in layers_params.items():\n+ for inner_k in v:\n+ for var_slice, suffix in zip(v[inner_k], suffixes):\n+ k_new = k.split(\'/\')[-1] + suffix\n+ if k_new not in sliced_layers_params:\n+ sliced_layers_params[k_new] = {}\n+ sliced_layers_params[k_new][inner_k] = var_slice\n+ return sliced_layers_params\n+\n+\n+class LayerStackTest(parameterized.TestCase):\n+\n+ @parameterized.parameters([1, 2, 4])\n+ def test_layer_stack(self, unroll):\n+ """Compare layer_stack to the equivalent unrolled stack.\n+\n+ Tests that the layer_stack application of a Haiku layer function is\n+ equivalent to repeatedly applying the layer function in an unrolled loop.\n+\n+ Args:\n+ unroll: Number of unrolled layers.\n+ """\n+ num_layers = 20\n+\n+ def inner_fn(x):\n+ x += hk.Linear(100, name=\'linear1\')(x)\n+ x += hk.Linear(100, name=\'linear2\')(x)\n+ return x\n+\n+ def outer_fn_unrolled(x):\n+ for _ in range(num_layers):\n+ x = inner_fn(x)\n+ return x\n+\n+ def outer_fn_layer_stack(x):\n+ stack = layer_stack.layer_stack(num_layers, unroll=unroll)(inner_fn)\n+ return stack(x)\n+\n+ unrolled_fn = hk.transform(outer_fn_unrolled)\n+ layer_stack_fn = hk.transform(outer_fn_layer_stack)\n+\n+ x = jax.random.uniform(jax.random.PRNGKey(0), [10, 256, 100])\n+\n+ rng_init = jax.random.PRNGKey(42)\n+\n+ params = layer_stack_fn.init(rng_init, x)\n+\n+ sliced_params = _slice_layers_params(params)\n+\n+ unrolled_pred = unrolled_fn.apply(sliced_params, None, x)\n+ layer_stack_pred = layer_stack_fn.apply(params, None, x)\n+\n+ np.testing.assert_allclose(unrolled_pred, layer_stack_pred)\n+\n+ def test_layer_stack_multi_args(self):\n+ """Compare layer_stack to the equivalent unrolled stack.\n+\n+ Similar to `test_layer_stack`, but use a function that takes more than one\n+ argument.\n+ """\n+ num_layers = 20\n+\n+ def inner_fn(x, y):\n+ x_out = x + hk.Linear(100, name=\'linear1\')(y)\n+ y_out = y + hk.Linear(100, name=\'linear2\')(x)\n+ return x_out, y_out\n+\n+ def outer_fn_unrolled(x, y):\n+ for _ in range(num_layers):\n+ x, y = inner_fn(x, y)\n+ return x, y\n+\n+ def outer_fn_layer_stack(x, y):\n+ stack = layer_stack.layer_stack(num_layers)(inner_fn)\n+ return stack(x, y)\n+\n+ unrolled_fn = hk.transform(outer_fn_unrolled)\n+ layer_stack_fn = hk.transform(outer_fn_layer_stack)\n+\n+ x = jax.random.uniform(jax.random.PRNGKey(0), [10, 256, 100])\n+ y = jax.random.uniform(jax.random.PRNGKey(1), [10, 256, 100])\n+\n+ rng_init = jax.random.PRNGKey(42)\n+\n+ params = layer_stack_fn.init(rng_init, x, y)\n+\n+ sliced_params = _slice_layers_params(params)\n+\n+ unrolled_x, unrolled_y = unrolled_fn.apply(sliced_params, None, x, y)\n+ layer_stack_x, layer_stack_y = layer_stack_fn.apply(params, None, x, y)\n+\n+ np.testing.assert_allclose(unrolled_x, layer_stack_x)\n+ np.testing'..b' x = x + jax.random.normal(hk.next_rng_key())\n+ return x\n+\n+ # Evaluate a bunch of times\n+ key, *keys = jax.random.split(jax.random.PRNGKey(7), 1024 + 1)\n+ params = add_random.init(key, 0.)\n+ apply_fn = jax.jit(add_random.apply)\n+ values = [apply_fn(params, key, 0.) for key in keys]\n+\n+ # Should be roughly N(0, sqrt(n))\n+ cdf = scipy.stats.norm(scale=np.sqrt(n)).cdf\n+ _, p = scipy.stats.kstest(values, cdf)\n+ self.assertLess(0.3, p)\n+\n+ def test_threading(self):\n+ """Test @layer_stack when the function gets per-layer state."""\n+ n = 5\n+\n+ @layer_stack.layer_stack(n, with_state=True)\n+ def f(x, y):\n+ x = x + y * jax.nn.one_hot(y, len(x)) / 10\n+ return x, 2 * y\n+\n+ @hk.without_apply_rng\n+ @hk.transform\n+ def g(x, ys):\n+ x, zs = f(x, ys)\n+ # Check here to catch issues at init time\n+ self.assertEqual(zs.shape, (n,))\n+ return x, zs\n+\n+ rng = jax.random.PRNGKey(7)\n+ x = np.zeros(n)\n+ ys = np.arange(n).astype(np.float32)\n+ params = g.init(rng, x, ys)\n+ x, zs = g.apply(params, x, ys)\n+ self.assertTrue(np.allclose(x, [0, .1, .2, .3, .4]))\n+ self.assertTrue(np.all(zs == 2 * ys))\n+\n+ def test_nested_stacks(self):\n+ def stack_fn(x):\n+ def layer_fn(x):\n+ return hk.Linear(100)(x)\n+\n+ outer_fn = layer_stack.layer_stack(10)(layer_fn)\n+\n+ layer_outer = layer_stack.layer_stack(20)(outer_fn)\n+ return layer_outer(x)\n+\n+ hk_mod = hk.transform(stack_fn)\n+ apply_rng, init_rng = jax.random.split(jax.random.PRNGKey(0))\n+\n+ params = hk_mod.init(init_rng, jnp.zeros([10, 100]))\n+\n+ hk_mod.apply(params, apply_rng, jnp.zeros([10, 100]))\n+\n+ p, = params.values()\n+\n+ assert p[\'w\'].shape == (10, 20, 100, 100)\n+ assert p[\'b\'].shape == (10, 20, 100)\n+\n+ def test_with_state_multi_args(self):\n+ """Test layer_stack with state with multiple arguments."""\n+ width = 4\n+ batch_size = 5\n+ stack_height = 3\n+\n+ def f_with_multi_args(x, a, b):\n+ return hk.Linear(\n+ width, w_init=hk.initializers.Constant(\n+ jnp.eye(width)))(x) * a + b, None\n+\n+ @hk.without_apply_rng\n+ @hk.transform\n+ def hk_fn(x):\n+ return layer_stack.layer_stack(\n+ stack_height,\n+ with_state=True)(f_with_multi_args)(x, jnp.full([stack_height], 2.),\n+ jnp.ones([stack_height]))\n+\n+ x = jnp.zeros([batch_size, width])\n+ key_seq = hk.PRNGSequence(19)\n+ params = hk_fn.init(next(key_seq), x)\n+ output, z = hk_fn.apply(params, x)\n+ self.assertIsNone(z)\n+ self.assertEqual(output.shape, (batch_size, width))\n+ np.testing.assert_equal(output, np.full([batch_size, width], 7.))\n+\n+ def test_with_container_state(self):\n+ width = 2\n+ batch_size = 2\n+ stack_height = 3\n+\n+ def f_with_container_state(x):\n+ hk_layer = hk.Linear(\n+ width, w_init=hk.initializers.Constant(jnp.eye(width)))\n+ layer_output = hk_layer(x)\n+ layer_state = {\n+ \'raw_output\': layer_output,\n+ \'output_projection\': jnp.sum(layer_output)\n+ }\n+ return layer_output + jnp.ones_like(layer_output), layer_state\n+\n+ @hk.without_apply_rng\n+ @hk.transform\n+ def hk_fn(x):\n+ return layer_stack.layer_stack(\n+ stack_height,\n+ with_state=True)(f_with_container_state)(x)\n+\n+ x = jnp.zeros([batch_size, width])\n+ key_seq = hk.PRNGSequence(19)\n+ params = hk_fn.init(next(key_seq), x)\n+ output, z = hk_fn.apply(params, x)\n+ self.assertEqual(z[\'raw_output\'].shape, (stack_height, batch_size, width))\n+ self.assertEqual(output.shape, (batch_size, width))\n+ self.assertEqual(z[\'output_projection\'].shape, (stack_height,))\n+ np.testing.assert_equal(np.sum(z[\'output_projection\']), np.array(12.))\n+ np.testing.assert_equal(\n+ np.all(z[\'raw_output\'] == np.array([0., 1., 2.])[..., None, None]),\n+ np.array(True))\n+\n+\n+if __name__ == \'__main__\':\n+ absltest.main()\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/lddt.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/lddt.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,88 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""lDDT protein distance score.""" +import jax.numpy as jnp + + +def lddt(predicted_points, + true_points, + true_points_mask, + cutoff=15., + per_residue=False): + """Measure (approximate) lDDT for a batch of coordinates. + + lDDT reference: + Mariani, V., Biasini, M., Barbato, A. & Schwede, T. lDDT: A local + superposition-free score for comparing protein structures and models using + distance difference tests. Bioinformatics 29, 2722–2728 (2013). + + lDDT is a measure of the difference between the true distance matrix and the + distance matrix of the predicted points. The difference is computed only on + points closer than cutoff *in the true structure*. + + This function does not compute the exact lDDT value that the original paper + describes because it does not include terms for physical feasibility + (e.g. bond length violations). Therefore this is only an approximate + lDDT score. + + Args: + predicted_points: (batch, length, 3) array of predicted 3D points + true_points: (batch, length, 3) array of true 3D points + true_points_mask: (batch, length, 1) binary-valued float array. This mask + should be 1 for points that exist in the true points. + cutoff: Maximum distance for a pair of points to be included + per_residue: If true, return score for each residue. Note that the overall + lDDT is not exactly the mean of the per_residue lDDT's because some + residues have more contacts than others. + + Returns: + An (approximate, see above) lDDT score in the range 0-1. + """ + + assert len(predicted_points.shape) == 3 + assert predicted_points.shape[-1] == 3 + assert true_points_mask.shape[-1] == 1 + assert len(true_points_mask.shape) == 3 + + # Compute true and predicted distance matrices. + dmat_true = jnp.sqrt(1e-10 + jnp.sum( + (true_points[:, :, None] - true_points[:, None, :])**2, axis=-1)) + + dmat_predicted = jnp.sqrt(1e-10 + jnp.sum( + (predicted_points[:, :, None] - + predicted_points[:, None, :])**2, axis=-1)) + + dists_to_score = ( + (dmat_true < cutoff).astype(jnp.float32) * true_points_mask * + jnp.transpose(true_points_mask, [0, 2, 1]) * + (1. - jnp.eye(dmat_true.shape[1])) # Exclude self-interaction. + ) + + # Shift unscored distances to be far away. + dist_l1 = jnp.abs(dmat_true - dmat_predicted) + + # True lDDT uses a number of fixed bins. + # We ignore the physical plausibility correction to lDDT, though. + score = 0.25 * ((dist_l1 < 0.5).astype(jnp.float32) + + (dist_l1 < 1.0).astype(jnp.float32) + + (dist_l1 < 2.0).astype(jnp.float32) + + (dist_l1 < 4.0).astype(jnp.float32)) + + # Normalize over the appropriate axes. + reduce_axes = (-1,) if per_residue else (-2, -1) + norm = 1. / (1e-10 + jnp.sum(dists_to_score, axis=reduce_axes)) + score = norm * (1e-10 + jnp.sum(dists_to_score * score, axis=reduce_axes)) + + return score |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/lddt_test.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/lddt_test.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,79 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for lddt.""" + +from absl.testing import absltest +from absl.testing import parameterized +from alphafold.model import lddt +import numpy as np + + +class LddtTest(parameterized.TestCase, absltest.TestCase): + + @parameterized.named_parameters( + ('same', + [[0, 0, 0], [5, 0, 0], [10, 0, 0]], + [[0, 0, 0], [5, 0, 0], [10, 0, 0]], + [1, 1, 1]), + ('all_shifted', + [[0, 0, 0], [5, 0, 0], [10, 0, 0]], + [[-1, 0, 0], [4, 0, 0], [9, 0, 0]], + [1, 1, 1]), + ('all_rotated', + [[0, 0, 0], [5, 0, 0], [10, 0, 0]], + [[0, 0, 0], [0, 5, 0], [0, 10, 0]], + [1, 1, 1]), + ('half_a_dist', + [[0, 0, 0], [5, 0, 0]], + [[0, 0, 0], [5.5-1e-5, 0, 0]], + [1, 1]), + ('one_a_dist', + [[0, 0, 0], [5, 0, 0]], + [[0, 0, 0], [6-1e-5, 0, 0]], + [0.75, 0.75]), + ('two_a_dist', + [[0, 0, 0], [5, 0, 0]], + [[0, 0, 0], [7-1e-5, 0, 0]], + [0.5, 0.5]), + ('four_a_dist', + [[0, 0, 0], [5, 0, 0]], + [[0, 0, 0], [9-1e-5, 0, 0]], + [0.25, 0.25],), + ('five_a_dist', + [[0, 0, 0], [16-1e-5, 0, 0]], + [[0, 0, 0], [11, 0, 0]], + [0, 0]), + ('no_pairs', + [[0, 0, 0], [20, 0, 0]], + [[0, 0, 0], [25-1e-5, 0, 0]], + [1, 1]), + ) + def test_lddt( + self, predicted_pos, true_pos, exp_lddt): + predicted_pos = np.array([predicted_pos], dtype=np.float32) + true_points_mask = np.array([[[1]] * len(true_pos)], dtype=np.float32) + true_pos = np.array([true_pos], dtype=np.float32) + cutoff = 15.0 + per_residue = True + + result = lddt.lddt( + predicted_pos, true_pos, true_points_mask, cutoff, + per_residue) + + np.testing.assert_almost_equal(result, [exp_lddt], decimal=4) + + +if __name__ == '__main__': + absltest.main() |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/mapping.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/mapping.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,218 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Specialized mapping functions.""" + +import functools + +from typing import Any, Callable, Optional, Sequence, Union + +import haiku as hk +import jax +import jax.numpy as jnp + + +PYTREE = Any +PYTREE_JAX_ARRAY = Any + +partial = functools.partial +PROXY = object() + + +def _maybe_slice(array, i, slice_size, axis): + if axis is PROXY: + return array + else: + return jax.lax.dynamic_slice_in_dim( + array, i, slice_size=slice_size, axis=axis) + + +def _maybe_get_size(array, axis): + if axis == PROXY: + return -1 + else: + return array.shape[axis] + + +def _expand_axes(axes, values, name='sharded_apply'): + values_tree_def = jax.tree_flatten(values)[1] + flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes) + # Replace None's with PROXY + flat_axes = [PROXY if x is None else x for x in flat_axes] + return jax.tree_unflatten(values_tree_def, flat_axes) + + +def sharded_map( + fun: Callable[..., PYTREE_JAX_ARRAY], + shard_size: Union[int, None] = 1, + in_axes: Union[int, PYTREE] = 0, + out_axes: Union[int, PYTREE] = 0) -> Callable[..., PYTREE_JAX_ARRAY]: + """Sharded vmap. + + Maps `fun` over axes, in a way similar to vmap, but does so in shards of + `shard_size`. This allows a smooth trade-off between memory usage + (as in a plain map) vs higher throughput (as in a vmap). + + Args: + fun: Function to apply smap transform to. + shard_size: Integer denoting shard size. + in_axes: Either integer or pytree describing which axis to map over for each + input to `fun`, None denotes broadcasting. + out_axes: integer or pytree denoting to what axis in the output the mapped + over axis maps. + + Returns: + function with smap applied. + """ + vmapped_fun = hk.vmap(fun, in_axes, out_axes) + return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes) + + +def sharded_apply( + fun: Callable[..., PYTREE_JAX_ARRAY], # pylint: disable=g-bare-generic + shard_size: Union[int, None] = 1, + in_axes: Union[int, PYTREE] = 0, + out_axes: Union[int, PYTREE] = 0, + new_out_axes: bool = False) -> Callable[..., PYTREE_JAX_ARRAY]: + """Sharded apply. + + Applies `fun` over shards to axes, in a way similar to vmap, + but does so in shards of `shard_size`. Shards are stacked after. + This allows a smooth trade-off between + memory usage (as in a plain map) vs higher throughput (as in a vmap). + + Args: + fun: Function to apply smap transform to. + shard_size: Integer denoting shard size. + in_axes: Either integer or pytree describing which axis to map over for each + input to `fun`, None denotes broadcasting. + out_axes: integer or pytree denoting to what axis in the output the mapped + over axis maps. + new_out_axes: whether to stack outputs on new axes. This assumes that the + output sizes for each shard (including the possible remainder shard) are + the same. + + Returns: + function with smap applied. + """ + docstr = ('Mapped version of {fun}. Takes similar arguments to {fun} ' + 'but with additional array axes over which {fun} is mapped.') + if new_out_axes: + raise NotImplementedError('New output axes not yet implemented.') + + # shard size None denotes no sharding + if shard_size is None: + return fun + + @jax.util.wraps(fun, docstr=docstr) + def mapped_fn(*args): + # Expand in axes and Determine Loop range + in_axes_ = _expand_axes(in_axes, args) + + in_sizes = jax.tree_multimap(_maybe_get_size, args, in_axes_) + flat_sizes = jax.tree_flatten(in_sizes)[0] + in_size = max(flat_sizes) + assert all(i in {in_size, -1} for i in flat_sizes) + + num_extra_shards = (in_size - 1) // shard_size + + # Fix Up if necessary + last_shard_size = in_size % shard_size + last_shard_size = shard_size if last_shard_size == 0 else last_shard_size + + def apply_fun_to_slice(slice_start, slice_size): + input_slice = jax.tree_multimap( + lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis + ), args, in_axes_) + return fun(*input_slice) + + remainder_shape_dtype = hk.eval_shape( + partial(apply_fun_to_slice, 0, last_shard_size)) + out_dtypes = jax.tree_map(lambda x: x.dtype, remainder_shape_dtype) + out_shapes = jax.tree_map(lambda x: x.shape, remainder_shape_dtype) + out_axes_ = _expand_axes(out_axes, remainder_shape_dtype) + + if num_extra_shards > 0: + regular_shard_shape_dtype = hk.eval_shape( + partial(apply_fun_to_slice, 0, shard_size)) + shard_shapes = jax.tree_map(lambda x: x.shape, regular_shard_shape_dtype) + + def make_output_shape(axis, shard_shape, remainder_shape): + return shard_shape[:axis] + ( + shard_shape[axis] * num_extra_shards + + remainder_shape[axis],) + shard_shape[axis + 1:] + + out_shapes = jax.tree_multimap(make_output_shape, out_axes_, shard_shapes, + out_shapes) + + # Calls dynamic Update slice with different argument order + # This is here since tree_multimap only works with positional arguments + def dynamic_update_slice_in_dim(full_array, update, axis, i): + return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis) + + def compute_shard(outputs, slice_start, slice_size): + slice_out = apply_fun_to_slice(slice_start, slice_size) + update_slice = partial( + dynamic_update_slice_in_dim, i=slice_start) + return jax.tree_multimap(update_slice, outputs, slice_out, out_axes_) + + def scan_iteration(outputs, i): + new_outputs = compute_shard(outputs, i, shard_size) + return new_outputs, () + + slice_starts = jnp.arange(0, in_size - shard_size + 1, shard_size) + + def allocate_buffer(dtype, shape): + return jnp.zeros(shape, dtype=dtype) + + outputs = jax.tree_multimap(allocate_buffer, out_dtypes, out_shapes) + + if slice_starts.shape[0] > 0: + outputs, _ = hk.scan(scan_iteration, outputs, slice_starts) + + if last_shard_size != shard_size: + remainder_start = in_size - last_shard_size + outputs = compute_shard(outputs, remainder_start, last_shard_size) + + return outputs + + return mapped_fn + + +def inference_subbatch( + module: Callable[..., PYTREE_JAX_ARRAY], + subbatch_size: int, + batched_args: Sequence[PYTREE_JAX_ARRAY], + nonbatched_args: Sequence[PYTREE_JAX_ARRAY], + low_memory: bool = True, + input_subbatch_dim: int = 0, + output_subbatch_dim: Optional[int] = None) -> PYTREE_JAX_ARRAY: + """Run through subbatches (like batch apply but with split and concat).""" + assert len(batched_args) > 0 # pylint: disable=g-explicit-length-test + + if not low_memory: + args = list(batched_args) + list(nonbatched_args) + return module(*args) + + if output_subbatch_dim is None: + output_subbatch_dim = input_subbatch_dim + + def run_module(*batched_args): + args = list(batched_args) + list(nonbatched_args) + return module(*args) + sharded_module = sharded_apply(run_module, + shard_size=subbatch_size, + in_axes=input_subbatch_dim, + out_axes=output_subbatch_dim) + return sharded_module(*batched_args) |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/model.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/model.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,177 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Code for constructing the model.""" +from typing import Any, Mapping, Optional, Union + +from absl import logging +from alphafold.common import confidence +from alphafold.model import features +from alphafold.model import modules +from alphafold.model import modules_multimer +import haiku as hk +import jax +import ml_collections +import numpy as np +import tensorflow.compat.v1 as tf +import tree + + +def get_confidence_metrics( + prediction_result: Mapping[str, Any], + multimer_mode: bool) -> Mapping[str, Any]: + """Post processes prediction_result to get confidence metrics.""" + confidence_metrics = {} + confidence_metrics['plddt'] = confidence.compute_plddt( + prediction_result['predicted_lddt']['logits']) + if 'predicted_aligned_error' in prediction_result: + confidence_metrics.update(confidence.compute_predicted_aligned_error( + logits=prediction_result['predicted_aligned_error']['logits'], + breaks=prediction_result['predicted_aligned_error']['breaks'])) + confidence_metrics['ptm'] = confidence.predicted_tm_score( + logits=prediction_result['predicted_aligned_error']['logits'], + breaks=prediction_result['predicted_aligned_error']['breaks'], + asym_id=None) + if multimer_mode: + # Compute the ipTM only for the multimer model. + confidence_metrics['iptm'] = confidence.predicted_tm_score( + logits=prediction_result['predicted_aligned_error']['logits'], + breaks=prediction_result['predicted_aligned_error']['breaks'], + asym_id=prediction_result['predicted_aligned_error']['asym_id'], + interface=True) + confidence_metrics['ranking_confidence'] = ( + 0.8 * confidence_metrics['iptm'] + 0.2 * confidence_metrics['ptm']) + + if not multimer_mode: + # Monomer models use mean pLDDT for model ranking. + confidence_metrics['ranking_confidence'] = np.mean( + confidence_metrics['plddt']) + + return confidence_metrics + + +class RunModel: + """Container for JAX model.""" + + def __init__(self, + config: ml_collections.ConfigDict, + params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None): + self.config = config + self.params = params + self.multimer_mode = config.model.global_config.multimer_mode + + if self.multimer_mode: + def _forward_fn(batch): + model = modules_multimer.AlphaFold(self.config.model) + return model( + batch, + is_training=False) + else: + def _forward_fn(batch): + model = modules.AlphaFold(self.config.model) + return model( + batch, + is_training=False, + compute_loss=False, + ensemble_representations=True) + + self.apply = jax.jit(hk.transform(_forward_fn).apply) + self.init = jax.jit(hk.transform(_forward_fn).init) + + def init_params(self, feat: features.FeatureDict, random_seed: int = 0): + """Initializes the model parameters. + + If none were provided when this class was instantiated then the parameters + are randomly initialized. + + Args: + feat: A dictionary of NumPy feature arrays as output by + RunModel.process_features. + random_seed: A random seed to use to initialize the parameters if none + were set when this class was initialized. + """ + if not self.params: + # Init params randomly. + rng = jax.random.PRNGKey(random_seed) + self.params = hk.data_structures.to_mutable_dict( + self.init(rng, feat)) + logging.warning('Initialized parameters randomly') + + def process_features( + self, + raw_features: Union[tf.train.Example, features.FeatureDict], + random_seed: int) -> features.FeatureDict: + """Processes features to prepare for feeding them into the model. + + Args: + raw_features: The output of the data pipeline either as a dict of NumPy + arrays or as a tf.train.Example. + random_seed: The random seed to use when processing the features. + + Returns: + A dict of NumPy feature arrays suitable for feeding into the model. + """ + + if self.multimer_mode: + return raw_features + + # Single-chain mode. + if isinstance(raw_features, dict): + return features.np_example_to_features( + np_example=raw_features, + config=self.config, + random_seed=random_seed) + else: + return features.tf_example_to_features( + tf_example=raw_features, + config=self.config, + random_seed=random_seed) + + def eval_shape(self, feat: features.FeatureDict) -> jax.ShapeDtypeStruct: + self.init_params(feat) + logging.info('Running eval_shape with shape(feat) = %s', + tree.map_structure(lambda x: x.shape, feat)) + shape = jax.eval_shape(self.apply, self.params, jax.random.PRNGKey(0), feat) + logging.info('Output shape was %s', shape) + return shape + + def predict(self, + feat: features.FeatureDict, + random_seed: int, + ) -> Mapping[str, Any]: + """Makes a prediction by inferencing the model on the provided features. + + Args: + feat: A dictionary of NumPy feature arrays as output by + RunModel.process_features. + random_seed: The random seed to use when running the model. In the + multimer model this controls the MSA sampling. + + Returns: + A dictionary of model outputs. + """ + self.init_params(feat) + logging.info('Running predict with shape(feat) = %s', + tree.map_structure(lambda x: x.shape, feat)) + result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat) + + # This block is to ensure benchmark timings are accurate. Some blocking is + # already happening when computing get_confidence_metrics, and this ensures + # all outputs are blocked on. + jax.tree_map(lambda x: x.block_until_ready(), result) + result.update( + get_confidence_metrics(result, multimer_mode=self.multimer_mode)) + logging.info('Output shape was %s', + tree.map_structure(lambda x: x.shape, result)) + return result |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/modules.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/modules.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,2105 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Modules and code used in the core part of AlphaFold.\n+\n+The structure generation code is in \'folding.py\'.\n+"""\n+import functools\n+from alphafold.common import residue_constants\n+from alphafold.model import all_atom\n+from alphafold.model import common_modules\n+from alphafold.model import folding\n+from alphafold.model import layer_stack\n+from alphafold.model import lddt\n+from alphafold.model import mapping\n+from alphafold.model import prng\n+from alphafold.model import quat_affine\n+from alphafold.model import utils\n+import haiku as hk\n+import jax\n+import jax.numpy as jnp\n+\n+\n+def softmax_cross_entropy(logits, labels):\n+ """Computes softmax cross entropy given logits and one-hot class labels."""\n+ loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)\n+ return jnp.asarray(loss)\n+\n+\n+def sigmoid_cross_entropy(logits, labels):\n+ """Computes sigmoid cross entropy given logits and multiple class labels."""\n+ log_p = jax.nn.log_sigmoid(logits)\n+ # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter is more numerically stable\n+ log_not_p = jax.nn.log_sigmoid(-logits)\n+ loss = -labels * log_p - (1. - labels) * log_not_p\n+ return jnp.asarray(loss)\n+\n+\n+def apply_dropout(*, tensor, safe_key, rate, is_training, broadcast_dim=None):\n+ """Applies dropout to a tensor."""\n+ if is_training and rate != 0.0:\n+ shape = list(tensor.shape)\n+ if broadcast_dim is not None:\n+ shape[broadcast_dim] = 1\n+ keep_rate = 1.0 - rate\n+ keep = jax.random.bernoulli(safe_key.get(), keep_rate, shape=shape)\n+ return keep * tensor / keep_rate\n+ else:\n+ return tensor\n+\n+\n+def dropout_wrapper(module,\n+ input_act,\n+ mask,\n+ safe_key,\n+ global_config,\n+ output_act=None,\n+ is_training=True,\n+ **kwargs):\n+ """Applies module + dropout + residual update."""\n+ if output_act is None:\n+ output_act = input_act\n+\n+ gc = global_config\n+ residual = module(input_act, mask, is_training=is_training, **kwargs)\n+ dropout_rate = 0.0 if gc.deterministic else module.config.dropout_rate\n+\n+ if module.config.shared_dropout:\n+ if module.config.orientation == \'per_row\':\n+ broadcast_dim = 0\n+ else:\n+ broadcast_dim = 1\n+ else:\n+ broadcast_dim = None\n+\n+ residual = apply_dropout(tensor=residual,\n+ safe_key=safe_key,\n+ rate=dropout_rate,\n+ is_training=is_training,\n+ broadcast_dim=broadcast_dim)\n+\n+ new_act = output_act + residual\n+\n+ return new_act\n+\n+\n+def create_extra_msa_feature(batch):\n+ """Expand extra_msa into 1hot and concat with other extra msa features.\n+\n+ We do this as late as possible as the one_hot extra msa can be very large.\n+\n+ Arguments:\n+ batch: a dictionary with the following keys:\n+ * \'extra_msa\': [N_extra_seq, N_res] MSA that wasn\'t selected as a cluster\n+ centre. Note, that this is not one-hot encoded.\n+ * \'extra_has_deletion\': [N_extra_seq, N_res] Whether there is a deletion to\n+ the left of each position in the extra MSA.\n+ * \'extra_deletion_value\': [N_extra_seq, N_res] The number of deletions to\n+ the left of each position in the extra MSA.\n+\n+ Returns:\n+ Concatenated tensor of extra MSA features.\n+ """\n+ # 2'..b'nit_vector]\n+ template_mask_2d = template_mask_2d.astype(dtype)\n+ if not self.config.use_template_unit_vector:\n+ unit_vector = [jnp.zeros_like(x) for x in unit_vector]\n+ to_concat.extend(unit_vector)\n+\n+ to_concat.append(template_mask_2d[..., None])\n+\n+ act = jnp.concatenate(to_concat, axis=-1)\n+\n+ # Mask out non-template regions so we don\'t get arbitrary values in the\n+ # distogram for these regions.\n+ act *= template_mask_2d[..., None]\n+\n+ # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 9\n+ act = common_modules.Linear(\n+ num_channels,\n+ initializer=\'relu\',\n+ name=\'embedding2d\')(\n+ act)\n+\n+ # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 11\n+ act = TemplatePairStack(\n+ self.config.template_pair_stack, self.global_config)(\n+ act, mask_2d, is_training)\n+\n+ act = hk.LayerNorm([-1], True, True, name=\'output_layer_norm\')(act)\n+ return act\n+\n+\n+class TemplateEmbedding(hk.Module):\n+ """Embeds a set of templates.\n+\n+ Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12\n+ Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention"\n+ """\n+\n+ def __init__(self, config, global_config, name=\'template_embedding\'):\n+ super().__init__(name=name)\n+ self.config = config\n+ self.global_config = global_config\n+\n+ def __call__(self, query_embedding, template_batch, mask_2d, is_training):\n+ """Build TemplateEmbedding module.\n+\n+ Arguments:\n+ query_embedding: Query pair representation, shape [N_res, N_res, c_z].\n+ template_batch: A batch of template features.\n+ mask_2d: Padding mask (Note: this doesn\'t care if a template exists,\n+ unlike the template_pseudo_beta_mask).\n+ is_training: Whether the module is in training mode.\n+\n+ Returns:\n+ A template embedding [N_res, N_res, c_z].\n+ """\n+\n+ num_templates = template_batch[\'template_mask\'].shape[0]\n+ num_channels = (self.config.template_pair_stack\n+ .triangle_attention_ending_node.value_dim)\n+ num_res = query_embedding.shape[0]\n+\n+ dtype = query_embedding.dtype\n+ template_mask = template_batch[\'template_mask\']\n+ template_mask = template_mask.astype(dtype)\n+\n+ query_num_channels = query_embedding.shape[-1]\n+\n+ # Make sure the weights are shared across templates by constructing the\n+ # embedder here.\n+ # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12\n+ template_embedder = SingleTemplateEmbedding(self.config, self.global_config)\n+\n+ def map_fn(batch):\n+ return template_embedder(query_embedding, batch, mask_2d, is_training)\n+\n+ template_pair_representation = mapping.sharded_map(map_fn, in_axes=0)(\n+ template_batch)\n+\n+ # Cross attend from the query to the templates along the residue\n+ # dimension by flattening everything else into the batch dimension.\n+ # Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention"\n+ flat_query = jnp.reshape(query_embedding,\n+ [num_res * num_res, 1, query_num_channels])\n+\n+ flat_templates = jnp.reshape(\n+ jnp.transpose(template_pair_representation, [1, 2, 0, 3]),\n+ [num_res * num_res, num_templates, num_channels])\n+\n+ bias = (1e9 * (template_mask[None, None, None, :] - 1.))\n+\n+ template_pointwise_attention_module = Attention(\n+ self.config.attention, self.global_config, query_num_channels)\n+ nonbatched_args = [bias]\n+ batched_args = [flat_query, flat_templates]\n+\n+ embedding = mapping.inference_subbatch(\n+ template_pointwise_attention_module,\n+ self.config.subbatch_size,\n+ batched_args=batched_args,\n+ nonbatched_args=nonbatched_args,\n+ low_memory=not is_training)\n+ embedding = jnp.reshape(embedding,\n+ [num_res, num_res, query_num_channels])\n+\n+ # No gradients if no templates.\n+ embedding *= (jnp.sum(template_mask) > 0.).astype(embedding.dtype)\n+\n+ return embedding\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/modules_multimer.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/modules_multimer.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,1129 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Core modules, which have been refactored in AlphaFold-Multimer.\n+\n+The main difference is that MSA sampling pipeline is moved inside the JAX model\n+for easier implementation of recycling and ensembling.\n+\n+Lower-level modules up to EvoformerIteration are reused from modules.py.\n+"""\n+\n+import functools\n+from typing import Sequence\n+\n+from alphafold.common import residue_constants\n+from alphafold.model import all_atom_multimer\n+from alphafold.model import common_modules\n+from alphafold.model import folding_multimer\n+from alphafold.model import geometry\n+from alphafold.model import layer_stack\n+from alphafold.model import modules\n+from alphafold.model import prng\n+from alphafold.model import utils\n+\n+import haiku as hk\n+import jax\n+import jax.numpy as jnp\n+import numpy as np\n+\n+\n+def reduce_fn(x, mode):\n+ if mode == \'none\' or mode is None:\n+ return jnp.asarray(x)\n+ elif mode == \'sum\':\n+ return jnp.asarray(x).sum()\n+ elif mode == \'mean\':\n+ return jnp.mean(jnp.asarray(x))\n+ else:\n+ raise ValueError(\'Unsupported reduction option.\')\n+\n+\n+def gumbel_noise(key: jnp.ndarray, shape: Sequence[int]) -> jnp.ndarray:\n+ """Generate Gumbel Noise of given Shape.\n+\n+ This generates samples from Gumbel(0, 1).\n+\n+ Args:\n+ key: Jax random number key.\n+ shape: Shape of noise to return.\n+\n+ Returns:\n+ Gumbel noise of given shape.\n+ """\n+ epsilon = 1e-6\n+ uniform = utils.padding_consistent_rng(jax.random.uniform)\n+ uniform_noise = uniform(\n+ key, shape=shape, dtype=jnp.float32, minval=0., maxval=1.)\n+ gumbel = -jnp.log(-jnp.log(uniform_noise + epsilon) + epsilon)\n+ return gumbel\n+\n+\n+def gumbel_max_sample(key: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray:\n+ """Samples from a probability distribution given by \'logits\'.\n+\n+ This uses Gumbel-max trick to implement the sampling in an efficient manner.\n+\n+ Args:\n+ key: prng key.\n+ logits: Logarithm of probabilities to sample from, probabilities can be\n+ unnormalized.\n+\n+ Returns:\n+ Sample from logprobs in one-hot form.\n+ """\n+ z = gumbel_noise(key, logits.shape)\n+ return jax.nn.one_hot(\n+ jnp.argmax(logits + z, axis=-1),\n+ logits.shape[-1],\n+ dtype=logits.dtype)\n+\n+\n+def gumbel_argsort_sample_idx(key: jnp.ndarray,\n+ logits: jnp.ndarray) -> jnp.ndarray:\n+ """Samples with replacement from a distribution given by \'logits\'.\n+\n+ This uses Gumbel trick to implement the sampling an efficient manner. For a\n+ distribution over k items this samples k times without replacement, so this\n+ is effectively sampling a random permutation with probabilities over the\n+ permutations derived from the logprobs.\n+\n+ Args:\n+ key: prng key.\n+ logits: Logarithm of probabilities to sample from, probabilities can be\n+ unnormalized.\n+\n+ Returns:\n+ Sample from logprobs in one-hot form.\n+ """\n+ z = gumbel_noise(key, logits.shape)\n+ # This construction is equivalent to jnp.argsort, but using a non stable sort,\n+ # since stable sort\'s aren\'t supported by jax2tf.\n+ axis = len(logits.shape) - 1\n+ iota = jax.lax.broadcasted_iota(jnp.int64, logits.shape, axis)\n+ _, perm = jax.lax.sort_key_val(\n+ logits + z, iota, dimension=-1, is_stable=False)\n+ return perm[::-1]\n+\n+\n+def make_masked_msa(batch, key, config, epsilon=1e-6):\n+ """Create data for BERT on raw MSA."""\n+ # Add a random'..b'the template embedder.\n+\n+ Args:\n+ act: [num_res, num_res, num_channel] Input pairwise activations.\n+ pair_mask: [num_res, num_res] padding mask.\n+ is_training: Whether to run in training mode.\n+ safe_key: Safe pseudo-random generator key.\n+\n+ Returns:\n+ [num_res, num_res, num_channel] tensor of activations.\n+ """\n+ c = self.config\n+ gc = self.global_config\n+\n+ if safe_key is None:\n+ safe_key = prng.SafeKey(hk.next_rng_key())\n+\n+ dropout_wrapper_fn = functools.partial(\n+ modules.dropout_wrapper,\n+ is_training=is_training,\n+ global_config=gc)\n+\n+ safe_key, *sub_keys = safe_key.split(20)\n+ sub_keys = iter(sub_keys)\n+\n+ act = dropout_wrapper_fn(\n+ modules.TriangleMultiplication(c.triangle_multiplication_outgoing, gc,\n+ name=\'triangle_multiplication_outgoing\'),\n+ act,\n+ pair_mask,\n+ safe_key=next(sub_keys))\n+\n+ act = dropout_wrapper_fn(\n+ modules.TriangleMultiplication(c.triangle_multiplication_incoming, gc,\n+ name=\'triangle_multiplication_incoming\'),\n+ act,\n+ pair_mask,\n+ safe_key=next(sub_keys))\n+\n+ act = dropout_wrapper_fn(\n+ modules.TriangleAttention(c.triangle_attention_starting_node, gc,\n+ name=\'triangle_attention_starting_node\'),\n+ act,\n+ pair_mask,\n+ safe_key=next(sub_keys))\n+\n+ act = dropout_wrapper_fn(\n+ modules.TriangleAttention(c.triangle_attention_ending_node, gc,\n+ name=\'triangle_attention_ending_node\'),\n+ act,\n+ pair_mask,\n+ safe_key=next(sub_keys))\n+\n+ act = dropout_wrapper_fn(\n+ modules.Transition(c.pair_transition, gc,\n+ name=\'pair_transition\'),\n+ act,\n+ pair_mask,\n+ safe_key=next(sub_keys))\n+\n+ return act\n+\n+\n+def template_embedding_1d(batch, num_channel):\n+ """Embed templates into an (num_res, num_templates, num_channels) embedding.\n+\n+ Args:\n+ batch: A batch containing:\n+ template_aatype, (num_templates, num_res) aatype for the templates.\n+ template_all_atom_positions, (num_templates, num_residues, 37, 3) atom\n+ positions for the templates.\n+ template_all_atom_mask, (num_templates, num_residues, 37) atom mask for\n+ each template.\n+ num_channel: The number of channels in the output.\n+\n+ Returns:\n+ An embedding of shape (num_templates, num_res, num_channels) and a mask of\n+ shape (num_templates, num_res).\n+ """\n+\n+ # Embed the templates aatypes.\n+ aatype_one_hot = jax.nn.one_hot(batch[\'template_aatype\'], 22, axis=-1)\n+\n+ num_templates = batch[\'template_aatype\'].shape[0]\n+ all_chi_angles = []\n+ all_chi_masks = []\n+ for i in range(num_templates):\n+ atom_pos = geometry.Vec3Array.from_array(\n+ batch[\'template_all_atom_positions\'][i, :, :, :])\n+ template_chi_angles, template_chi_mask = all_atom_multimer.compute_chi_angles(\n+ atom_pos,\n+ batch[\'template_all_atom_mask\'][i, :, :],\n+ batch[\'template_aatype\'][i, :])\n+ all_chi_angles.append(template_chi_angles)\n+ all_chi_masks.append(template_chi_mask)\n+ chi_angles = jnp.stack(all_chi_angles, axis=0)\n+ chi_mask = jnp.stack(all_chi_masks, axis=0)\n+\n+ template_features = jnp.concatenate([\n+ aatype_one_hot,\n+ jnp.sin(chi_angles) * chi_mask,\n+ jnp.cos(chi_angles) * chi_mask,\n+ chi_mask], axis=-1)\n+\n+ template_mask = chi_mask[:, :, 0]\n+\n+ template_activations = common_modules.Linear(\n+ num_channel,\n+ initializer=\'relu\',\n+ name=\'template_single_embedding\')(\n+ template_features)\n+ template_activations = jax.nn.relu(template_activations)\n+ template_activations = common_modules.Linear(\n+ num_channel,\n+ initializer=\'relu\',\n+ name=\'template_projection\')(\n+ template_activations)\n+ return template_activations, template_mask\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/prng.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/prng.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,69 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A collection of utilities surrounding PRNG usage in protein folding.""" + +import haiku as hk +import jax + + +def safe_dropout(*, tensor, safe_key, rate, is_deterministic, is_training): + if is_training and rate != 0.0 and not is_deterministic: + return hk.dropout(safe_key.get(), rate, tensor) + else: + return tensor + + +class SafeKey: + """Safety wrapper for PRNG keys.""" + + def __init__(self, key): + self._key = key + self._used = False + + def _assert_not_used(self): + if self._used: + raise RuntimeError('Random key has been used previously.') + + def get(self): + self._assert_not_used() + self._used = True + return self._key + + def split(self, num_keys=2): + self._assert_not_used() + self._used = True + new_keys = jax.random.split(self._key, num_keys) + return jax.tree_map(SafeKey, tuple(new_keys)) + + def duplicate(self, num_keys=2): + self._assert_not_used() + self._used = True + return tuple(SafeKey(self._key) for _ in range(num_keys)) + + +def _safe_key_flatten(safe_key): + # Flatten transfers "ownership" to the tree + return (safe_key._key,), safe_key._used # pylint: disable=protected-access + + +def _safe_key_unflatten(aux_data, children): + ret = SafeKey(children[0]) + ret._used = aux_data # pylint: disable=protected-access + return ret + + +jax.tree_util.register_pytree_node( + SafeKey, _safe_key_flatten, _safe_key_unflatten) + |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/prng_test.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/prng_test.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,46 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for prng.""" + +from absl.testing import absltest +from alphafold.model import prng +import jax + + +class PrngTest(absltest.TestCase): + + def test_key_reuse(self): + + init_key = jax.random.PRNGKey(42) + safe_key = prng.SafeKey(init_key) + _, safe_key = safe_key.split() + + raw_key = safe_key.get() + + self.assertNotEqual(raw_key[0], init_key[0]) + self.assertNotEqual(raw_key[1], init_key[1]) + + with self.assertRaises(RuntimeError): + safe_key.get() + + with self.assertRaises(RuntimeError): + safe_key.split() + + with self.assertRaises(RuntimeError): + safe_key.duplicate() + + +if __name__ == '__main__': + absltest.main() |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/quat_affine.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/quat_affine.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,459 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Quaternion geometry modules.\n+\n+This introduces a representation of coordinate frames that is based around a\n+\xe2\x80\x98QuatAffine\xe2\x80\x99 object. This object describes an array of coordinate frames.\n+It consists of vectors corresponding to the\n+origin of the frames as well as orientations which are stored in two\n+ways, as unit quaternions as well as a rotation matrices.\n+The rotation matrices are derived from the unit quaternions and the two are kept\n+in sync.\n+For an explanation of the relation between unit quaternions and rotations see\n+https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation\n+\n+This representation is used in the model for the backbone frames.\n+\n+One important thing to note here, is that while we update both representations\n+the jit compiler is going to ensure that only the parts that are\n+actually used are executed.\n+"""\n+\n+\n+import functools\n+from typing import Tuple\n+\n+import jax\n+import jax.numpy as jnp\n+import numpy as np\n+\n+# pylint: disable=bad-whitespace\n+QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32)\n+\n+QUAT_TO_ROT[0, 0] = [[ 1, 0, 0], [ 0, 1, 0], [ 0, 0, 1]] # rr\n+QUAT_TO_ROT[1, 1] = [[ 1, 0, 0], [ 0,-1, 0], [ 0, 0,-1]] # ii\n+QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [ 0, 1, 0], [ 0, 0,-1]] # jj\n+QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [ 0,-1, 0], [ 0, 0, 1]] # kk\n+\n+QUAT_TO_ROT[1, 2] = [[ 0, 2, 0], [ 2, 0, 0], [ 0, 0, 0]] # ij\n+QUAT_TO_ROT[1, 3] = [[ 0, 0, 2], [ 0, 0, 0], [ 2, 0, 0]] # ik\n+QUAT_TO_ROT[2, 3] = [[ 0, 0, 0], [ 0, 0, 2], [ 0, 2, 0]] # jk\n+\n+QUAT_TO_ROT[0, 1] = [[ 0, 0, 0], [ 0, 0,-2], [ 0, 2, 0]] # ir\n+QUAT_TO_ROT[0, 2] = [[ 0, 0, 2], [ 0, 0, 0], [-2, 0, 0]] # jr\n+QUAT_TO_ROT[0, 3] = [[ 0,-2, 0], [ 2, 0, 0], [ 0, 0, 0]] # kr\n+\n+QUAT_MULTIPLY = np.zeros((4, 4, 4), dtype=np.float32)\n+QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0],\n+ [ 0,-1, 0, 0],\n+ [ 0, 0,-1, 0],\n+ [ 0, 0, 0,-1]]\n+\n+QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0],\n+ [ 1, 0, 0, 0],\n+ [ 0, 0, 0, 1],\n+ [ 0, 0,-1, 0]]\n+\n+QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0],\n+ [ 0, 0, 0,-1],\n+ [ 1, 0, 0, 0],\n+ [ 0, 1, 0, 0]]\n+\n+QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],\n+ [ 0, 0, 1, 0],\n+ [ 0,-1, 0, 0],\n+ [ 1, 0, 0, 0]]\n+\n+QUAT_MULTIPLY_BY_VEC = QUAT_MULTIPLY[:, 1:, :]\n+# pylint: enable=bad-whitespace\n+\n+\n+def rot_to_quat(rot, unstack_inputs=False):\n+ """Convert rotation matrix to quaternion.\n+\n+ Note that this function calls self_adjoint_eig which is extremely expensive on\n+ the GPU. If at all possible, this function should run on the CPU.\n+\n+ Args:\n+ rot: rotation matrix (see below for format).\n+ unstack_inputs: If true, rotation matrix should be shape (..., 3, 3)\n+ otherwise the rotation matrix should be a list of lists of tensors.\n+\n+ Returns:\n+ Quaternion as (..., 4) tensor.\n+ """\n+ if unstack_inputs:\n+ rot = [jnp.moveaxis(x, -1, 0) for x in jnp.moveaxis(rot, -2, 0)]\n+\n+ [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot\n+\n+ # pylint: disable=bad-whitespace\n+ k = [[ xx + yy + zz, zy - yz, xz - zx, yx - xy,],\n+ [ zy - yz, xx - yy - zz, xy + yx, xz + zx,],\n+ [ xz -'..b'yz coordinates.\n+ c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.\n+\n+ Returns:\n+ A tuple (translation, rotation) where:\n+ translation is an array of shape [batch, 3] defining the translation.\n+ rotation is an array of shape [batch, 3, 3] defining the rotation.\n+ After applying the translation and rotation to all atoms in a residue:\n+ * All atoms will be shifted so that CA is at the origin,\n+ * All atoms will be rotated so that C is at the x-axis,\n+ * All atoms will be shifted so that N is in the xy plane.\n+ """\n+ assert len(n_xyz.shape) == 2, n_xyz.shape\n+ assert n_xyz.shape[-1] == 3, n_xyz.shape\n+ assert n_xyz.shape == ca_xyz.shape == c_xyz.shape, (\n+ n_xyz.shape, ca_xyz.shape, c_xyz.shape)\n+\n+ # Place CA at the origin.\n+ translation = -ca_xyz\n+ n_xyz = n_xyz + translation\n+ c_xyz = c_xyz + translation\n+\n+ # Place C on the x-axis.\n+ c_x, c_y, c_z = [c_xyz[:, i] for i in range(3)]\n+ # Rotate by angle c1 in the x-y plane (around the z-axis).\n+ sin_c1 = -c_y / jnp.sqrt(1e-20 + c_x**2 + c_y**2)\n+ cos_c1 = c_x / jnp.sqrt(1e-20 + c_x**2 + c_y**2)\n+ zeros = jnp.zeros_like(sin_c1)\n+ ones = jnp.ones_like(sin_c1)\n+ # pylint: disable=bad-whitespace\n+ c1_rot_matrix = jnp.stack([jnp.array([cos_c1, -sin_c1, zeros]),\n+ jnp.array([sin_c1, cos_c1, zeros]),\n+ jnp.array([zeros, zeros, ones])])\n+\n+ # Rotate by angle c2 in the x-z plane (around the y-axis).\n+ sin_c2 = c_z / jnp.sqrt(1e-20 + c_x**2 + c_y**2 + c_z**2)\n+ cos_c2 = jnp.sqrt(c_x**2 + c_y**2) / jnp.sqrt(\n+ 1e-20 + c_x**2 + c_y**2 + c_z**2)\n+ c2_rot_matrix = jnp.stack([jnp.array([cos_c2, zeros, sin_c2]),\n+ jnp.array([zeros, ones, zeros]),\n+ jnp.array([-sin_c2, zeros, cos_c2])])\n+\n+ c_rot_matrix = _multiply(c2_rot_matrix, c1_rot_matrix)\n+ n_xyz = jnp.stack(apply_rot_to_vec(c_rot_matrix, n_xyz, unstack=True)).T\n+\n+ # Place N in the x-y plane.\n+ _, n_y, n_z = [n_xyz[:, i] for i in range(3)]\n+ # Rotate by angle alpha in the y-z plane (around the x-axis).\n+ sin_n = -n_z / jnp.sqrt(1e-20 + n_y**2 + n_z**2)\n+ cos_n = n_y / jnp.sqrt(1e-20 + n_y**2 + n_z**2)\n+ n_rot_matrix = jnp.stack([jnp.array([ones, zeros, zeros]),\n+ jnp.array([zeros, cos_n, -sin_n]),\n+ jnp.array([zeros, sin_n, cos_n])])\n+ # pylint: enable=bad-whitespace\n+\n+ return (translation,\n+ jnp.transpose(_multiply(n_rot_matrix, c_rot_matrix), [2, 0, 1]))\n+\n+\n+def make_transform_from_reference(\n+ n_xyz: jnp.ndarray,\n+ ca_xyz: jnp.ndarray,\n+ c_xyz: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:\n+ """Returns rotation and translation matrices to convert from reference.\n+\n+ Note that this method does not take care of symmetries. If you provide the\n+ atom positions in the non-standard way, the N atom will end up not at\n+ [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You\n+ need to take care of such cases in your code.\n+\n+ Args:\n+ n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates.\n+ ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates.\n+ c_xyz: An array of shape [batch, 3] of carbon xyz coordinates.\n+\n+ Returns:\n+ A tuple (rotation, translation) where:\n+ rotation is an array of shape [batch, 3, 3] defining the rotation.\n+ translation is an array of shape [batch, 3] defining the translation.\n+ After applying the translation and rotation to the reference backbone,\n+ the coordinates will approximately equal to the input coordinates.\n+\n+ The order of translation and rotation differs from make_canonical_transform\n+ because the rotation from this function should be applied before the\n+ translation, unlike make_canonical_transform.\n+ """\n+ translation, rotation = make_canonical_transform(n_xyz, ca_xyz, c_xyz)\n+ return np.transpose(rotation, (0, 2, 1)), -translation\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/quat_affine_test.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/quat_affine_test.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,150 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for quat_affine.""" + +from absl import logging +from absl.testing import absltest +from alphafold.model import quat_affine +import jax +import jax.numpy as jnp +import numpy as np + +VERBOSE = False +np.set_printoptions(precision=3, suppress=True) + +r2t = quat_affine.rot_list_to_tensor +v2t = quat_affine.vec_list_to_tensor + +q2r = lambda q: r2t(quat_affine.quat_to_rot(q)) + + +class QuatAffineTest(absltest.TestCase): + + def _assert_check(self, to_check, tol=1e-5): + for k, (correct, generated) in to_check.items(): + if VERBOSE: + logging.info(k) + logging.info('Correct %s', correct) + logging.info('Predicted %s', generated) + self.assertLess(np.max(np.abs(correct - generated)), tol) + + def test_conversion(self): + quat = jnp.array([-2., 5., -1., 4.]) + + rotation = jnp.array([ + [0.26087, 0.130435, 0.956522], + [-0.565217, -0.782609, 0.26087], + [0.782609, -0.608696, -0.130435]]) + + translation = jnp.array([1., -3., 4.]) + point = jnp.array([0.7, 3.2, -2.9]) + + a = quat_affine.QuatAffine(quat, translation, unstack_inputs=True) + true_new_point = jnp.matmul(rotation, point[:, None])[:, 0] + translation + + self._assert_check({ + 'rot': (rotation, r2t(a.rotation)), + 'trans': (translation, v2t(a.translation)), + 'point': (true_new_point, + v2t(a.apply_to_point(jnp.moveaxis(point, -1, 0)))), + # Because of the double cover, we must be careful and compare rotations + 'quat': (q2r(a.quaternion), + q2r(quat_affine.rot_to_quat(a.rotation))), + + }) + + def test_double_cover(self): + """Test that -q is the same rotation as q.""" + rng = jax.random.PRNGKey(42) + keys = jax.random.split(rng) + q = jax.random.normal(keys[0], (2, 4)) + trans = jax.random.normal(keys[1], (2, 3)) + a1 = quat_affine.QuatAffine(q, trans, unstack_inputs=True) + a2 = quat_affine.QuatAffine(-q, trans, unstack_inputs=True) + + self._assert_check({ + 'rot': (r2t(a1.rotation), + r2t(a2.rotation)), + 'trans': (v2t(a1.translation), + v2t(a2.translation)), + }) + + def test_homomorphism(self): + rng = jax.random.PRNGKey(42) + keys = jax.random.split(rng, 4) + vec_q1 = jax.random.normal(keys[0], (2, 3)) + + q1 = jnp.concatenate([ + jnp.ones_like(vec_q1)[:, :1], + vec_q1], axis=-1) + + q2 = jax.random.normal(keys[1], (2, 4)) + t1 = jax.random.normal(keys[2], (2, 3)) + t2 = jax.random.normal(keys[3], (2, 3)) + + a1 = quat_affine.QuatAffine(q1, t1, unstack_inputs=True) + a2 = quat_affine.QuatAffine(q2, t2, unstack_inputs=True) + a21 = a2.pre_compose(jnp.concatenate([vec_q1, t1], axis=-1)) + + rng, key = jax.random.split(rng) + x = jax.random.normal(key, (2, 3)) + new_x = a21.apply_to_point(jnp.moveaxis(x, -1, 0)) + new_x_apply2 = a2.apply_to_point(a1.apply_to_point(jnp.moveaxis(x, -1, 0))) + + self._assert_check({ + 'quat': (q2r(quat_affine.quat_multiply(a2.quaternion, a1.quaternion)), + q2r(a21.quaternion)), + 'rot': (jnp.matmul(r2t(a2.rotation), r2t(a1.rotation)), + r2t(a21.rotation)), + 'point': (v2t(new_x_apply2), + v2t(new_x)), + 'inverse': (x, v2t(a21.invert_point(new_x))), + }) + + def test_batching(self): + """Test that affine applies batchwise.""" + rng = jax.random.PRNGKey(42) + keys = jax.random.split(rng, 3) + q = jax.random.uniform(keys[0], (5, 2, 4)) + t = jax.random.uniform(keys[1], (2, 3)) + x = jax.random.uniform(keys[2], (5, 1, 3)) + + a = quat_affine.QuatAffine(q, t, unstack_inputs=True) + y = v2t(a.apply_to_point(jnp.moveaxis(x, -1, 0))) + + y_list = [] + for i in range(5): + for j in range(2): + a_local = quat_affine.QuatAffine(q[i, j], t[j], + unstack_inputs=True) + y_local = v2t(a_local.apply_to_point(jnp.moveaxis(x[i, 0], -1, 0))) + y_list.append(y_local) + y_combine = jnp.reshape(jnp.stack(y_list, axis=0), (5, 2, 3)) + + self._assert_check({ + 'batch': (y_combine, y), + 'quat': (q2r(a.quaternion), + q2r(quat_affine.rot_to_quat(a.rotation))), + }) + + def assertAllClose(self, a, b, rtol=1e-06, atol=1e-06): + self.assertTrue(np.allclose(a, b, rtol=rtol, atol=atol)) + + def assertAllEqual(self, a, b): + self.assertTrue(np.all(np.array(a) == np.array(b))) + + +if __name__ == '__main__': + absltest.main() |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/r3.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/r3.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,320 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Transformations for 3D coordinates.\n+\n+This Module contains objects for representing Vectors (Vecs), Rotation Matrices\n+(Rots) and proper Rigid transformation (Rigids). These are represented as\n+named tuples with arrays for each entry, for example a set of\n+[N, M] points would be represented as a Vecs object with arrays of shape [N, M]\n+for x, y and z.\n+\n+This is being done to improve readability by making it very clear what objects\n+are geometric objects rather than relying on comments and array shapes.\n+Another reason for this is to avoid using matrix\n+multiplication primitives like matmul or einsum, on modern accelerator hardware\n+these can end up on specialized cores such as tensor cores on GPU or the MXU on\n+cloud TPUs, this often involves lower computational precision which can be\n+problematic for coordinate geometry. Also these cores are typically optimized\n+for larger matrices than 3 dimensional, this code is written to avoid any\n+unintended use of these cores on both GPUs and TPUs.\n+"""\n+\n+import collections\n+from typing import List\n+from alphafold.model import quat_affine\n+import jax.numpy as jnp\n+import tree\n+\n+# Array of 3-component vectors, stored as individual array for\n+# each component.\n+Vecs = collections.namedtuple(\'Vecs\', [\'x\', \'y\', \'z\'])\n+\n+# Array of 3x3 rotation matrices, stored as individual array for\n+# each component.\n+Rots = collections.namedtuple(\'Rots\', [\'xx\', \'xy\', \'xz\',\n+ \'yx\', \'yy\', \'yz\',\n+ \'zx\', \'zy\', \'zz\'])\n+# Array of rigid 3D transformations, stored as array of rotations and\n+# array of translations.\n+Rigids = collections.namedtuple(\'Rigids\', [\'rot\', \'trans\'])\n+\n+\n+def squared_difference(x, y):\n+ return jnp.square(x - y)\n+\n+\n+def invert_rigids(r: Rigids) -> Rigids:\n+ """Computes group inverse of rigid transformations \'r\'."""\n+ inv_rots = invert_rots(r.rot)\n+ t = rots_mul_vecs(inv_rots, r.trans)\n+ inv_trans = Vecs(-t.x, -t.y, -t.z)\n+ return Rigids(inv_rots, inv_trans)\n+\n+\n+def invert_rots(m: Rots) -> Rots:\n+ """Computes inverse of rotations \'m\'."""\n+ return Rots(m.xx, m.yx, m.zx,\n+ m.xy, m.yy, m.zy,\n+ m.xz, m.yz, m.zz)\n+\n+\n+def rigids_from_3_points(\n+ point_on_neg_x_axis: Vecs, # shape (...)\n+ origin: Vecs, # shape (...)\n+ point_on_xy_plane: Vecs, # shape (...)\n+) -> Rigids: # shape (...)\n+ """Create Rigids from 3 points.\n+\n+ Jumper et al. (2021) Suppl. Alg. 21 "rigidFrom3Points"\n+ This creates a set of rigid transformations from 3 points by Gram Schmidt\n+ orthogonalization.\n+\n+ Args:\n+ point_on_neg_x_axis: Vecs corresponding to points on the negative x axis\n+ origin: Origin of resulting rigid transformations\n+ point_on_xy_plane: Vecs corresponding to points in the xy plane\n+ Returns:\n+ Rigid transformations from global frame to local frames derived from\n+ the input points.\n+ """\n+ m = rots_from_two_vecs(\n+ e0_unnormalized=vecs_sub(origin, point_on_neg_x_axis),\n+ e1_unnormalized=vecs_sub(point_on_xy_plane, origin))\n+\n+ return Rigids(rot=m, trans=origin)\n+\n+\n+def rigids_from_list(l: List[jnp.ndarray]) -> Rigids:\n+ """Converts flat list of arrays to rigid transformations."""\n+ assert len(l) == 12\n+ return Rigids(Rots(*(l[:9])), Vecs(*(l[9:])))\n+\n+\n+def rigids_from_quataffine(a: quat_affine.QuatAffine) -> Rigids:\n+ """'..b'.., 3, 3)\n+) -> Rots: # shape (...)\n+ """Convert rotations represented as (3, 3) array to Rots."""\n+ assert m.shape[-1] == 3\n+ assert m.shape[-2] == 3\n+ return Rots(m[..., 0, 0], m[..., 0, 1], m[..., 0, 2],\n+ m[..., 1, 0], m[..., 1, 1], m[..., 1, 2],\n+ m[..., 2, 0], m[..., 2, 1], m[..., 2, 2])\n+\n+\n+def rots_from_two_vecs(e0_unnormalized: Vecs, e1_unnormalized: Vecs) -> Rots:\n+ """Create rotation matrices from unnormalized vectors for the x and y-axes.\n+\n+ This creates a rotation matrix from two vectors using Gram-Schmidt\n+ orthogonalization.\n+\n+ Args:\n+ e0_unnormalized: vectors lying along x-axis of resulting rotation\n+ e1_unnormalized: vectors lying in xy-plane of resulting rotation\n+ Returns:\n+ Rotations resulting from Gram-Schmidt procedure.\n+ """\n+ # Normalize the unit vector for the x-axis, e0.\n+ e0 = vecs_robust_normalize(e0_unnormalized)\n+\n+ # make e1 perpendicular to e0.\n+ c = vecs_dot_vecs(e1_unnormalized, e0)\n+ e1 = Vecs(e1_unnormalized.x - c * e0.x,\n+ e1_unnormalized.y - c * e0.y,\n+ e1_unnormalized.z - c * e0.z)\n+ e1 = vecs_robust_normalize(e1)\n+\n+ # Compute e2 as cross product of e0 and e1.\n+ e2 = vecs_cross_vecs(e0, e1)\n+\n+ return Rots(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z)\n+\n+\n+def rots_mul_rots(a: Rots, b: Rots) -> Rots:\n+ """Composition of rotations \'a\' and \'b\'."""\n+ c0 = rots_mul_vecs(a, Vecs(b.xx, b.yx, b.zx))\n+ c1 = rots_mul_vecs(a, Vecs(b.xy, b.yy, b.zy))\n+ c2 = rots_mul_vecs(a, Vecs(b.xz, b.yz, b.zz))\n+ return Rots(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)\n+\n+\n+def rots_mul_vecs(m: Rots, v: Vecs) -> Vecs:\n+ """Apply rotations \'m\' to vectors \'v\'."""\n+ return Vecs(m.xx * v.x + m.xy * v.y + m.xz * v.z,\n+ m.yx * v.x + m.yy * v.y + m.yz * v.z,\n+ m.zx * v.x + m.zy * v.y + m.zz * v.z)\n+\n+\n+def vecs_add(v1: Vecs, v2: Vecs) -> Vecs:\n+ """Add two vectors \'v1\' and \'v2\'."""\n+ return Vecs(v1.x + v2.x, v1.y + v2.y, v1.z + v2.z)\n+\n+\n+def vecs_dot_vecs(v1: Vecs, v2: Vecs) -> jnp.ndarray:\n+ """Dot product of vectors \'v1\' and \'v2\'."""\n+ return v1.x * v2.x + v1.y * v2.y + v1.z * v2.z\n+\n+\n+def vecs_cross_vecs(v1: Vecs, v2: Vecs) -> Vecs:\n+ """Cross product of vectors \'v1\' and \'v2\'."""\n+ return Vecs(v1.y * v2.z - v1.z * v2.y,\n+ v1.z * v2.x - v1.x * v2.z,\n+ v1.x * v2.y - v1.y * v2.x)\n+\n+\n+def vecs_from_tensor(x: jnp.ndarray # shape (..., 3)\n+ ) -> Vecs: # shape (...)\n+ """Converts from tensor of shape (3,) to Vecs."""\n+ num_components = x.shape[-1]\n+ assert num_components == 3\n+ return Vecs(x[..., 0], x[..., 1], x[..., 2])\n+\n+\n+def vecs_robust_normalize(v: Vecs, epsilon: float = 1e-8) -> Vecs:\n+ """Normalizes vectors \'v\'.\n+\n+ Args:\n+ v: vectors to be normalized.\n+ epsilon: small regularizer added to squared norm before taking square root.\n+ Returns:\n+ normalized vectors\n+ """\n+ norms = vecs_robust_norm(v, epsilon)\n+ return Vecs(v.x / norms, v.y / norms, v.z / norms)\n+\n+\n+def vecs_robust_norm(v: Vecs, epsilon: float = 1e-8) -> jnp.ndarray:\n+ """Computes norm of vectors \'v\'.\n+\n+ Args:\n+ v: vectors to be normalized.\n+ epsilon: small regularizer added to squared norm before taking square root.\n+ Returns:\n+ norm of \'v\'\n+ """\n+ return jnp.sqrt(jnp.square(v.x) + jnp.square(v.y) + jnp.square(v.z) + epsilon)\n+\n+\n+def vecs_sub(v1: Vecs, v2: Vecs) -> Vecs:\n+ """Computes v1 - v2."""\n+ return Vecs(v1.x - v2.x, v1.y - v2.y, v1.z - v2.z)\n+\n+\n+def vecs_squared_distance(v1: Vecs, v2: Vecs) -> jnp.ndarray:\n+ """Computes squared euclidean difference between \'v1\' and \'v2\'."""\n+ return (squared_difference(v1.x, v2.x) +\n+ squared_difference(v1.y, v2.y) +\n+ squared_difference(v1.z, v2.z))\n+\n+\n+def vecs_to_tensor(v: Vecs # shape (...)\n+ ) -> jnp.ndarray: # shape(..., 3)\n+ """Converts \'v\' to tensor with shape 3, inverse of \'vecs_from_tensor\'."""\n+ return jnp.stack([v.x, v.y, v.z], axis=-1)\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/tf/__init__.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/tf/__init__.py Tue Mar 01 02:53:05 2022 +0000 |
b |
@@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Alphafold model TensorFlow code.""" |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/tf/data_transforms.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/tf/data_transforms.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,625 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Data for AlphaFold."""\n+\n+from alphafold.common import residue_constants\n+from alphafold.model.tf import shape_helpers\n+from alphafold.model.tf import shape_placeholders\n+from alphafold.model.tf import utils\n+import numpy as np\n+import tensorflow.compat.v1 as tf\n+\n+# Pylint gets confused by the curry1 decorator because it changes the number\n+# of arguments to the function.\n+# pylint:disable=no-value-for-parameter\n+\n+\n+NUM_RES = shape_placeholders.NUM_RES\n+NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ\n+NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ\n+NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES\n+\n+\n+def cast_64bit_ints(protein):\n+\n+ for k, v in protein.items():\n+ if v.dtype == tf.int64:\n+ protein[k] = tf.cast(v, tf.int32)\n+ return protein\n+\n+\n+_MSA_FEATURE_NAMES = [\n+ \'msa\', \'deletion_matrix\', \'msa_mask\', \'msa_row_mask\', \'bert_mask\',\n+ \'true_msa\'\n+]\n+\n+\n+def make_seq_mask(protein):\n+ protein[\'seq_mask\'] = tf.ones(\n+ shape_helpers.shape_list(protein[\'aatype\']), dtype=tf.float32)\n+ return protein\n+\n+\n+def make_template_mask(protein):\n+ protein[\'template_mask\'] = tf.ones(\n+ shape_helpers.shape_list(protein[\'template_domain_names\']),\n+ dtype=tf.float32)\n+ return protein\n+\n+\n+def curry1(f):\n+ """Supply all arguments but the first."""\n+\n+ def fc(*args, **kwargs):\n+ return lambda x: f(x, *args, **kwargs)\n+\n+ return fc\n+\n+\n+@curry1\n+def add_distillation_flag(protein, distillation):\n+ protein[\'is_distillation\'] = tf.constant(float(distillation),\n+ shape=[],\n+ dtype=tf.float32)\n+ return protein\n+\n+\n+def make_all_atom_aatype(protein):\n+ protein[\'all_atom_aatype\'] = protein[\'aatype\']\n+ return protein\n+\n+\n+def fix_templates_aatype(protein):\n+ """Fixes aatype encoding of templates."""\n+ # Map one-hot to indices.\n+ protein[\'template_aatype\'] = tf.argmax(\n+ protein[\'template_aatype\'], output_type=tf.int32, axis=-1)\n+ # Map hhsearch-aatype to our aatype.\n+ new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE\n+ new_order = tf.constant(new_order_list, dtype=tf.int32)\n+ protein[\'template_aatype\'] = tf.gather(params=new_order,\n+ indices=protein[\'template_aatype\'])\n+ return protein\n+\n+\n+def correct_msa_restypes(protein):\n+ """Correct MSA restype to have the same order as residue_constants."""\n+ new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE\n+ new_order = tf.constant(new_order_list, dtype=protein[\'msa\'].dtype)\n+ protein[\'msa\'] = tf.gather(new_order, protein[\'msa\'], axis=0)\n+\n+ perm_matrix = np.zeros((22, 22), dtype=np.float32)\n+ perm_matrix[range(len(new_order_list)), new_order_list] = 1.\n+\n+ for k in protein:\n+ if \'profile\' in k: # Include both hhblits and psiblast profiles\n+ num_dim = protein[k].shape.as_list()[-1]\n+ assert num_dim in [20, 21, 22], (\n+ \'num_dim for %s out of expected range: %s\' % (k, num_dim))\n+ protein[k] = tf.tensordot(protein[k], perm_matrix[:num_dim, :num_dim], 1)\n+ return protein\n+\n+\n+def squeeze_features(protein):\n+ """Remove singleton and repeated dimensions in protein features."""\n+ protein[\'aatype\'] = tf.argmax(\n+ protein[\'aatype\'], axis=-1, output_type=tf.int32)\n+ for k in [\n+ \'domain_name\', \'msa\', \'num_alignments\', \'seq_length\', \'seq'..b'andom.stateless_uniform(\n+ shape=(), minval=0, maxval=seq_length - num_res_crop_size + 1,\n+ dtype=tf.int32, seed=seed_maker())\n+\n+ templates_select_indices = tf.argsort(tf.random.stateless_uniform(\n+ [num_templates], seed=seed_maker()))\n+\n+ for k, v in protein.items():\n+ if k not in shape_schema or (\n+ \'template\' not in k and NUM_RES not in shape_schema[k]):\n+ continue\n+\n+ # randomly permute the templates before cropping them.\n+ if k.startswith(\'template\') and subsample_templates:\n+ v = tf.gather(v, templates_select_indices)\n+\n+ crop_sizes = []\n+ crop_starts = []\n+ for i, (dim_size, dim) in enumerate(zip(shape_schema[k],\n+ shape_helpers.shape_list(v))):\n+ is_num_res = (dim_size == NUM_RES)\n+ if i == 0 and k.startswith(\'template\'):\n+ crop_size = num_templates_crop_size\n+ crop_start = templates_crop_start\n+ else:\n+ crop_start = num_res_crop_start if is_num_res else 0\n+ crop_size = (num_res_crop_size if is_num_res else\n+ (-1 if dim is None else dim))\n+ crop_sizes.append(crop_size)\n+ crop_starts.append(crop_start)\n+ protein[k] = tf.slice(v, crop_starts, crop_sizes)\n+\n+ protein[\'seq_length\'] = num_res_crop_size\n+ return protein\n+\n+\n+def make_atom14_masks(protein):\n+ """Construct denser atom positions (14 dimensions instead of 37)."""\n+ restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37\n+ restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14\n+ restype_atom14_mask = []\n+\n+ for rt in residue_constants.restypes:\n+ atom_names = residue_constants.restype_name_to_atom14_names[\n+ residue_constants.restype_1to3[rt]]\n+\n+ restype_atom14_to_atom37.append([\n+ (residue_constants.atom_order[name] if name else 0)\n+ for name in atom_names\n+ ])\n+\n+ atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}\n+ restype_atom37_to_atom14.append([\n+ (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)\n+ for name in residue_constants.atom_types\n+ ])\n+\n+ restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])\n+\n+ # Add dummy mapping for restype \'UNK\'\n+ restype_atom14_to_atom37.append([0] * 14)\n+ restype_atom37_to_atom14.append([0] * 37)\n+ restype_atom14_mask.append([0.] * 14)\n+\n+ restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)\n+ restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)\n+ restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)\n+\n+ # create the mapping for (residx, atom14) --> atom37, i.e. an array\n+ # with shape (num_res, 14) containing the atom37 indices for this protein\n+ residx_atom14_to_atom37 = tf.gather(restype_atom14_to_atom37,\n+ protein[\'aatype\'])\n+ residx_atom14_mask = tf.gather(restype_atom14_mask,\n+ protein[\'aatype\'])\n+\n+ protein[\'atom14_atom_exists\'] = residx_atom14_mask\n+ protein[\'residx_atom14_to_atom37\'] = residx_atom14_to_atom37\n+\n+ # create the gather indices for mapping back\n+ residx_atom37_to_atom14 = tf.gather(restype_atom37_to_atom14,\n+ protein[\'aatype\'])\n+ protein[\'residx_atom37_to_atom14\'] = residx_atom37_to_atom14\n+\n+ # create the corresponding mask\n+ restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)\n+ for restype, restype_letter in enumerate(residue_constants.restypes):\n+ restype_name = residue_constants.restype_1to3[restype_letter]\n+ atom_names = residue_constants.residue_atoms[restype_name]\n+ for atom_name in atom_names:\n+ atom_type = residue_constants.atom_order[atom_name]\n+ restype_atom37_mask[restype, atom_type] = 1\n+\n+ residx_atom37_mask = tf.gather(restype_atom37_mask,\n+ protein[\'aatype\'])\n+ protein[\'atom37_atom_exists\'] = residx_atom37_mask\n+\n+ return protein\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/tf/input_pipeline.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/tf/input_pipeline.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,166 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Feature pre-processing input pipeline for AlphaFold.""" + +from alphafold.model.tf import data_transforms +from alphafold.model.tf import shape_placeholders +import tensorflow.compat.v1 as tf +import tree + +# Pylint gets confused by the curry1 decorator because it changes the number +# of arguments to the function. +# pylint:disable=no-value-for-parameter + + +NUM_RES = shape_placeholders.NUM_RES +NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ +NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ +NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES + + +def nonensembled_map_fns(data_config): + """Input pipeline functions which are not ensembled.""" + common_cfg = data_config.common + + map_fns = [ + data_transforms.correct_msa_restypes, + data_transforms.add_distillation_flag(False), + data_transforms.cast_64bit_ints, + data_transforms.squeeze_features, + # Keep to not disrupt RNG. + data_transforms.randomly_replace_msa_with_unknown(0.0), + data_transforms.make_seq_mask, + data_transforms.make_msa_mask, + # Compute the HHblits profile if it's not set. This has to be run before + # sampling the MSA. + data_transforms.make_hhblits_profile, + data_transforms.make_random_crop_to_size_seed, + ] + if common_cfg.use_templates: + map_fns.extend([ + data_transforms.fix_templates_aatype, + data_transforms.make_template_mask, + data_transforms.make_pseudo_beta('template_') + ]) + map_fns.extend([ + data_transforms.make_atom14_masks, + ]) + + return map_fns + + +def ensembled_map_fns(data_config): + """Input pipeline functions that can be ensembled and averaged.""" + common_cfg = data_config.common + eval_cfg = data_config.eval + + map_fns = [] + + if common_cfg.reduce_msa_clusters_by_max_templates: + pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates + else: + pad_msa_clusters = eval_cfg.max_msa_clusters + + max_msa_clusters = pad_msa_clusters + max_extra_msa = common_cfg.max_extra_msa + + map_fns.append( + data_transforms.sample_msa( + max_msa_clusters, + keep_extra=True)) + + if 'masked_msa' in common_cfg: + # Masked MSA should come *before* MSA clustering so that + # the clustering and full MSA profile do not leak information about + # the masked locations and secret corrupted locations. + map_fns.append( + data_transforms.make_masked_msa(common_cfg.masked_msa, + eval_cfg.masked_msa_replace_fraction)) + + if common_cfg.msa_cluster_features: + map_fns.append(data_transforms.nearest_neighbor_clusters()) + map_fns.append(data_transforms.summarize_clusters()) + + # Crop after creating the cluster profiles. + if max_extra_msa: + map_fns.append(data_transforms.crop_extra_msa(max_extra_msa)) + else: + map_fns.append(data_transforms.delete_extra_msa) + + map_fns.append(data_transforms.make_msa_feat()) + + crop_feats = dict(eval_cfg.feat) + + if eval_cfg.fixed_size: + map_fns.append(data_transforms.select_feat(list(crop_feats))) + map_fns.append(data_transforms.random_crop_to_size( + eval_cfg.crop_size, + eval_cfg.max_templates, + crop_feats, + eval_cfg.subsample_templates)) + map_fns.append(data_transforms.make_fixed_size( + crop_feats, + pad_msa_clusters, + common_cfg.max_extra_msa, + eval_cfg.crop_size, + eval_cfg.max_templates)) + else: + map_fns.append(data_transforms.crop_templates(eval_cfg.max_templates)) + + return map_fns + + +def process_tensors_from_config(tensors, data_config): + """Apply filters and maps to an existing dataset, based on the config.""" + + def wrap_ensemble_fn(data, i): + """Function to be mapped over the ensemble dimension.""" + d = data.copy() + fns = ensembled_map_fns(data_config) + fn = compose(fns) + d['ensemble_index'] = i + return fn(d) + + eval_cfg = data_config.eval + tensors = compose( + nonensembled_map_fns( + data_config))( + tensors) + + tensors_0 = wrap_ensemble_fn(tensors, tf.constant(0)) + num_ensemble = eval_cfg.num_ensemble + if data_config.common.resample_msa_in_recycling: + # Separate batch per ensembling & recycling step. + num_ensemble *= data_config.common.num_recycle + 1 + + if isinstance(num_ensemble, tf.Tensor) or num_ensemble > 1: + fn_output_signature = tree.map_structure( + tf.TensorSpec.from_tensor, tensors_0) + tensors = tf.map_fn( + lambda x: wrap_ensemble_fn(tensors, x), + tf.range(num_ensemble), + parallel_iterations=1, + fn_output_signature=fn_output_signature) + else: + tensors = tree.map_structure(lambda x: x[None], + tensors_0) + return tensors + + +@data_transforms.curry1 +def compose(x, fs): + for f in fs: + x = f(x) + return x |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/tf/protein_features.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/tf/protein_features.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,129 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains descriptions of various protein features.""" +import enum +from typing import Dict, Optional, Sequence, Tuple, Union +from alphafold.common import residue_constants +import tensorflow.compat.v1 as tf + +# Type aliases. +FeaturesMetadata = Dict[str, Tuple[tf.dtypes.DType, Sequence[Union[str, int]]]] + + +class FeatureType(enum.Enum): + ZERO_DIM = 0 # Shape [x] + ONE_DIM = 1 # Shape [num_res, x] + TWO_DIM = 2 # Shape [num_res, num_res, x] + MSA = 3 # Shape [msa_length, num_res, x] + + +# Placeholder values that will be replaced with their true value at runtime. +NUM_RES = "num residues placeholder" +NUM_SEQ = "length msa placeholder" +NUM_TEMPLATES = "num templates placeholder" +# Sizes of the protein features, NUM_RES and NUM_SEQ are allowed as placeholders +# to be replaced with the number of residues and the number of sequences in the +# multiple sequence alignment, respectively. + + +FEATURES = { + #### Static features of a protein sequence #### + "aatype": (tf.float32, [NUM_RES, 21]), + "between_segment_residues": (tf.int64, [NUM_RES, 1]), + "deletion_matrix": (tf.float32, [NUM_SEQ, NUM_RES, 1]), + "domain_name": (tf.string, [1]), + "msa": (tf.int64, [NUM_SEQ, NUM_RES, 1]), + "num_alignments": (tf.int64, [NUM_RES, 1]), + "residue_index": (tf.int64, [NUM_RES, 1]), + "seq_length": (tf.int64, [NUM_RES, 1]), + "sequence": (tf.string, [1]), + "all_atom_positions": (tf.float32, + [NUM_RES, residue_constants.atom_type_num, 3]), + "all_atom_mask": (tf.int64, [NUM_RES, residue_constants.atom_type_num]), + "resolution": (tf.float32, [1]), + "template_domain_names": (tf.string, [NUM_TEMPLATES]), + "template_sum_probs": (tf.float32, [NUM_TEMPLATES, 1]), + "template_aatype": (tf.float32, [NUM_TEMPLATES, NUM_RES, 22]), + "template_all_atom_positions": (tf.float32, [ + NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 3 + ]), + "template_all_atom_masks": (tf.float32, [ + NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 1 + ]), +} + +FEATURE_TYPES = {k: v[0] for k, v in FEATURES.items()} +FEATURE_SIZES = {k: v[1] for k, v in FEATURES.items()} + + +def register_feature(name: str, + type_: tf.dtypes.DType, + shape_: Tuple[Union[str, int]]): + """Register extra features used in custom datasets.""" + FEATURES[name] = (type_, shape_) + FEATURE_TYPES[name] = type_ + FEATURE_SIZES[name] = shape_ + + +def shape(feature_name: str, + num_residues: int, + msa_length: int, + num_templates: Optional[int] = None, + features: Optional[FeaturesMetadata] = None): + """Get the shape for the given feature name. + + This is near identical to _get_tf_shape_no_placeholders() but with 2 + differences: + * This method does not calculate a single placeholder from the total number of + elements (eg given <NUM_RES, 3> and size := 12, this won't deduce NUM_RES + must be 4) + * This method will work with tensors + + Args: + feature_name: String identifier for the feature. If the feature name ends + with "_unnormalized", this suffix is stripped off. + num_residues: The number of residues in the current domain - some elements + of the shape can be dynamic and will be replaced by this value. + msa_length: The number of sequences in the multiple sequence alignment, some + elements of the shape can be dynamic and will be replaced by this value. + If the number of alignments is unknown / not read, please pass None for + msa_length. + num_templates (optional): The number of templates in this tfexample. + features: A feature_name to (tf_dtype, shape) lookup; defaults to FEATURES. + + Returns: + List of ints representation the tensor size. + + Raises: + ValueError: If a feature is requested but no concrete placeholder value is + given. + """ + features = features or FEATURES + if feature_name.endswith("_unnormalized"): + feature_name = feature_name[:-13] + + unused_dtype, raw_sizes = features[feature_name] + replacements = {NUM_RES: num_residues, + NUM_SEQ: msa_length} + + if num_templates is not None: + replacements[NUM_TEMPLATES] = num_templates + + sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes] + for dimension in sizes: + if isinstance(dimension, str): + raise ValueError("Could not parse %s (shape: %s) with values: %s" % ( + feature_name, raw_sizes, replacements)) + return sizes |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/tf/protein_features_test.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/tf/protein_features_test.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,51 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for protein_features.""" +import uuid + +from absl.testing import absltest +from absl.testing import parameterized +from alphafold.model.tf import protein_features +import tensorflow.compat.v1 as tf + + +def _random_bytes(): + return str(uuid.uuid4()).encode('utf-8') + + +class FeaturesTest(parameterized.TestCase, tf.test.TestCase): + + def testFeatureNames(self): + self.assertEqual(len(protein_features.FEATURE_SIZES), + len(protein_features.FEATURE_TYPES)) + sorted_size_names = sorted(protein_features.FEATURE_SIZES.keys()) + sorted_type_names = sorted(protein_features.FEATURE_TYPES.keys()) + for i, size_name in enumerate(sorted_size_names): + self.assertEqual(size_name, sorted_type_names[i]) + + def testReplacement(self): + for name in protein_features.FEATURE_SIZES.keys(): + sizes = protein_features.shape(name, + num_residues=12, + msa_length=24, + num_templates=3) + for x in sizes: + self.assertEqual(type(x), int) + self.assertGreater(x, 0) + + +if __name__ == '__main__': + tf.disable_v2_behavior() + absltest.main() |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/tf/proteins_dataset.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/tf/proteins_dataset.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,166 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Datasets consisting of proteins.""" +from typing import Dict, Mapping, Optional, Sequence +from alphafold.model.tf import protein_features +import numpy as np +import tensorflow.compat.v1 as tf + +TensorDict = Dict[str, tf.Tensor] + + +def parse_tfexample( + raw_data: bytes, + features: protein_features.FeaturesMetadata, + key: Optional[str] = None) -> Dict[str, tf.train.Feature]: + """Read a single TF Example proto and return a subset of its features. + + Args: + raw_data: A serialized tf.Example proto. + features: A dictionary of features, mapping string feature names to a tuple + (dtype, shape). This dictionary should be a subset of + protein_features.FEATURES (or the dictionary itself for all features). + key: Optional string with the SSTable key of that tf.Example. This will be + added into features as a 'key' but only if requested in features. + + Returns: + A dictionary of features mapping feature names to features. Only the given + features are returned, all other ones are filtered out. + """ + feature_map = { + k: tf.io.FixedLenSequenceFeature(shape=(), dtype=v[0], allow_missing=True) + for k, v in features.items() + } + parsed_features = tf.io.parse_single_example(raw_data, feature_map) + reshaped_features = parse_reshape_logic(parsed_features, features, key=key) + + return reshaped_features + + +def _first(tensor: tf.Tensor) -> tf.Tensor: + """Returns the 1st element - the input can be a tensor or a scalar.""" + return tf.reshape(tensor, shape=(-1,))[0] + + +def parse_reshape_logic( + parsed_features: TensorDict, + features: protein_features.FeaturesMetadata, + key: Optional[str] = None) -> TensorDict: + """Transforms parsed serial features to the correct shape.""" + # Find out what is the number of sequences and the number of alignments. + num_residues = tf.cast(_first(parsed_features["seq_length"]), dtype=tf.int32) + + if "num_alignments" in parsed_features: + num_msa = tf.cast(_first(parsed_features["num_alignments"]), dtype=tf.int32) + else: + num_msa = 0 + + if "template_domain_names" in parsed_features: + num_templates = tf.cast( + tf.shape(parsed_features["template_domain_names"])[0], dtype=tf.int32) + else: + num_templates = 0 + + if key is not None and "key" in features: + parsed_features["key"] = [key] # Expand dims from () to (1,). + + # Reshape the tensors according to the sequence length and num alignments. + for k, v in parsed_features.items(): + new_shape = protein_features.shape( + feature_name=k, + num_residues=num_residues, + msa_length=num_msa, + num_templates=num_templates, + features=features) + new_shape_size = tf.constant(1, dtype=tf.int32) + for dim in new_shape: + new_shape_size *= tf.cast(dim, tf.int32) + + assert_equal = tf.assert_equal( + tf.size(v), new_shape_size, + name="assert_%s_shape_correct" % k, + message="The size of feature %s (%s) could not be reshaped " + "into %s" % (k, tf.size(v), new_shape)) + if "template" not in k: + # Make sure the feature we are reshaping is not empty. + assert_non_empty = tf.assert_greater( + tf.size(v), 0, name="assert_%s_non_empty" % k, + message="The feature %s is not set in the tf.Example. Either do not " + "request the feature or use a tf.Example that has the " + "feature set." % k) + with tf.control_dependencies([assert_non_empty, assert_equal]): + parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k) + else: + with tf.control_dependencies([assert_equal]): + parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k) + + return parsed_features + + +def _make_features_metadata( + feature_names: Sequence[str]) -> protein_features.FeaturesMetadata: + """Makes a feature name to type and shape mapping from a list of names.""" + # Make sure these features are always read. + required_features = ["aatype", "sequence", "seq_length"] + feature_names = list(set(feature_names) | set(required_features)) + + features_metadata = {name: protein_features.FEATURES[name] + for name in feature_names} + return features_metadata + + +def create_tensor_dict( + raw_data: bytes, + features: Sequence[str], + key: Optional[str] = None, + ) -> TensorDict: + """Creates a dictionary of tensor features. + + Args: + raw_data: A serialized tf.Example proto. + features: A list of strings of feature names to be returned in the dataset. + key: Optional string with the SSTable key of that tf.Example. This will be + added into features as a 'key' but only if requested in features. + + Returns: + A dictionary of features mapping feature names to features. Only the given + features are returned, all other ones are filtered out. + """ + features_metadata = _make_features_metadata(features) + return parse_tfexample(raw_data, features_metadata, key) + + +def np_to_tensor_dict( + np_example: Mapping[str, np.ndarray], + features: Sequence[str], + ) -> TensorDict: + """Creates dict of tensors from a dict of NumPy arrays. + + Args: + np_example: A dict of NumPy feature arrays. + features: A list of strings of feature names to be returned in the dataset. + + Returns: + A dictionary of features mapping feature names to features. Only the given + features are returned, all other ones are filtered out. + """ + features_metadata = _make_features_metadata(features) + tensor_dict = {k: tf.constant(v) for k, v in np_example.items() + if k in features_metadata} + + # Ensures shapes are as expected. Needed for setting size of empty features + # e.g. when no template hits were found. + tensor_dict = parse_reshape_logic(tensor_dict, features_metadata) + return tensor_dict |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/tf/shape_helpers.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/tf/shape_helpers.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,47 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for dealing with shapes of TensorFlow tensors.""" +import tensorflow.compat.v1 as tf + + +def shape_list(x): + """Return list of dimensions of a tensor, statically where possible. + + Like `x.shape.as_list()` but with tensors instead of `None`s. + + Args: + x: A tensor. + Returns: + A list with length equal to the rank of the tensor. The n-th element of the + list is an integer when that dimension is statically known otherwise it is + the n-th element of `tf.shape(x)`. + """ + x = tf.convert_to_tensor(x) + + # If unknown rank, return dynamic shape + if x.get_shape().dims is None: + return tf.shape(x) + + static = x.get_shape().as_list() + shape = tf.shape(x) + + ret = [] + for i in range(len(static)): + dim = static[i] + if dim is None: + dim = shape[i] + ret.append(dim) + return ret + |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/tf/shape_helpers_test.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/tf/shape_helpers_test.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,39 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for shape_helpers.""" + +from alphafold.model.tf import shape_helpers +import numpy as np +import tensorflow.compat.v1 as tf + + +class ShapeTest(tf.test.TestCase): + + def test_shape_list(self): + """Test that shape_list can allow for reshaping to dynamic shapes.""" + a = tf.zeros([10, 4, 4, 2]) + p = tf.placeholder(tf.float32, shape=[None, None, 1, 4, 4]) + shape_dyn = shape_helpers.shape_list(p)[:2] + [4, 4] + + b = tf.reshape(a, shape_dyn) + with self.session() as sess: + out = sess.run(b, feed_dict={p: np.ones((20, 1, 1, 4, 4))}) + + self.assertAllEqual(out.shape, (20, 1, 4, 4)) + + +if __name__ == '__main__': + tf.disable_v2_behavior() + tf.test.main() |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/tf/shape_placeholders.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/tf/shape_placeholders.py Tue Mar 01 02:53:05 2022 +0000 |
b |
@@ -0,0 +1,20 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Placeholder values for run-time varying dimension sizes.""" + +NUM_RES = 'num residues placeholder' +NUM_MSA_SEQ = 'msa placeholder' +NUM_EXTRA_SEQ = 'extra msa placeholder' +NUM_TEMPLATES = 'num templates placeholder' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/tf/utils.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/tf/utils.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,47 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utilities for various components.""" +import tensorflow.compat.v1 as tf + + +def tf_combine_mask(*masks): + """Take the intersection of float-valued masks.""" + ret = 1 + for m in masks: + ret *= m + return ret + + +class SeedMaker(object): + """Return unique seeds.""" + + def __init__(self, initial_seed=0): + self.next_seed = initial_seed + + def __call__(self): + i = self.next_seed + self.next_seed += 1 + return i + +seed_maker = SeedMaker() + + +def make_random_seed(): + return tf.random.uniform([2], + tf.int32.min, + tf.int32.max, + tf.int32, + seed=seed_maker()) + |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/model/utils.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/utils.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,131 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A collection of JAX utility functions for use in protein folding.""" + +import collections +import functools +import numbers +from typing import Mapping + +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np + + +def final_init(config): + if config.zero_init: + return 'zeros' + else: + return 'linear' + + +def batched_gather(params, indices, axis=0, batch_dims=0): + """Implements a JAX equivalent of `tf.gather` with `axis` and `batch_dims`.""" + take_fn = lambda p, i: jnp.take(p, i, axis=axis) + for _ in range(batch_dims): + take_fn = jax.vmap(take_fn) + return take_fn(params, indices) + + +def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10): + """Masked mean.""" + if drop_mask_channel: + mask = mask[..., 0] + + mask_shape = mask.shape + value_shape = value.shape + + assert len(mask_shape) == len(value_shape) + + if isinstance(axis, numbers.Integral): + axis = [axis] + elif axis is None: + axis = list(range(len(mask_shape))) + assert isinstance(axis, collections.Iterable), ( + 'axis needs to be either an iterable, integer or "None"') + + broadcast_factor = 1. + for axis_ in axis: + value_size = value_shape[axis_] + mask_size = mask_shape[axis_] + if mask_size == 1: + broadcast_factor *= value_size + else: + assert mask_size == value_size + + return (jnp.sum(mask * value, axis=axis) / + (jnp.sum(mask, axis=axis) * broadcast_factor + eps)) + + +def flat_params_to_haiku(params: Mapping[str, np.ndarray]) -> hk.Params: + """Convert a dictionary of NumPy arrays to Haiku parameters.""" + hk_params = {} + for path, array in params.items(): + scope, name = path.split('//') + if scope not in hk_params: + hk_params[scope] = {} + hk_params[scope][name] = jnp.array(array) + + return hk_params + + +def padding_consistent_rng(f): + """Modify any element-wise random function to be consistent with padding. + + Normally if you take a function like jax.random.normal and generate an array, + say of size (10,10), you will get a different set of random numbers to if you + add padding and take the first (10,10) sub-array. + + This function makes a random function that is consistent regardless of the + amount of padding added. + + Note: The padding-consistent function is likely to be slower to compile and + run than the function it is wrapping, but these slowdowns are likely to be + negligible in a large network. + + Args: + f: Any element-wise function that takes (PRNG key, shape) as the first 2 + arguments. + + Returns: + An equivalent function to f, that is now consistent for different amounts of + padding. + """ + def grid_keys(key, shape): + """Generate a grid of rng keys that is consistent with different padding. + + Generate random keys such that the keys will be identical, regardless of + how much padding is added to any dimension. + + Args: + key: A PRNG key. + shape: The shape of the output array of keys that will be generated. + + Returns: + An array of shape `shape` consisting of random keys. + """ + if not shape: + return key + new_keys = jax.vmap(functools.partial(jax.random.fold_in, key))( + jnp.arange(shape[0])) + return jax.vmap(functools.partial(grid_keys, shape=shape[1:]))(new_keys) + + def inner(key, shape, **kwargs): + return jnp.vectorize( + lambda key: f(key, shape=(), **kwargs), + signature='(2)->()')( + grid_keys(key, shape)) + return inner |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/notebooks/__init__.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/notebooks/__init__.py Tue Mar 01 02:53:05 2022 +0000 |
b |
@@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""AlphaFold Colab notebook.""" |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/notebooks/notebook_utils.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/notebooks/notebook_utils.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,182 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper methods for the AlphaFold Colab notebook.""" +import enum +import json +from typing import Any, Mapping, Optional, Sequence, Tuple + +from alphafold.common import residue_constants +from alphafold.data import parsers +from matplotlib import pyplot as plt +import numpy as np + + +@enum.unique +class ModelType(enum.Enum): + MONOMER = 0 + MULTIMER = 1 + + +def clean_and_validate_sequence( + input_sequence: str, min_length: int, max_length: int) -> str: + """Checks that the input sequence is ok and returns a clean version of it.""" + # Remove all whitespaces, tabs and end lines; upper-case. + clean_sequence = input_sequence.translate( + str.maketrans('', '', ' \n\t')).upper() + aatypes = set(residue_constants.restypes) # 20 standard aatypes. + if not set(clean_sequence).issubset(aatypes): + raise ValueError( + f'Input sequence contains non-amino acid letters: ' + f'{set(clean_sequence) - aatypes}. AlphaFold only supports 20 standard ' + 'amino acids as inputs.') + if len(clean_sequence) < min_length: + raise ValueError( + f'Input sequence is too short: {len(clean_sequence)} amino acids, ' + f'while the minimum is {min_length}') + if len(clean_sequence) > max_length: + raise ValueError( + f'Input sequence is too long: {len(clean_sequence)} amino acids, while ' + f'the maximum is {max_length}. You may be able to run it with the full ' + f'AlphaFold system depending on your resources (system memory, ' + f'GPU memory).') + return clean_sequence + + +def validate_input( + input_sequences: Sequence[str], + min_length: int, + max_length: int, + max_multimer_length: int) -> Tuple[Sequence[str], ModelType]: + """Validates and cleans input sequences and determines which model to use.""" + sequences = [] + + for input_sequence in input_sequences: + if input_sequence.strip(): + input_sequence = clean_and_validate_sequence( + input_sequence=input_sequence, + min_length=min_length, + max_length=max_length) + sequences.append(input_sequence) + + if len(sequences) == 1: + print('Using the single-chain model.') + return sequences, ModelType.MONOMER + + elif len(sequences) > 1: + total_multimer_length = sum([len(seq) for seq in sequences]) + if total_multimer_length > max_multimer_length: + raise ValueError(f'The total length of multimer sequences is too long: ' + f'{total_multimer_length}, while the maximum is ' + f'{max_multimer_length}. Please use the full AlphaFold ' + f'system for long multimers.') + elif total_multimer_length > 1536: + print('WARNING: The accuracy of the system has not been fully validated ' + 'above 1536 residues, and you may experience long running times or ' + f'run out of memory for your complex with {total_multimer_length} ' + 'residues.') + print(f'Using the multimer model with {len(sequences)} sequences.') + return sequences, ModelType.MULTIMER + + else: + raise ValueError('No input amino acid sequence provided, please provide at ' + 'least one sequence.') + + +def merge_chunked_msa( + results: Sequence[Mapping[str, Any]], + max_hits: Optional[int] = None + ) -> parsers.Msa: + """Merges chunked database hits together into hits for the full database.""" + unsorted_results = [] + for chunk_index, chunk in enumerate(results): + msa = parsers.parse_stockholm(chunk['sto']) + e_values_dict = parsers.parse_e_values_from_tblout(chunk['tbl']) + # Jackhmmer lists sequences as <sequence name>/<residue from>-<residue to>. + e_values = [e_values_dict[t.partition('/')[0]] for t in msa.descriptions] + chunk_results = zip( + msa.sequences, msa.deletion_matrix, msa.descriptions, e_values) + if chunk_index != 0: + next(chunk_results) # Only take query (first hit) from the first chunk. + unsorted_results.extend(chunk_results) + + sorted_by_evalue = sorted(unsorted_results, key=lambda x: x[-1]) + merged_sequences, merged_deletion_matrix, merged_descriptions, _ = zip( + *sorted_by_evalue) + merged_msa = parsers.Msa(sequences=merged_sequences, + deletion_matrix=merged_deletion_matrix, + descriptions=merged_descriptions) + if max_hits is not None: + merged_msa = merged_msa.truncate(max_seqs=max_hits) + + return merged_msa + + +def show_msa_info( + single_chain_msas: Sequence[parsers.Msa], + sequence_index: int): + """Prints info and shows a plot of the deduplicated single chain MSA.""" + full_single_chain_msa = [] + for single_chain_msa in single_chain_msas: + full_single_chain_msa.extend(single_chain_msa.sequences) + + # Deduplicate but preserve order (hence can't use set). + deduped_full_single_chain_msa = list(dict.fromkeys(full_single_chain_msa)) + total_msa_size = len(deduped_full_single_chain_msa) + print(f'\n{total_msa_size} unique sequences found in total for sequence ' + f'{sequence_index}\n') + + aa_map = {res: i for i, res in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ-')} + msa_arr = np.array( + [[aa_map[aa] for aa in seq] for seq in deduped_full_single_chain_msa]) + + plt.figure(figsize=(12, 3)) + plt.title(f'Per-Residue Count of Non-Gap Amino Acids in the MSA for Sequence ' + f'{sequence_index}') + plt.plot(np.sum(msa_arr != aa_map['-'], axis=0), color='black') + plt.ylabel('Non-Gap Count') + plt.yticks(range(0, total_msa_size + 1, max(1, int(total_msa_size / 3)))) + plt.show() + + +def empty_placeholder_template_features( + num_templates: int, num_res: int) -> Mapping[str, np.ndarray]: + return { + 'template_aatype': np.zeros( + (num_templates, num_res, + len(residue_constants.restypes_with_x_and_gap)), dtype=np.float32), + 'template_all_atom_masks': np.zeros( + (num_templates, num_res, residue_constants.atom_type_num), + dtype=np.float32), + 'template_all_atom_positions': np.zeros( + (num_templates, num_res, residue_constants.atom_type_num, 3), + dtype=np.float32), + 'template_domain_names': np.zeros([num_templates], dtype=np.object), + 'template_sequence': np.zeros([num_templates], dtype=np.object), + 'template_sum_probs': np.zeros([num_templates], dtype=np.float32), + } + + +def get_pae_json(pae: np.ndarray, max_pae: float) -> str: + """Returns the PAE in the same format as is used in the AFDB.""" + rounded_errors = np.round(pae.astype(np.float64), decimals=1) + indices = np.indices((len(rounded_errors), len(rounded_errors))) + 1 + indices_1 = indices[0].flatten().tolist() + indices_2 = indices[1].flatten().tolist() + return json.dumps( + [{'residue1': indices_1, + 'residue2': indices_2, + 'distance': rounded_errors.flatten().tolist(), + 'max_predicted_aligned_error': max_pae}], + indent=None, separators=(',', ':')) |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/notebooks/notebook_utils_test.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/notebooks/notebook_utils_test.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,203 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Tests for notebook_utils."""\n+import io\n+\n+from absl.testing import absltest\n+from absl.testing import parameterized\n+from alphafold.data import parsers\n+from alphafold.data import templates\n+from alphafold.notebooks import notebook_utils\n+\n+import mock\n+import numpy as np\n+\n+\n+ONLY_QUERY_HIT = {\n+ \'sto\': (\n+ \'# STOCKHOLM 1.0\\n\'\n+ \'#=GF ID query-i1\\n\'\n+ \'query MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEH\\n\'\n+ \'//\\n\'),\n+ \'tbl\': \'\',\n+ \'stderr\': b\'\',\n+ \'n_iter\': 1,\n+ \'e_value\': 0.0001}\n+\n+# pylint: disable=line-too-long\n+MULTI_SEQUENCE_HIT_1 = {\n+ \'sto\': (\n+ \'# STOCKHOLM 1.0\\n\'\n+ \'#=GF ID query-i1\\n\'\n+ \'#=GS ERR1700680_4602609/41-109 DE [subseq from] ERR1700680_4602609\\n\'\n+ \'#=GS ERR1019366_5760491/40-105 DE [subseq from] ERR1019366_5760491\\n\'\n+ \'#=GS SRR5580704_12853319/61-125 DE [subseq from] SRR5580704_12853319\\n\'\n+ \'query MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH\\n\'\n+ \'ERR1700680_4602609/41-109 --INKGAEYHKKAAEHHELAAKHHREAAKHHEAGSHEKAAHHSEIAAGHGLTAVHHTEEATK-HHPEEHTEK--\\n\'\n+ \'ERR1019366_5760491/40-105 ---RSGAQHHDAAAQHYEEAARHHRMAAKQYQASHHEKAAHYAQLAYAHHMYAEQHAAEAAK-AHAKNHG----\\n\'\n+ \'SRR5580704_12853319/61-125 ----PAADHHMKAAEHHEEAAKHHRAAAEHHTAGDHQKAGHHAHVANGHHVNAVHHAEEASK-HHATDHS----\\n\'\n+ \'//\\n\'),\n+ \'tbl\': (\n+ \'ERR1700680_4602609 - query - 7.7e-09 47.7 33.8 1.1e-08 47.2 33.8 1.2 1 0 0 1 1 1 1 -\\n\'\n+ \'ERR1019366_5760491 - query - 1.7e-08 46.6 33.1 2.5e-08 46.1 33.1 1.3 1 0 0 1 1 1 1 -\\n\'\n+ \'SRR5580704_12853319 - query - 1.1e-07 44.0 41.6 2e-07 43.1 41.6 1.4 1 0 0 1 1 1 1 -\\n\'),\n+ \'stderr\': b\'\',\n+ \'n_iter\': 1,\n+ \'e_value\': 0.0001}\n+\n+MULTI_SEQUENCE_HIT_2 = {\n+ \'sto\': (\n+ \'# STOCKHOLM 1.0\\n\'\n+ \'#=GF ID query-i1\\n\'\n+ \'#=GS ERR1700719_3476944/70-137 DE [subseq from] ERR1700719_3476944\\n\'\n+ \'#=GS ERR1700761_4254522/72-138 DE [subseq from] ERR1700761_4254522\\n\'\n+ \'#=GS SRR5438477_9761204/64-132 DE [subseq from] SRR5438477_9761204\\n\'\n+ \'query MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH\\n\'\n+ \'ERR1700719_3476944/70-137 ---KQAAEHHHQAAEHHEHAARHHREAAKHHEAGDHESAAHHAHTAQGHLHQATHHASEAAKLHVEHHGQK--\\n\'\n+ \'ERR1700761_4254522/72-138 ----QASEHHNLAAEHHEHAARHHRDAAKHHKAGDHEKAAHHAHVAHGHHLHATHHATEAAKHHVEAHGEK--\\n\'\n+ \'SRR5438477_9761204/64-132 MPKHEGAEHHKKAAEHNEHAARHHKEAARHHEEGSHEKVGHHAHIAHGHHLHATHHAEEAAKTHSNQHE----\\n\'\n+ \'//\\n\'),\n+ \'tbl\': (\n+ \'ERR1700719_3476944 - query - 2e-07 43.2 47.5 3.5e-07 42.4 47.5 1.4 1 0 0 1 1 1 1 -\\n\'\n+ \'ERR1700761_4254522 - query - 6.1e-07 41.6 48.1 8.1e-07 41.3 48.1 1.2 1 0 0 1 1 1 1 -\\n\'\n+ \'SRR5438477_9761204 - query - 1.8e-06 40.2 46.9 2.3e-06 39.8 46.9 1.2 1 0 0 1 1 1 1 -\\n\'),\n+ \'stderr\': b\'\',\n+ '..b'.MULTIMER))\n+ def test_validate_input_ok(\n+ self, input_sequences, exp_sequences, exp_model_type):\n+ sequences, model_type = notebook_utils.validate_input(\n+ input_sequences=input_sequences,\n+ min_length=1, max_length=100, max_multimer_length=100)\n+ self.assertSequenceEqual(sequences, exp_sequences)\n+ self.assertEqual(model_type, exp_model_type)\n+\n+ @parameterized.named_parameters(\n+ (\'no_input_sequence\', [\'\', \'\\t\', \'\\n\'], \'No input amino acid sequence\'),\n+ (\'too_long_single\', [\'AAAAAAAAA\', \'AAAA\'], \'Input sequence is too long\'),\n+ (\'too_long_multimer\', [\'AAAA\', \'AAAAA\'], \'The total length of multimer\'))\n+ def test_validate_input_bad(self, input_sequences, exp_error):\n+ with self.assertRaisesRegex(ValueError, f\'.*{exp_error}.*\'):\n+ notebook_utils.validate_input(\n+ input_sequences=input_sequences,\n+ min_length=4, max_length=8, max_multimer_length=6)\n+\n+ def test_merge_chunked_msa_no_hits(self):\n+ results = [ONLY_QUERY_HIT, ONLY_QUERY_HIT]\n+ merged_msa = notebook_utils.merge_chunked_msa(\n+ results=results)\n+ self.assertSequenceEqual(\n+ merged_msa.sequences,\n+ (\'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEH\',))\n+ self.assertSequenceEqual(merged_msa.deletion_matrix, ([0] * 56,))\n+\n+ def test_merge_chunked_msa(self):\n+ results = [MULTI_SEQUENCE_HIT_1, MULTI_SEQUENCE_HIT_2]\n+ merged_msa = notebook_utils.merge_chunked_msa(\n+ results=results)\n+ self.assertLen(merged_msa.sequences, 7)\n+ # The 1st one is the query.\n+ self.assertEqual(\n+ merged_msa.sequences[0],\n+ \'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAP\'\n+ \'KPH\')\n+ # The 2nd one is the one with the lowest e-value: ERR1700680_4602609.\n+ self.assertEqual(\n+ merged_msa.sequences[1],\n+ \'--INKGAEYHKKAAEHHELAAKHHREAAKHHEAGSHEKAAHHSEIAAGHGLTAVHHTEEATK-HHPEEHT\'\n+ \'EK-\')\n+ # The last one is the one with the largest e-value: SRR5438477_9761204.\n+ self.assertEqual(\n+ merged_msa.sequences[-1],\n+ \'MPKHEGAEHHKKAAEHNEHAARHHKEAARHHEEGSHEKVGHHAHIAHGHHLHATHHAEEAAKTHSNQHE-\'\n+ \'---\')\n+ self.assertLen(merged_msa.deletion_matrix, 7)\n+\n+ @mock.patch(\'sys.stdout\', new_callable=io.StringIO)\n+ def test_show_msa_info(self, mocked_stdout):\n+ single_chain_msas = [\n+ parsers.Msa(sequences=[\'A\', \'B\', \'C\', \'C\'],\n+ deletion_matrix=[None] * 4,\n+ descriptions=[\'\'] * 4),\n+ parsers.Msa(sequences=[\'A\', \'A\', \'A\', \'D\'],\n+ deletion_matrix=[None] * 4,\n+ descriptions=[\'\'] * 4)\n+ ]\n+ notebook_utils.show_msa_info(\n+ single_chain_msas=single_chain_msas, sequence_index=1)\n+ self.assertEqual(mocked_stdout.getvalue(),\n+ \'\\n4 unique sequences found in total for sequence 1\\n\\n\')\n+\n+ @parameterized.named_parameters(\n+ (\'some_templates\', 4), (\'no_templates\', 0))\n+ def test_empty_placeholder_template_features(self, num_templates):\n+ template_features = notebook_utils.empty_placeholder_template_features(\n+ num_templates=num_templates, num_res=16)\n+ self.assertCountEqual(template_features.keys(),\n+ templates.TEMPLATE_FEATURES.keys())\n+ self.assertSameElements(\n+ [v.shape[0] for v in template_features.values()], [num_templates])\n+ self.assertSequenceEqual(\n+ [t.dtype for t in template_features.values()],\n+ [np.array([], dtype=templates.TEMPLATE_FEATURES[feat_name]).dtype\n+ for feat_name in template_features])\n+\n+ def test_get_pae_json(self):\n+ pae = np.array([[0.01, 13.12345], [20.0987, 0.0]])\n+ pae_json = notebook_utils.get_pae_json(pae=pae, max_pae=31.75)\n+ self.assertEqual(\n+ pae_json,\n+ \'[{"residue1":[1,1,2,2],"residue2":[1,2,1,2],"distance":\'\n+ \'[0.0,13.1,20.1,0.0],"max_predicted_aligned_error":31.75}]\')\n+\n+\n+if __name__ == \'__main__\':\n+ absltest.main()\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/relax/__init__.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/relax/__init__.py Tue Mar 01 02:53:05 2022 +0000 |
b |
@@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Amber relaxation.""" |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/relax/amber_minimize.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/relax/amber_minimize.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,543 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Restrained Amber Minimization of a structure."""\n+\n+import io\n+import time\n+from typing import Collection, Optional, Sequence\n+\n+from absl import logging\n+from alphafold.common import protein\n+from alphafold.common import residue_constants\n+from alphafold.model import folding\n+from alphafold.relax import cleanup\n+from alphafold.relax import utils\n+import ml_collections\n+import numpy as np\n+from simtk import openmm\n+from simtk import unit\n+from simtk.openmm import app as openmm_app\n+from simtk.openmm.app.internal.pdbstructure import PdbStructure\n+\n+\n+ENERGY = unit.kilocalories_per_mole\n+LENGTH = unit.angstroms\n+\n+\n+def will_restrain(atom: openmm_app.Atom, rset: str) -> bool:\n+ """Returns True if the atom will be restrained by the given restraint set."""\n+\n+ if rset == "non_hydrogen":\n+ return atom.element.name != "hydrogen"\n+ elif rset == "c_alpha":\n+ return atom.name == "CA"\n+\n+\n+def _add_restraints(\n+ system: openmm.System,\n+ reference_pdb: openmm_app.PDBFile,\n+ stiffness: unit.Unit,\n+ rset: str,\n+ exclude_residues: Sequence[int]):\n+ """Adds a harmonic potential that restrains the system to a structure."""\n+ assert rset in ["non_hydrogen", "c_alpha"]\n+\n+ force = openmm.CustomExternalForce(\n+ "0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)")\n+ force.addGlobalParameter("k", stiffness)\n+ for p in ["x0", "y0", "z0"]:\n+ force.addPerParticleParameter(p)\n+\n+ for i, atom in enumerate(reference_pdb.topology.atoms()):\n+ if atom.residue.index in exclude_residues:\n+ continue\n+ if will_restrain(atom, rset):\n+ force.addParticle(i, reference_pdb.positions[i])\n+ logging.info("Restraining %d / %d particles.",\n+ force.getNumParticles(), system.getNumParticles())\n+ system.addForce(force)\n+\n+\n+def _openmm_minimize(\n+ pdb_str: str,\n+ max_iterations: int,\n+ tolerance: unit.Unit,\n+ stiffness: unit.Unit,\n+ restraint_set: str,\n+ exclude_residues: Sequence[int]):\n+ """Minimize energy via openmm."""\n+\n+ pdb_file = io.StringIO(pdb_str)\n+ pdb = openmm_app.PDBFile(pdb_file)\n+\n+ force_field = openmm_app.ForceField("amber99sb.xml")\n+ constraints = openmm_app.HBonds\n+ system = force_field.createSystem(\n+ pdb.topology, constraints=constraints)\n+ if stiffness > 0 * ENERGY / (LENGTH**2):\n+ _add_restraints(system, pdb, stiffness, restraint_set, exclude_residues)\n+\n+ integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)\n+ platform = openmm.Platform.getPlatformByName("CPU")\n+ simulation = openmm_app.Simulation(\n+ pdb.topology, system, integrator, platform)\n+ simulation.context.setPositions(pdb.positions)\n+\n+ ret = {}\n+ state = simulation.context.getState(getEnergy=True, getPositions=True)\n+ ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY)\n+ ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)\n+ simulation.minimizeEnergy(maxIterations=max_iterations,\n+ tolerance=tolerance)\n+ state = simulation.context.getState(getEnergy=True, getPositions=True)\n+ ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY)\n+ ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)\n+ ret["min_pdb"] = _get_pdb_string(simulation.topology, state.getPositions())\n+ return ret\n+\n+\n+def _get_pdb_string(topology: openmm_app.Topology, positions: unit.Quantity):\n+ """Returns a pdb s'..b'rance of L-BFGS.\n+ The default value is the OpenMM default.\n+ restraint_set: The set of atoms to restrain.\n+ max_attempts: The maximum number of minimization attempts per iteration.\n+ checks: Whether to perform cleaning checks.\n+ exclude_residues: An optional list of zero-indexed residues to exclude from\n+ restraints.\n+\n+ Returns:\n+ out: A dictionary of output values.\n+ """\n+\n+ # `protein.to_pdb` will strip any poorly-defined residues so we need to\n+ # perform this check before `clean_protein`.\n+ _check_residues_are_well_defined(prot)\n+ pdb_string = clean_protein(prot, checks=checks)\n+\n+ exclude_residues = exclude_residues or []\n+ exclude_residues = set(exclude_residues)\n+ violations = np.inf\n+ iteration = 0\n+\n+ while violations > 0 and iteration < max_outer_iterations:\n+ ret = _run_one_iteration(\n+ pdb_string=pdb_string,\n+ exclude_residues=exclude_residues,\n+ max_iterations=max_iterations,\n+ tolerance=tolerance,\n+ stiffness=stiffness,\n+ restraint_set=restraint_set,\n+ max_attempts=max_attempts)\n+ prot = protein.from_pdb_string(ret["min_pdb"])\n+ if place_hydrogens_every_iteration:\n+ pdb_string = clean_protein(prot, checks=True)\n+ else:\n+ pdb_string = ret["min_pdb"]\n+ ret.update(get_violation_metrics(prot))\n+ ret.update({\n+ "num_exclusions": len(exclude_residues),\n+ "iteration": iteration,\n+ })\n+ violations = ret["violations_per_residue"]\n+ exclude_residues = exclude_residues.union(ret["residue_violations"])\n+\n+ logging.info("Iteration completed: Einit %.2f Efinal %.2f Time %.2f s "\n+ "num residue violations %d num residue exclusions %d ",\n+ ret["einit"], ret["efinal"], ret["opt_time"],\n+ ret["num_residue_violations"], ret["num_exclusions"])\n+ iteration += 1\n+ return ret\n+\n+\n+def get_initial_energies(pdb_strs: Sequence[str],\n+ stiffness: float = 0.0,\n+ restraint_set: str = "non_hydrogen",\n+ exclude_residues: Optional[Sequence[int]] = None):\n+ """Returns initial potential energies for a sequence of PDBs.\n+\n+ Assumes the input PDBs are ready for minimization, and all have the same\n+ topology.\n+ Allows time to be saved by not pdbfixing / rebuilding the system.\n+\n+ Args:\n+ pdb_strs: List of PDB strings.\n+ stiffness: kcal/mol A**2, spring constant of heavy atom restraining\n+ potential.\n+ restraint_set: Which atom types to restrain.\n+ exclude_residues: An optional list of zero-indexed residues to exclude from\n+ restraints.\n+\n+ Returns:\n+ A list of initial energies in the same order as pdb_strs.\n+ """\n+ exclude_residues = exclude_residues or []\n+\n+ openmm_pdbs = [openmm_app.PDBFile(PdbStructure(io.StringIO(p)))\n+ for p in pdb_strs]\n+ force_field = openmm_app.ForceField("amber99sb.xml")\n+ system = force_field.createSystem(openmm_pdbs[0].topology,\n+ constraints=openmm_app.HBonds)\n+ stiffness = stiffness * ENERGY / (LENGTH**2)\n+ if stiffness > 0 * ENERGY / (LENGTH**2):\n+ _add_restraints(system, openmm_pdbs[0], stiffness, restraint_set,\n+ exclude_residues)\n+ simulation = openmm_app.Simulation(openmm_pdbs[0].topology,\n+ system,\n+ openmm.LangevinIntegrator(0, 0.01, 0.0),\n+ openmm.Platform.getPlatformByName("CPU"))\n+ energies = []\n+ for pdb in openmm_pdbs:\n+ try:\n+ simulation.context.setPositions(pdb.positions)\n+ state = simulation.context.getState(getEnergy=True)\n+ energies.append(state.getPotentialEnergy().value_in_unit(ENERGY))\n+ except Exception as e: # pylint: disable=broad-except\n+ logging.error("Error getting initial energy, returning large value %s", e)\n+ energies.append(unit.Quantity(1e20, ENERGY))\n+ return energies\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/relax/amber_minimize_test.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/relax/amber_minimize_test.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,130 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for amber_minimize.""" +import os + +from absl.testing import absltest +from alphafold.common import protein +from alphafold.relax import amber_minimize +import numpy as np +# Internal import (7716). + + +def _load_test_protein(data_path): + pdb_path = os.path.join(absltest.get_default_test_srcdir(), data_path) + with open(pdb_path, 'r') as f: + return protein.from_pdb_string(f.read()) + + +class AmberMinimizeTest(absltest.TestCase): + + def test_multiple_disulfides_target(self): + prot = _load_test_protein( + 'alphafold/relax/testdata/multiple_disulfides_target.pdb' + ) + ret = amber_minimize.run_pipeline(prot, max_iterations=10, max_attempts=1, + stiffness=10.) + self.assertIn('opt_time', ret) + self.assertIn('min_attempts', ret) + + def test_raises_invalid_protein_assertion(self): + prot = _load_test_protein( + 'alphafold/relax/testdata/multiple_disulfides_target.pdb' + ) + prot.atom_mask[4, :] = 0 + with self.assertRaisesRegex( + ValueError, + 'Amber minimization can only be performed on proteins with well-defined' + ' residues. This protein contains at least one residue with no atoms.'): + amber_minimize.run_pipeline(prot, max_iterations=10, + stiffness=1., + max_attempts=1) + + def test_iterative_relax(self): + prot = _load_test_protein( + 'alphafold/relax/testdata/with_violations.pdb' + ) + violations = amber_minimize.get_violation_metrics(prot) + self.assertGreater(violations['num_residue_violations'], 0) + out = amber_minimize.run_pipeline( + prot=prot, max_outer_iterations=10, stiffness=10.) + self.assertLess(out['efinal'], out['einit']) + self.assertEqual(0, out['num_residue_violations']) + + def test_find_violations(self): + prot = _load_test_protein( + 'alphafold/relax/testdata/multiple_disulfides_target.pdb' + ) + viols, _ = amber_minimize.find_violations(prot) + + expected_between_residues_connection_mask = np.zeros((191,), np.float32) + for residue in (42, 43, 59, 60, 135, 136): + expected_between_residues_connection_mask[residue] = 1.0 + + expected_clash_indices = np.array([ + [8, 4], + [8, 5], + [13, 3], + [14, 1], + [14, 4], + [26, 4], + [26, 5], + [31, 8], + [31, 10], + [39, 0], + [39, 1], + [39, 2], + [39, 3], + [39, 4], + [42, 5], + [42, 6], + [42, 7], + [42, 8], + [47, 7], + [47, 8], + [47, 9], + [47, 10], + [64, 4], + [85, 5], + [102, 4], + [102, 5], + [109, 13], + [111, 5], + [118, 6], + [118, 7], + [118, 8], + [124, 4], + [124, 5], + [131, 5], + [139, 7], + [147, 4], + [152, 7]], dtype=np.int32) + expected_between_residues_clash_mask = np.zeros([191, 14]) + expected_between_residues_clash_mask[expected_clash_indices[:, 0], + expected_clash_indices[:, 1]] += 1 + expected_per_atom_violations = np.zeros([191, 14]) + np.testing.assert_array_equal( + viols['between_residues']['connections_per_residue_violation_mask'], + expected_between_residues_connection_mask) + np.testing.assert_array_equal( + viols['between_residues']['clashes_per_atom_clash_mask'], + expected_between_residues_clash_mask) + np.testing.assert_array_equal( + viols['within_residues']['per_atom_violations'], + expected_per_atom_violations) + + +if __name__ == '__main__': + absltest.main() |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/relax/cleanup.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/relax/cleanup.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,127 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations. + +fix_pdb uses a third-party tool. We also support fixing some additional edge +cases like removing chains of length one (see clean_structure). +""" +import io + +import pdbfixer +from simtk.openmm import app +from simtk.openmm.app import element + + +def fix_pdb(pdbfile, alterations_info): + """Apply pdbfixer to the contents of a PDB file; return a PDB string result. + + 1) Replaces nonstandard residues. + 2) Removes heterogens (non protein residues) including water. + 3) Adds missing residues and missing atoms within existing residues. + 4) Adds hydrogens assuming pH=7.0. + 5) KeepIds is currently true, so the fixer must keep the existing chain and + residue identifiers. This will fail for some files in wider PDB that have + invalid IDs. + + Args: + pdbfile: Input PDB file handle. + alterations_info: A dict that will store details of changes made. + + Returns: + A PDB string representing the fixed structure. + """ + fixer = pdbfixer.PDBFixer(pdbfile=pdbfile) + fixer.findNonstandardResidues() + alterations_info['nonstandard_residues'] = fixer.nonstandardResidues + fixer.replaceNonstandardResidues() + _remove_heterogens(fixer, alterations_info, keep_water=False) + fixer.findMissingResidues() + alterations_info['missing_residues'] = fixer.missingResidues + fixer.findMissingAtoms() + alterations_info['missing_heavy_atoms'] = fixer.missingAtoms + alterations_info['missing_terminals'] = fixer.missingTerminals + fixer.addMissingAtoms(seed=0) + fixer.addMissingHydrogens() + out_handle = io.StringIO() + app.PDBFile.writeFile(fixer.topology, fixer.positions, out_handle, + keepIds=True) + return out_handle.getvalue() + + +def clean_structure(pdb_structure, alterations_info): + """Applies additional fixes to an OpenMM structure, to handle edge cases. + + Args: + pdb_structure: An OpenMM structure to modify and fix. + alterations_info: A dict that will store details of changes made. + """ + _replace_met_se(pdb_structure, alterations_info) + _remove_chains_of_length_one(pdb_structure, alterations_info) + + +def _remove_heterogens(fixer, alterations_info, keep_water): + """Removes the residues that Pdbfixer considers to be heterogens. + + Args: + fixer: A Pdbfixer instance. + alterations_info: A dict that will store details of changes made. + keep_water: If True, water (HOH) is not considered to be a heterogen. + """ + initial_resnames = set() + for chain in fixer.topology.chains(): + for residue in chain.residues(): + initial_resnames.add(residue.name) + fixer.removeHeterogens(keepWater=keep_water) + final_resnames = set() + for chain in fixer.topology.chains(): + for residue in chain.residues(): + final_resnames.add(residue.name) + alterations_info['removed_heterogens'] = ( + initial_resnames.difference(final_resnames)) + + +def _replace_met_se(pdb_structure, alterations_info): + """Replace the Se in any MET residues that were not marked as modified.""" + modified_met_residues = [] + for res in pdb_structure.iter_residues(): + name = res.get_name_with_spaces().strip() + if name == 'MET': + s_atom = res.get_atom('SD') + if s_atom.element_symbol == 'Se': + s_atom.element_symbol = 'S' + s_atom.element = element.get_by_symbol('S') + modified_met_residues.append(s_atom.residue_number) + alterations_info['Se_in_MET'] = modified_met_residues + + +def _remove_chains_of_length_one(pdb_structure, alterations_info): + """Removes chains that correspond to a single amino acid. + + A single amino acid in a chain is both N and C terminus. There is no force + template for this case. + + Args: + pdb_structure: An OpenMM pdb_structure to modify and fix. + alterations_info: A dict that will store details of changes made. + """ + removed_chains = {} + for model in pdb_structure.iter_models(): + valid_chains = [c for c in model.iter_chains() if len(c) > 1] + invalid_chain_ids = [c.chain_id for c in model.iter_chains() if len(c) <= 1] + model.chains = valid_chains + for chain_id in invalid_chain_ids: + model.chains_by_id.pop(chain_id) + removed_chains[model.number] = invalid_chain_ids + alterations_info['removed_chains'] = removed_chains |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/relax/cleanup_test.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/relax/cleanup_test.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,137 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for relax.cleanup.""" +import io + +from absl.testing import absltest +from alphafold.relax import cleanup +from simtk.openmm.app.internal import pdbstructure + + +def _pdb_to_structure(pdb_str): + handle = io.StringIO(pdb_str) + return pdbstructure.PdbStructure(handle) + + +def _lines_to_structure(pdb_lines): + return _pdb_to_structure('\n'.join(pdb_lines)) + + +class CleanupTest(absltest.TestCase): + + def test_missing_residues(self): + pdb_lines = ['SEQRES 1 C 3 CYS GLY LEU', + 'ATOM 1 N CYS C 1 -12.262 20.115 60.959 1.00 ' + '19.08 N', + 'ATOM 2 CA CYS C 1 -11.065 20.934 60.773 1.00 ' + '17.23 C', + 'ATOM 3 C CYS C 1 -10.002 20.742 61.844 1.00 ' + '15.38 C', + 'ATOM 4 O CYS C 1 -10.284 20.225 62.929 1.00 ' + '16.04 O', + 'ATOM 5 N LEU C 3 -7.688 18.700 62.045 1.00 ' + '14.75 N', + 'ATOM 6 CA LEU C 3 -7.256 17.320 62.234 1.00 ' + '16.81 C', + 'ATOM 7 C LEU C 3 -6.380 16.864 61.070 1.00 ' + '16.95 C', + 'ATOM 8 O LEU C 3 -6.551 17.332 59.947 1.00 ' + '16.97 O'] + input_handle = io.StringIO('\n'.join(pdb_lines)) + alterations = {} + result = cleanup.fix_pdb(input_handle, alterations) + structure = _pdb_to_structure(result) + residue_names = [r.get_name() for r in structure.iter_residues()] + self.assertCountEqual(residue_names, ['CYS', 'GLY', 'LEU']) + self.assertCountEqual(alterations['missing_residues'].values(), [['GLY']]) + + def test_missing_atoms(self): + pdb_lines = ['SEQRES 1 A 1 PRO', + 'ATOM 1 CA PRO A 1 1.000 1.000 1.000 1.00 ' + ' 0.00 C'] + input_handle = io.StringIO('\n'.join(pdb_lines)) + alterations = {} + result = cleanup.fix_pdb(input_handle, alterations) + structure = _pdb_to_structure(result) + atom_names = [a.get_name() for a in structure.iter_atoms()] + self.assertCountEqual(atom_names, ['N', 'CD', 'HD2', 'HD3', 'CG', 'HG2', + 'HG3', 'CB', 'HB2', 'HB3', 'CA', 'HA', + 'C', 'O', 'H2', 'H3', 'OXT']) + missing_atoms_by_residue = list(alterations['missing_heavy_atoms'].values()) + self.assertLen(missing_atoms_by_residue, 1) + atoms_added = [a.name for a in missing_atoms_by_residue[0]] + self.assertCountEqual(atoms_added, ['N', 'CD', 'CG', 'CB', 'C', 'O']) + missing_terminals_by_residue = alterations['missing_terminals'] + self.assertLen(missing_terminals_by_residue, 1) + has_missing_terminal = [r.name for r in missing_terminals_by_residue.keys()] + self.assertCountEqual(has_missing_terminal, ['PRO']) + self.assertCountEqual([t for t in missing_terminals_by_residue.values()], + [['OXT']]) + + def test_remove_heterogens(self): + pdb_lines = ['SEQRES 1 A 1 GLY', + 'ATOM 1 CA GLY A 1 0.000 0.000 0.000 1.00 ' + ' 0.00 C', + 'ATOM 2 O HOH A 2 0.000 0.000 0.000 1.00 ' + ' 0.00 O'] + input_handle = io.StringIO('\n'.join(pdb_lines)) + alterations = {} + result = cleanup.fix_pdb(input_handle, alterations) + structure = _pdb_to_structure(result) + self.assertCountEqual([res.get_name() for res in structure.iter_residues()], + ['GLY']) + self.assertEqual(alterations['removed_heterogens'], set(['HOH'])) + + def test_fix_nonstandard_residues(self): + pdb_lines = ['SEQRES 1 A 1 DAL', + 'ATOM 1 CA DAL A 1 0.000 0.000 0.000 1.00 ' + ' 0.00 C'] + input_handle = io.StringIO('\n'.join(pdb_lines)) + alterations = {} + result = cleanup.fix_pdb(input_handle, alterations) + structure = _pdb_to_structure(result) + residue_names = [res.get_name() for res in structure.iter_residues()] + self.assertCountEqual(residue_names, ['ALA']) + self.assertLen(alterations['nonstandard_residues'], 1) + original_res, new_name = alterations['nonstandard_residues'][0] + self.assertEqual(original_res.id, '1') + self.assertEqual(new_name, 'ALA') + + def test_replace_met_se(self): + pdb_lines = ['SEQRES 1 A 1 MET', + 'ATOM 1 SD MET A 1 0.000 0.000 0.000 1.00 ' + ' 0.00 Se'] + structure = _lines_to_structure(pdb_lines) + alterations = {} + cleanup._replace_met_se(structure, alterations) + sd = [a for a in structure.iter_atoms() if a.get_name() == 'SD'] + self.assertLen(sd, 1) + self.assertEqual(sd[0].element_symbol, 'S') + self.assertCountEqual(alterations['Se_in_MET'], [sd[0].residue_number]) + + def test_remove_chains_of_length_one(self): + pdb_lines = ['SEQRES 1 A 1 GLY', + 'ATOM 1 CA GLY A 1 0.000 0.000 0.000 1.00 ' + ' 0.00 C'] + structure = _lines_to_structure(pdb_lines) + alterations = {} + cleanup._remove_chains_of_length_one(structure, alterations) + chains = list(structure.iter_chains()) + self.assertEmpty(chains) + self.assertCountEqual(alterations['removed_chains'].values(), [['A']]) + + +if __name__ == '__main__': + absltest.main() |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/relax/relax.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/relax/relax.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,80 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Amber relaxation.""" +from typing import Any, Dict, Sequence, Tuple +from alphafold.common import protein +from alphafold.relax import amber_minimize +from alphafold.relax import utils +import numpy as np + + +class AmberRelaxation(object): + """Amber relaxation.""" + + def __init__(self, + *, + max_iterations: int, + tolerance: float, + stiffness: float, + exclude_residues: Sequence[int], + max_outer_iterations: int): + """Initialize Amber Relaxer. + + Args: + max_iterations: Maximum number of L-BFGS iterations. 0 means no max. + tolerance: kcal/mol, the energy tolerance of L-BFGS. + stiffness: kcal/mol A**2, spring constant of heavy atom restraining + potential. + exclude_residues: Residues to exclude from per-atom restraining. + Zero-indexed. + max_outer_iterations: Maximum number of violation-informed relax + iterations. A value of 1 will run the non-iterative procedure used in + CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes + as soon as there are no violations, hence in most cases this causes no + slowdown. In the worst case we do 20 outer iterations. + """ + + self._max_iterations = max_iterations + self._tolerance = tolerance + self._stiffness = stiffness + self._exclude_residues = exclude_residues + self._max_outer_iterations = max_outer_iterations + + def process(self, *, + prot: protein.Protein) -> Tuple[str, Dict[str, Any], np.ndarray]: + """Runs Amber relax on a prediction, adds hydrogens, returns PDB string.""" + out = amber_minimize.run_pipeline( + prot=prot, max_iterations=self._max_iterations, + tolerance=self._tolerance, stiffness=self._stiffness, + exclude_residues=self._exclude_residues, + max_outer_iterations=self._max_outer_iterations) + min_pos = out['pos'] + start_pos = out['posinit'] + rmsd = np.sqrt(np.sum((start_pos - min_pos)**2) / start_pos.shape[0]) + debug_data = { + 'initial_energy': out['einit'], + 'final_energy': out['efinal'], + 'attempts': out['min_attempts'], + 'rmsd': rmsd + } + pdb_str = amber_minimize.clean_protein(prot) + min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos) + min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors) + utils.assert_equal_nonterminal_atom_types( + protein.from_pdb_string(min_pdb).atom_mask, + prot.atom_mask) + violations = out['structural_violations'][ + 'total_per_residue_violations_mask'] + return min_pdb, debug_data, violations |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/relax/relax_test.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/relax/relax_test.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,88 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for relax.""" +import os + +from absl.testing import absltest +from alphafold.common import protein +from alphafold.relax import relax +import numpy as np +# Internal import (7716). + + +class RunAmberRelaxTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.test_dir = os.path.join( + absltest.get_default_test_srcdir(), + 'alphafold/relax/testdata/') + self.test_config = { + 'max_iterations': 1, + 'tolerance': 2.39, + 'stiffness': 10.0, + 'exclude_residues': [], + 'max_outer_iterations': 1} + + def test_process(self): + amber_relax = relax.AmberRelaxation(**self.test_config) + + with open(os.path.join(self.test_dir, 'model_output.pdb')) as f: + test_prot = protein.from_pdb_string(f.read()) + pdb_min, debug_info, num_violations = amber_relax.process(prot=test_prot) + + self.assertCountEqual(debug_info.keys(), + set({'initial_energy', 'final_energy', + 'attempts', 'rmsd'})) + self.assertLess(debug_info['final_energy'], debug_info['initial_energy']) + self.assertGreater(debug_info['rmsd'], 0) + + prot_min = protein.from_pdb_string(pdb_min) + # Most protein properties should be unchanged. + np.testing.assert_almost_equal(test_prot.aatype, prot_min.aatype) + np.testing.assert_almost_equal(test_prot.residue_index, + prot_min.residue_index) + # Atom mask and bfactors identical except for terminal OXT of last residue. + np.testing.assert_almost_equal(test_prot.atom_mask[:-1, :], + prot_min.atom_mask[:-1, :]) + np.testing.assert_almost_equal(test_prot.b_factors[:-1, :], + prot_min.b_factors[:-1, :]) + np.testing.assert_almost_equal(test_prot.atom_mask[:, :-1], + prot_min.atom_mask[:, :-1]) + np.testing.assert_almost_equal(test_prot.b_factors[:, :-1], + prot_min.b_factors[:, :-1]) + # There are no residues with violations. + np.testing.assert_equal(num_violations, np.zeros_like(num_violations)) + + def test_unresolved_violations(self): + amber_relax = relax.AmberRelaxation(**self.test_config) + with open(os.path.join(self.test_dir, + 'with_violations_casp14.pdb')) as f: + test_prot = protein.from_pdb_string(f.read()) + _, _, num_violations = amber_relax.process(prot=test_prot) + exp_num_violations = np.array( + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 0, 0, 0]) + # Check no violations were added. Can't check exactly due to stochasticity. + self.assertTrue(np.all(num_violations <= exp_num_violations)) + + +if __name__ == '__main__': + absltest.main() |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/relax/testdata/model_output.pdb --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/relax/testdata/model_output.pdb Tue Mar 01 02:53:05 2022 +0000 |
b |
@@ -0,0 +1,98 @@ +ATOM 1 C MET A 1 1.921 -46.152 7.786 1.00 4.39 C +ATOM 2 CA MET A 1 1.631 -46.829 9.131 1.00 4.39 C +ATOM 3 CB MET A 1 2.759 -47.768 9.578 1.00 4.39 C +ATOM 4 CE MET A 1 3.466 -49.770 13.198 1.00 4.39 C +ATOM 5 CG MET A 1 2.581 -48.221 11.034 1.00 4.39 C +ATOM 6 H MET A 1 0.234 -48.249 8.549 1.00 4.39 H +ATOM 7 H2 MET A 1 -0.424 -46.789 8.952 1.00 4.39 H +ATOM 8 H3 MET A 1 0.111 -47.796 10.118 1.00 4.39 H +ATOM 9 HA MET A 1 1.628 -46.009 9.849 1.00 4.39 H +ATOM 10 HB2 MET A 1 3.701 -47.225 9.500 1.00 4.39 H +ATOM 11 HB3 MET A 1 2.807 -48.640 8.926 1.00 4.39 H +ATOM 12 HE1 MET A 1 2.747 -50.537 12.910 1.00 4.39 H +ATOM 13 HE2 MET A 1 4.296 -50.241 13.725 1.00 4.39 H +ATOM 14 HE3 MET A 1 2.988 -49.052 13.864 1.00 4.39 H +ATOM 15 HG2 MET A 1 1.791 -48.971 11.083 1.00 4.39 H +ATOM 16 HG3 MET A 1 2.295 -47.368 11.650 1.00 4.39 H +ATOM 17 N MET A 1 0.291 -47.464 9.182 1.00 4.39 N +ATOM 18 O MET A 1 2.091 -44.945 7.799 1.00 4.39 O +ATOM 19 SD MET A 1 4.096 -48.921 11.725 1.00 4.39 S +ATOM 20 C LYS A 2 1.366 -45.033 4.898 1.00 2.92 C +ATOM 21 CA LYS A 2 2.235 -46.242 5.308 1.00 2.92 C +ATOM 22 CB LYS A 2 2.206 -47.314 4.196 1.00 2.92 C +ATOM 23 CD LYS A 2 3.331 -49.342 3.134 1.00 2.92 C +ATOM 24 CE LYS A 2 4.434 -50.403 3.293 1.00 2.92 C +ATOM 25 CG LYS A 2 3.294 -48.395 4.349 1.00 2.92 C +ATOM 26 H LYS A 2 1.832 -47.853 6.656 1.00 2.92 H +ATOM 27 HA LYS A 2 3.248 -45.841 5.355 1.00 2.92 H +ATOM 28 HB2 LYS A 2 1.223 -47.785 4.167 1.00 2.92 H +ATOM 29 HB3 LYS A 2 2.363 -46.812 3.241 1.00 2.92 H +ATOM 30 HD2 LYS A 2 3.524 -48.754 2.237 1.00 2.92 H +ATOM 31 HD3 LYS A 2 2.364 -49.833 3.031 1.00 2.92 H +ATOM 32 HE2 LYS A 2 5.383 -49.891 3.455 1.00 2.92 H +ATOM 33 HE3 LYS A 2 4.225 -51.000 4.180 1.00 2.92 H +ATOM 34 HG2 LYS A 2 3.102 -48.977 5.250 1.00 2.92 H +ATOM 35 HG3 LYS A 2 4.264 -47.909 4.446 1.00 2.92 H +ATOM 36 HZ1 LYS A 2 4.763 -50.747 1.274 1.00 2.92 H +ATOM 37 HZ2 LYS A 2 3.681 -51.785 1.931 1.00 2.92 H +ATOM 38 HZ3 LYS A 2 5.280 -51.965 2.224 1.00 2.92 H +ATOM 39 N LYS A 2 1.907 -46.846 6.629 1.00 2.92 N +ATOM 40 NZ LYS A 2 4.542 -51.286 2.100 1.00 2.92 N +ATOM 41 O LYS A 2 1.882 -44.093 4.312 1.00 2.92 O +ATOM 42 C PHE A 3 -0.511 -42.597 5.624 1.00 4.39 C +ATOM 43 CA PHE A 3 -0.853 -43.933 4.929 1.00 4.39 C +ATOM 44 CB PHE A 3 -2.271 -44.408 5.285 1.00 4.39 C +ATOM 45 CD1 PHE A 3 -3.760 -43.542 3.432 1.00 4.39 C +ATOM 46 CD2 PHE A 3 -4.050 -42.638 5.675 1.00 4.39 C +ATOM 47 CE1 PHE A 3 -4.797 -42.715 2.965 1.00 4.39 C +ATOM 48 CE2 PHE A 3 -5.091 -41.818 5.207 1.00 4.39 C +ATOM 49 CG PHE A 3 -3.382 -43.505 4.788 1.00 4.39 C +ATOM 50 CZ PHE A 3 -5.463 -41.853 3.853 1.00 4.39 C +ATOM 51 H PHE A 3 -0.311 -45.868 5.655 1.00 4.39 H +ATOM 52 HA PHE A 3 -0.817 -43.746 3.856 1.00 4.39 H +ATOM 53 HB2 PHE A 3 -2.353 -44.512 6.367 1.00 4.39 H +ATOM 54 HB3 PHE A 3 -2.432 -45.393 4.848 1.00 4.39 H +ATOM 55 HD1 PHE A 3 -3.255 -44.198 2.739 1.00 4.39 H +ATOM 56 HD2 PHE A 3 -3.768 -42.590 6.716 1.00 4.39 H +ATOM 57 HE1 PHE A 3 -5.083 -42.735 1.923 1.00 4.39 H +ATOM 58 HE2 PHE A 3 -5.604 -41.151 5.885 1.00 4.39 H +ATOM 59 HZ PHE A 3 -6.257 -41.215 3.493 1.00 4.39 H +ATOM 60 N PHE A 3 0.079 -45.027 5.253 1.00 4.39 N +ATOM 61 O PHE A 3 -0.633 -41.541 5.014 1.00 4.39 O +ATOM 62 C LEU A 4 1.598 -40.732 7.042 1.00 4.39 C +ATOM 63 CA LEU A 4 0.367 -41.437 7.633 1.00 4.39 C +ATOM 64 CB LEU A 4 0.628 -41.823 9.104 1.00 4.39 C +ATOM 65 CD1 LEU A 4 -0.319 -42.778 11.228 1.00 4.39 C +ATOM 66 CD2 LEU A 4 -1.300 -40.694 10.309 1.00 4.39 C +ATOM 67 CG LEU A 4 -0.650 -42.027 9.937 1.00 4.39 C +ATOM 68 H LEU A 4 0.163 -43.538 7.292 1.00 4.39 H +ATOM 69 HA LEU A 4 -0.445 -40.712 7.588 1.00 4.39 H +ATOM 70 HB2 LEU A 4 1.213 -41.034 9.576 1.00 4.39 H +ATOM 71 HB3 LEU A 4 1.235 -42.728 9.127 1.00 4.39 H +ATOM 72 HD11 LEU A 4 0.380 -42.191 11.824 1.00 4.39 H +ATOM 73 HD12 LEU A 4 0.127 -43.747 11.002 1.00 4.39 H +ATOM 74 HD13 LEU A 4 -1.230 -42.927 11.808 1.00 4.39 H +ATOM 75 HD21 LEU A 4 -0.606 -40.080 10.883 1.00 4.39 H +ATOM 76 HD22 LEU A 4 -2.193 -40.869 10.909 1.00 4.39 H +ATOM 77 HD23 LEU A 4 -1.593 -40.147 9.413 1.00 4.39 H +ATOM 78 HG LEU A 4 -1.359 -42.630 9.370 1.00 4.39 H +ATOM 79 N LEU A 4 -0.012 -42.638 6.869 1.00 4.39 N +ATOM 80 O LEU A 4 1.655 -39.508 7.028 1.00 4.39 O +ATOM 81 C VAL A 5 3.372 -40.190 4.573 1.00 4.39 C +ATOM 82 CA VAL A 5 3.752 -40.956 5.845 1.00 4.39 C +ATOM 83 CB VAL A 5 4.757 -42.083 5.528 1.00 4.39 C +ATOM 84 CG1 VAL A 5 6.019 -41.568 4.827 1.00 4.39 C +ATOM 85 CG2 VAL A 5 5.199 -42.807 6.810 1.00 4.39 C +ATOM 86 H VAL A 5 2.440 -42.503 6.548 1.00 4.39 H +ATOM 87 HA VAL A 5 4.234 -40.242 6.512 1.00 4.39 H +ATOM 88 HB VAL A 5 4.279 -42.813 4.875 1.00 4.39 H +ATOM 89 HG11 VAL A 5 6.494 -40.795 5.431 1.00 4.39 H +ATOM 90 HG12 VAL A 5 5.770 -41.145 3.853 1.00 4.39 H +ATOM 91 HG13 VAL A 5 6.725 -42.383 4.670 1.00 4.39 H +ATOM 92 HG21 VAL A 5 4.347 -43.283 7.297 1.00 4.39 H +ATOM 93 HG22 VAL A 5 5.933 -43.575 6.568 1.00 4.39 H +ATOM 94 HG23 VAL A 5 5.651 -42.093 7.498 1.00 4.39 H +ATOM 95 N VAL A 5 2.554 -41.501 6.509 1.00 4.39 N +ATOM 96 O VAL A 5 3.937 -39.138 4.297 1.00 4.39 O +TER 96 VAL A 5 +END |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/relax/testdata/multiple_disulfides_target.pdb --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/relax/testdata/multiple_disulfides_target.pdb Tue Mar 01 02:53:05 2022 +0000 |
b |
b'@@ -0,0 +1,1478 @@\n+MODEL 0\n+ATOM 1 N MET A 1 19.164 -28.457 26.130 1.00 0.00 N \n+ATOM 2 CA MET A 1 19.746 -27.299 25.456 1.00 0.00 C \n+ATOM 3 C MET A 1 19.080 -26.008 25.921 1.00 0.00 C \n+ATOM 4 CB MET A 1 19.615 -27.438 23.938 1.00 0.00 C \n+ATOM 5 O MET A 1 17.853 -25.899 25.913 1.00 0.00 O \n+ATOM 6 CG MET A 1 19.873 -28.846 23.427 1.00 0.00 C \n+ATOM 7 SD MET A 1 21.636 -29.126 23.002 1.00 0.00 S \n+ATOM 8 CE MET A 1 22.302 -27.462 23.284 1.00 0.00 C \n+ATOM 9 N ALA A 2 19.679 -25.354 27.019 1.00 0.00 N \n+ATOM 10 CA ALA A 2 19.241 -24.061 27.539 1.00 0.00 C \n+ATOM 11 C ALA A 2 18.629 -23.204 26.434 1.00 0.00 C \n+ATOM 12 CB ALA A 2 20.410 -23.326 28.192 1.00 0.00 C \n+ATOM 13 O ALA A 2 19.158 -23.145 25.322 1.00 0.00 O \n+ATOM 14 N HIS A 3 17.369 -23.382 26.161 1.00 0.00 N \n+ATOM 15 CA HIS A 3 16.748 -22.427 25.250 1.00 0.00 C \n+ATOM 16 C HIS A 3 17.419 -21.061 25.342 1.00 0.00 C \n+ATOM 17 CB HIS A 3 15.252 -22.299 25.547 1.00 0.00 C \n+ATOM 18 O HIS A 3 17.896 -20.669 26.409 1.00 0.00 O \n+ATOM 19 CG HIS A 3 14.464 -23.520 25.196 1.00 0.00 C \n+ATOM 20 CD2 HIS A 3 13.848 -24.436 25.979 1.00 0.00 C \n+ATOM 21 ND1 HIS A 3 14.242 -23.914 23.894 1.00 0.00 N \n+ATOM 22 CE1 HIS A 3 13.520 -25.022 23.892 1.00 0.00 C \n+ATOM 23 NE2 HIS A 3 13.268 -25.360 25.145 1.00 0.00 N \n+ATOM 24 N GLU A 4 18.306 -20.798 24.429 1.00 0.00 N \n+ATOM 25 CA GLU A 4 18.907 -19.505 24.115 1.00 0.00 C \n+ATOM 26 C GLU A 4 18.392 -18.415 25.050 1.00 0.00 C \n+ATOM 27 CB GLU A 4 18.631 -19.123 22.659 1.00 0.00 C \n+ATOM 28 O GLU A 4 17.240 -18.458 25.486 1.00 0.00 O \n+ATOM 29 CG GLU A 4 19.253 -20.072 21.645 1.00 0.00 C \n+ATOM 30 CD GLU A 4 20.767 -19.956 21.564 1.00 0.00 C \n+ATOM 31 OE1 GLU A 4 21.330 -18.981 22.111 1.00 0.00 O \n+ATOM 32 OE2 GLU A 4 21.394 -20.846 20.948 1.00 0.00 O \n+ATOM 33 N GLU A 5 19.093 -18.090 26.026 1.00 0.00 N \n+ATOM 34 CA GLU A 5 19.080 -16.885 26.849 1.00 0.00 C \n+ATOM 35 C GLU A 5 17.938 -15.956 26.449 1.00 0.00 C \n+ATOM 36 CB GLU A 5 20.418 -16.148 26.746 1.00 0.00 C \n+ATOM 37 O GLU A 5 17.774 -15.636 25.269 1.00 0.00 O \n+ATOM 38 CG GLU A 5 21.604 -16.952 27.257 1.00 0.00 C \n+ATOM 39 CD GLU A 5 21.641 -17.070 28.772 1.00 0.00 C \n+ATOM 40 OE1 GLU A 5 20.899 -16.330 29.457 1.00 0.00 O \n+ATOM 41 OE2 GLU A 5 22.419 -17.909 29.279 1.00 0.00 O \n+ATOM 42 N ASP A 6 16.721 -16.161 26.857 1.00 0.00 N \n+ATOM 43 CA ASP A 6 15.629 -15.196 26.948 1.00 0.00 C \n+ATOM 44 C ASP A 6 16.107 -13.791 26.591 1.00 0.00 C \n+ATOM 45 CB ASP A 6 15.022 -15.204 28.353 1.00 0.00 C \n+ATOM 46 O ASP A 6 17.144 -13.339 27.079 1.00 0.00 O \n+ATOM 47 CG ASP A 6 14.317 -16.507 28.687 1.00 0.00 C \n+ATOM 48 OD1 ASP A 6 14.123 -16.805 29.885 1.00 0.00 O \n+ATOM 49 OD2 ASP A 6 '..b'0 0.00 C \n+ATOM 1427 C HIS A 187 -24.996 22.092 -14.235 1.00 0.00 C \n+ATOM 1428 CB HIS A 187 -25.335 20.219 -15.862 1.00 0.00 C \n+ATOM 1429 O HIS A 187 -25.860 21.819 -13.399 1.00 0.00 O \n+ATOM 1430 CG HIS A 187 -24.723 19.267 -16.840 1.00 0.00 C \n+ATOM 1431 CD2 HIS A 187 -24.703 17.913 -16.868 1.00 0.00 C \n+ATOM 1432 ND1 HIS A 187 -24.026 19.687 -17.952 1.00 0.00 N \n+ATOM 1433 CE1 HIS A 187 -23.603 18.630 -18.624 1.00 0.00 C \n+ATOM 1434 NE2 HIS A 187 -24.000 17.541 -17.988 1.00 0.00 N \n+ATOM 1435 N HIS A 188 -24.187 22.915 -13.610 1.00 0.00 N \n+ATOM 1436 CA HIS A 188 -24.905 24.094 -13.140 1.00 0.00 C \n+ATOM 1437 C HIS A 188 -25.814 24.656 -14.228 1.00 0.00 C \n+ATOM 1438 CB HIS A 188 -23.923 25.168 -12.669 1.00 0.00 C \n+ATOM 1439 O HIS A 188 -25.380 24.849 -15.366 1.00 0.00 O \n+ATOM 1440 CG HIS A 188 -23.207 24.814 -11.404 1.00 0.00 C \n+ATOM 1441 CD2 HIS A 188 -21.913 24.489 -11.176 1.00 0.00 C \n+ATOM 1442 ND1 HIS A 188 -23.839 24.765 -10.181 1.00 0.00 N \n+ATOM 1443 CE1 HIS A 188 -22.961 24.425 -9.252 1.00 0.00 C \n+ATOM 1444 NE2 HIS A 188 -21.785 24.251 -9.830 1.00 0.00 N \n+ATOM 1445 N HIS A 189 -26.993 24.045 -14.412 1.00 0.00 N \n+ATOM 1446 CA HIS A 189 -28.031 24.708 -15.193 1.00 0.00 C \n+ATOM 1447 C HIS A 189 -28.031 26.213 -14.944 1.00 0.00 C \n+ATOM 1448 CB HIS A 189 -29.406 24.122 -14.865 1.00 0.00 C \n+ATOM 1449 O HIS A 189 -27.980 26.656 -13.795 1.00 0.00 O \n+ATOM 1450 CG HIS A 189 -29.586 22.714 -15.335 1.00 0.00 C \n+ATOM 1451 CD2 HIS A 189 -29.619 21.546 -14.652 1.00 0.00 C \n+ATOM 1452 ND1 HIS A 189 -29.755 22.390 -16.664 1.00 0.00 N \n+ATOM 1453 CE1 HIS A 189 -29.887 21.079 -16.778 1.00 0.00 C \n+ATOM 1454 NE2 HIS A 189 -29.808 20.543 -15.572 1.00 0.00 N \n+ATOM 1455 N HIS A 190 -27.150 26.911 -15.639 1.00 0.00 N \n+ATOM 1456 CA HIS A 190 -27.271 28.363 -15.688 1.00 0.00 C \n+ATOM 1457 C HIS A 190 -28.733 28.796 -15.667 1.00 0.00 C \n+ATOM 1458 CB HIS A 190 -26.577 28.917 -16.934 1.00 0.00 C \n+ATOM 1459 O HIS A 190 -29.546 28.288 -16.443 1.00 0.00 O \n+ATOM 1460 CG HIS A 190 -25.087 28.794 -16.896 1.00 0.00 C \n+ATOM 1461 CD2 HIS A 190 -24.243 27.997 -17.592 1.00 0.00 C \n+ATOM 1462 ND1 HIS A 190 -24.299 29.553 -16.058 1.00 0.00 N \n+ATOM 1463 CE1 HIS A 190 -23.030 29.227 -16.242 1.00 0.00 C \n+ATOM 1464 NE2 HIS A 190 -22.969 28.285 -17.168 1.00 0.00 N \n+ATOM 1465 N HIS A 191 -29.301 28.820 -14.496 1.00 0.00 N \n+ATOM 1466 CA HIS A 191 -30.497 29.651 -14.423 1.00 0.00 C \n+ATOM 1467 C HIS A 191 -30.240 31.039 -15.001 1.00 0.00 C \n+ATOM 1468 CB HIS A 191 -30.981 29.766 -12.976 1.00 0.00 C \n+ATOM 1469 O HIS A 191 -29.144 31.584 -14.855 1.00 0.00 O \n+ATOM 1470 CG HIS A 191 -31.644 28.528 -12.465 1.00 0.00 C \n+ATOM 1471 CD2 HIS A 191 -31.230 27.603 -11.566 1.00 0.00 C \n+ATOM 1472 ND1 HIS A 191 -32.892 28.123 -12.888 1.00 0.00 N \n+ATOM 1473 CE1 HIS A 191 -33.217 27.000 -12.270 1.00 0.00 C \n+ATOM 1474 NE2 HIS A 191 -32.226 26.664 -11.462 1.00 0.00 N \n+TER 1475 HIS A 191\n+ENDMDL\n+END\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/relax/testdata/with_violations.pdb --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/relax/testdata/with_violations.pdb Tue Mar 01 02:53:05 2022 +0000 |
b |
b'@@ -0,0 +1,1193 @@\n+MODEL 0\n+ATOM 1 N SER A 1 23.291 1.505 0.613 1.00 6.08 N \n+ATOM 2 CA SER A 1 22.518 0.883 -0.457 1.00 6.08 C \n+ATOM 3 C SER A 1 21.020 1.015 -0.206 1.00 6.08 C \n+ATOM 4 CB SER A 1 22.891 -0.593 -0.601 1.00 6.08 C \n+ATOM 5 O SER A 1 20.593 1.246 0.928 1.00 6.08 O \n+ATOM 6 OG SER A 1 22.364 -1.352 0.474 1.00 6.08 O \n+ATOM 7 N PHE A 2 20.180 1.317 -1.280 1.00 6.08 N \n+ATOM 8 CA PHE A 2 18.725 1.321 -1.187 1.00 6.08 C \n+ATOM 9 C PHE A 2 18.244 0.288 -0.175 1.00 6.08 C \n+ATOM 10 CB PHE A 2 18.097 1.046 -2.557 1.00 6.08 C \n+ATOM 11 O PHE A 2 17.437 0.600 0.703 1.00 6.08 O \n+ATOM 12 CG PHE A 2 16.601 0.880 -2.517 1.00 6.08 C \n+ATOM 13 CD1 PHE A 2 15.765 1.989 -2.519 1.00 6.08 C \n+ATOM 14 CD2 PHE A 2 16.033 -0.386 -2.478 1.00 6.08 C \n+ATOM 15 CE1 PHE A 2 14.380 1.838 -2.482 1.00 6.08 C \n+ATOM 16 CE2 PHE A 2 14.650 -0.545 -2.441 1.00 6.08 C \n+ATOM 17 CZ PHE A 2 13.826 0.569 -2.442 1.00 6.08 C \n+ATOM 18 N GLU A 3 18.695 -0.904 -0.178 1.00 6.08 N \n+ATOM 19 CA GLU A 3 18.305 -2.028 0.668 1.00 6.08 C \n+ATOM 20 C GLU A 3 18.535 -1.714 2.144 1.00 6.08 C \n+ATOM 21 CB GLU A 3 19.073 -3.291 0.273 1.00 6.08 C \n+ATOM 22 O GLU A 3 17.664 -1.961 2.980 1.00 6.08 O \n+ATOM 23 CG GLU A 3 18.413 -4.088 -0.843 1.00 6.08 C \n+ATOM 24 CD GLU A 3 19.408 -4.840 -1.713 1.00 6.08 C \n+ATOM 25 OE1 GLU A 3 18.977 -5.585 -2.622 1.00 6.08 O \n+ATOM 26 OE2 GLU A 3 20.628 -4.683 -1.482 1.00 6.08 O \n+ATOM 27 N GLU A 4 19.823 -1.305 2.459 1.00 6.08 N \n+ATOM 28 CA GLU A 4 20.190 -1.047 3.848 1.00 6.08 C \n+ATOM 29 C GLU A 4 19.315 0.044 4.456 1.00 6.08 C \n+ATOM 30 CB GLU A 4 21.666 -0.656 3.950 1.00 6.08 C \n+ATOM 31 O GLU A 4 18.868 -0.076 5.599 1.00 6.08 O \n+ATOM 32 CG GLU A 4 22.621 -1.841 3.913 1.00 6.08 C \n+ATOM 33 CD GLU A 4 24.085 -1.434 3.973 1.00 6.08 C \n+ATOM 34 OE1 GLU A 4 24.957 -2.324 4.094 1.00 6.08 O \n+ATOM 35 OE2 GLU A 4 24.361 -0.216 3.899 1.00 6.08 O \n+ATOM 36 N GLN A 5 19.061 1.102 3.590 1.00 6.08 N \n+ATOM 37 CA GLN A 5 18.207 2.189 4.056 1.00 6.08 C \n+ATOM 38 C GLN A 5 16.771 1.714 4.255 1.00 6.08 C \n+ATOM 39 CB GLN A 5 18.241 3.359 3.071 1.00 6.08 C \n+ATOM 40 O GLN A 5 16.113 2.097 5.225 1.00 6.08 O \n+ATOM 41 CG GLN A 5 19.395 4.326 3.304 1.00 6.08 C \n+ATOM 42 CD GLN A 5 19.384 5.496 2.338 1.00 6.08 C \n+ATOM 43 NE2 GLN A 5 20.565 6.022 2.031 1.00 6.08 N \n+ATOM 44 OE1 GLN A 5 18.323 5.922 1.871 1.00 6.08 O \n+ATOM 45 N PHE A 6 16.354 0.831 3.208 1.00 5.36 N \n+ATOM 46 CA PHE A 6 15.014 0.260 3.283 1.00 5.36 C \n+ATOM 47 C PHE A 6 14.844 -0.555 4.559 1.00 5.36 C \n+ATOM 48 CB PHE A 6 14.732 -0.616 2.059 1.00 5.36 C \n+ATOM 49 O PHE A 6 '..b'0 6.08 C \n+ATOM 1142 O ALA A 142 -8.874 10.710 -10.184 1.00 6.08 O \n+ATOM 1143 N ASP A 143 -9.016 8.898 -10.877 1.00 6.08 N \n+ATOM 1144 CA ASP A 143 -10.425 8.964 -11.253 1.00 6.08 C \n+ATOM 1145 C ASP A 143 -11.295 9.345 -10.057 1.00 6.08 C \n+ATOM 1146 CB ASP A 143 -10.889 7.628 -11.836 1.00 6.08 C \n+ATOM 1147 O ASP A 143 -11.158 8.769 -8.975 1.00 6.08 O \n+ATOM 1148 CG ASP A 143 -11.385 7.746 -13.267 1.00 6.08 C \n+ATOM 1149 OD1 ASP A 143 -11.573 6.706 -13.934 1.00 6.08 O \n+ATOM 1150 OD2 ASP A 143 -11.586 8.889 -13.731 1.00 6.08 O \n+ATOM 1151 N SER A 144 -11.432 10.610 -9.724 1.00 6.08 N \n+ATOM 1152 CA SER A 144 -12.633 11.251 -9.197 1.00 6.08 C \n+ATOM 1153 C SER A 144 -13.803 10.274 -9.147 1.00 6.08 C \n+ATOM 1154 CB SER A 144 -13.009 12.466 -10.046 1.00 6.08 C \n+ATOM 1155 O SER A 144 -14.946 10.678 -8.919 1.00 6.08 O \n+ATOM 1156 OG SER A 144 -12.987 12.143 -11.426 1.00 6.08 O \n+ATOM 1157 N ARG A 145 -13.625 8.971 -9.055 1.00 6.08 N \n+ATOM 1158 CA ARG A 145 -14.877 8.231 -8.942 1.00 6.08 C \n+ATOM 1159 C ARG A 145 -15.517 8.442 -7.574 1.00 6.08 C \n+ATOM 1160 CB ARG A 145 -14.644 6.738 -9.187 1.00 6.08 C \n+ATOM 1161 O ARG A 145 -14.826 8.444 -6.554 1.00 6.08 O \n+ATOM 1162 CG ARG A 145 -14.402 6.383 -10.645 1.00 6.08 C \n+ATOM 1163 CD ARG A 145 -14.336 4.877 -10.856 1.00 6.08 C \n+ATOM 1164 NE ARG A 145 -13.186 4.497 -11.671 1.00 6.08 N \n+ATOM 1165 NH1 ARG A 145 -13.735 2.255 -11.769 1.00 6.08 N \n+ATOM 1166 NH2 ARG A 145 -11.852 3.025 -12.824 1.00 6.08 N \n+ATOM 1167 CZ ARG A 145 -12.927 3.260 -12.086 1.00 6.08 C \n+ATOM 1168 N THR A 146 -16.379 9.419 -7.415 1.00 6.08 N \n+ATOM 1169 CA THR A 146 -17.507 9.485 -6.494 1.00 6.08 C \n+ATOM 1170 C THR A 146 -18.280 8.169 -6.491 1.00 6.08 C \n+ATOM 1171 CB THR A 146 -18.458 10.641 -6.856 1.00 6.08 C \n+ATOM 1172 O THR A 146 -18.534 7.590 -7.549 1.00 6.08 O \n+ATOM 1173 CG2 THR A 146 -18.028 11.939 -6.180 1.00 6.08 C \n+ATOM 1174 OG1 THR A 146 -18.451 10.830 -8.276 1.00 6.08 O \n+ATOM 1175 N VAL A 147 -17.785 7.142 -5.708 1.00 6.08 N \n+ATOM 1176 CA VAL A 147 -18.621 5.978 -5.435 1.00 6.08 C \n+ATOM 1177 C VAL A 147 -20.048 6.427 -5.126 1.00 6.08 C \n+ATOM 1178 CB VAL A 147 -18.061 5.139 -4.264 1.00 6.08 C \n+ATOM 1179 O VAL A 147 -20.261 7.305 -4.287 1.00 6.08 O \n+ATOM 1180 CG1 VAL A 147 -18.638 3.725 -4.289 1.00 6.08 C \n+ATOM 1181 CG2 VAL A 147 -16.535 5.098 -4.321 1.00 6.08 C \n+ATOM 1182 N ASP A 148 -20.960 6.728 -6.190 1.00 6.08 N \n+ATOM 1183 CA ASP A 148 -22.394 6.829 -5.938 1.00 6.08 C \n+ATOM 1184 C ASP A 148 -22.901 5.619 -5.157 1.00 6.08 C \n+ATOM 1185 CB ASP A 148 -23.162 6.965 -7.254 1.00 6.08 C \n+ATOM 1186 O ASP A 148 -22.505 4.485 -5.432 1.00 6.08 O \n+ATOM 1187 CG ASP A 148 -22.902 8.285 -7.959 1.00 6.08 C \n+ATOM 1188 OD1 ASP A 148 -23.140 8.380 -9.182 1.00 6.08 O \n+ATOM 1189 OD2 ASP A 148 -22.451 9.237 -7.286 1.00 6.08 O \n+TER 1190 ASP A 148\n+ENDMDL\n+END\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/relax/testdata/with_violations_casp14.pdb --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/relax/testdata/with_violations_casp14.pdb Tue Mar 01 02:53:05 2022 +0000 |
b |
b'@@ -0,0 +1,1193 @@\n+MODEL 0\n+ATOM 1 N SER A 1 27.311 -3.395 37.375 1.00 8.64 N \n+ATOM 2 CA SER A 1 26.072 -4.109 37.084 1.00 8.64 C \n+ATOM 3 C SER A 1 26.047 -4.608 35.643 1.00 8.64 C \n+ATOM 4 CB SER A 1 24.862 -3.211 37.342 1.00 8.64 C \n+ATOM 5 O SER A 1 26.782 -4.101 34.792 1.00 8.64 O \n+ATOM 6 OG SER A 1 24.740 -2.228 36.329 1.00 8.64 O \n+ATOM 7 N PHE A 2 25.619 -5.987 35.357 1.00 8.64 N \n+ATOM 8 CA PHE A 2 25.448 -6.479 33.995 1.00 8.64 C \n+ATOM 9 C PHE A 2 25.049 -5.347 33.056 1.00 8.64 C \n+ATOM 10 CB PHE A 2 24.395 -7.591 33.953 1.00 8.64 C \n+ATOM 11 O PHE A 2 25.590 -5.226 31.955 1.00 8.64 O \n+ATOM 12 CG PHE A 2 24.140 -8.134 32.573 1.00 8.64 C \n+ATOM 13 CD1 PHE A 2 25.003 -9.063 32.006 1.00 8.64 C \n+ATOM 14 CD2 PHE A 2 23.036 -7.714 31.842 1.00 8.64 C \n+ATOM 15 CE1 PHE A 2 24.770 -9.567 30.728 1.00 8.64 C \n+ATOM 16 CE2 PHE A 2 22.796 -8.214 30.565 1.00 8.64 C \n+ATOM 17 CZ PHE A 2 23.665 -9.139 30.010 1.00 8.64 C \n+ATOM 18 N GLU A 3 24.279 -4.453 33.583 1.00 8.64 N \n+ATOM 19 CA GLU A 3 23.756 -3.316 32.831 1.00 8.64 C \n+ATOM 20 C GLU A 3 24.858 -2.308 32.517 1.00 8.64 C \n+ATOM 21 CB GLU A 3 22.624 -2.635 33.604 1.00 8.64 C \n+ATOM 22 O GLU A 3 24.963 -1.828 31.387 1.00 8.64 O \n+ATOM 23 CG GLU A 3 21.251 -3.239 33.345 1.00 8.64 C \n+ATOM 24 CD GLU A 3 20.291 -3.067 34.511 1.00 8.64 C \n+ATOM 25 OE1 GLU A 3 19.129 -3.525 34.413 1.00 8.64 O \n+ATOM 26 OE2 GLU A 3 20.702 -2.469 35.530 1.00 8.64 O \n+ATOM 27 N GLU A 4 25.795 -2.118 33.499 1.00 8.64 N \n+ATOM 28 CA GLU A 4 26.873 -1.150 33.321 1.00 8.64 C \n+ATOM 29 C GLU A 4 27.923 -1.667 32.341 1.00 8.64 C \n+ATOM 30 CB GLU A 4 27.526 -0.820 34.666 1.00 8.64 C \n+ATOM 31 O GLU A 4 28.401 -0.920 31.485 1.00 8.64 O \n+ATOM 32 CG GLU A 4 26.709 0.130 35.529 1.00 8.64 C \n+ATOM 33 CD GLU A 4 27.351 0.416 36.878 1.00 8.64 C \n+ATOM 34 OE1 GLU A 4 26.801 1.234 37.650 1.00 8.64 O \n+ATOM 35 OE2 GLU A 4 28.412 -0.182 37.164 1.00 8.64 O \n+ATOM 36 N GLN A 5 28.078 -2.983 32.335 1.00 8.64 N \n+ATOM 37 CA GLN A 5 29.050 -3.614 31.449 1.00 8.64 C \n+ATOM 38 C GLN A 5 28.520 -3.696 30.020 1.00 8.64 C \n+ATOM 39 CB GLN A 5 29.410 -5.012 31.956 1.00 8.64 C \n+ATOM 40 O GLN A 5 29.268 -3.487 29.063 1.00 8.64 O \n+ATOM 41 CG GLN A 5 30.587 -5.031 32.922 1.00 8.64 C \n+ATOM 42 CD GLN A 5 30.906 -6.425 33.430 1.00 8.64 C \n+ATOM 43 NE2 GLN A 5 31.803 -6.509 34.407 1.00 8.64 N \n+ATOM 44 OE1 GLN A 5 30.350 -7.418 32.950 1.00 8.64 O \n+ATOM 45 N PHE A 6 27.127 -3.824 29.849 1.00 8.64 N \n+ATOM 46 CA PHE A 6 26.442 -3.868 28.562 1.00 8.64 C \n+ATOM 47 C PHE A 6 26.501 -2.512 27.870 1.00 8.64 C \n+ATOM 48 CB PHE A 6 24.983 -4.302 28.744 1.00 8.64 C \n+ATOM 49 O PHE A 6 '..b'0 8.64 C \n+ATOM 1142 O ALA A 142 18.474 -22.483 2.749 1.00 8.64 O \n+ATOM 1143 N ASP A 143 18.576 -22.902 4.585 1.00 8.64 N \n+ATOM 1144 CA ASP A 143 19.977 -23.243 4.809 1.00 8.64 C \n+ATOM 1145 C ASP A 143 20.646 -22.233 5.738 1.00 8.64 C \n+ATOM 1146 CB ASP A 143 20.732 -23.315 3.480 1.00 8.64 C \n+ATOM 1147 O ASP A 143 20.988 -21.126 5.317 1.00 8.64 O \n+ATOM 1148 CG ASP A 143 21.865 -24.326 3.495 1.00 8.64 C \n+ATOM 1149 OD1 ASP A 143 22.444 -24.609 2.424 1.00 8.64 O \n+ATOM 1150 OD2 ASP A 143 22.180 -24.846 4.587 1.00 8.64 O \n+ATOM 1151 N SER A 144 20.228 -21.957 7.002 1.00 8.64 N \n+ATOM 1152 CA SER A 144 21.341 -21.510 7.832 1.00 8.64 C \n+ATOM 1153 C SER A 144 21.253 -22.095 9.238 1.00 8.64 C \n+ATOM 1154 CB SER A 144 21.373 -19.983 7.908 1.00 8.64 C \n+ATOM 1155 O SER A 144 21.817 -21.541 10.183 1.00 8.64 O \n+ATOM 1156 OG SER A 144 20.155 -19.479 8.427 1.00 8.64 O \n+ATOM 1157 N ARG A 145 20.767 -23.335 9.468 1.00 8.64 N \n+ATOM 1158 CA ARG A 145 20.907 -23.823 10.836 1.00 8.64 C \n+ATOM 1159 C ARG A 145 22.363 -24.145 11.156 1.00 8.64 C \n+ATOM 1160 CB ARG A 145 20.036 -25.062 11.058 1.00 8.64 C \n+ATOM 1161 O ARG A 145 23.035 -24.835 10.387 1.00 8.64 O \n+ATOM 1162 CG ARG A 145 18.542 -24.781 11.007 1.00 8.64 C \n+ATOM 1163 CD ARG A 145 17.724 -26.058 11.140 1.00 8.64 C \n+ATOM 1164 NE ARG A 145 17.859 -26.910 9.963 1.00 8.64 N \n+ATOM 1165 NH1 ARG A 145 16.657 -28.674 10.846 1.00 8.64 N \n+ATOM 1166 NH2 ARG A 145 17.532 -28.821 8.732 1.00 8.64 N \n+ATOM 1167 CZ ARG A 145 17.349 -28.133 9.850 1.00 8.64 C \n+ATOM 1168 N THR A 146 23.240 -23.194 11.377 1.00 8.64 N \n+ATOM 1169 CA THR A 146 24.353 -23.445 12.287 1.00 8.64 C \n+ATOM 1170 C THR A 146 23.848 -23.974 13.626 1.00 8.64 C \n+ATOM 1171 CB THR A 146 25.186 -22.170 12.516 1.00 8.64 C \n+ATOM 1172 O THR A 146 22.915 -23.417 14.208 1.00 8.64 O \n+ATOM 1173 CG2 THR A 146 26.113 -21.899 11.336 1.00 8.64 C \n+ATOM 1174 OG1 THR A 146 24.303 -21.054 12.683 1.00 8.64 O \n+ATOM 1175 N VAL A 147 23.601 -25.297 13.669 1.00 8.64 N \n+ATOM 1176 CA VAL A 147 23.604 -26.015 14.940 1.00 8.64 C \n+ATOM 1177 C VAL A 147 25.042 -26.295 15.372 1.00 8.64 C \n+ATOM 1178 CB VAL A 147 22.809 -27.337 14.846 1.00 8.64 C \n+ATOM 1179 O VAL A 147 25.856 -26.768 14.575 1.00 8.64 O \n+ATOM 1180 CG1 VAL A 147 22.708 -28.004 16.217 1.00 8.64 C \n+ATOM 1181 CG2 VAL A 147 21.418 -27.082 14.269 1.00 8.64 C \n+ATOM 1182 N ASP A 148 25.801 -25.404 15.970 1.00 8.64 N \n+ATOM 1183 CA ASP A 148 26.582 -25.450 17.202 1.00 8.64 C \n+ATOM 1184 C ASP A 148 27.024 -24.051 17.625 1.00 8.64 C \n+ATOM 1185 CB ASP A 148 27.802 -26.358 17.031 1.00 8.64 C \n+ATOM 1186 O ASP A 148 27.467 -23.256 16.793 1.00 8.64 O \n+ATOM 1187 CG ASP A 148 27.486 -27.826 17.259 1.00 8.64 C \n+ATOM 1188 OD1 ASP A 148 28.258 -28.694 16.799 1.00 8.64 O \n+ATOM 1189 OD2 ASP A 148 26.454 -28.116 17.901 1.00 8.64 O \n+TER 1190 ASP A 148\n+ENDMDL\n+END\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/relax/utils.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/relax/utils.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,80 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils for minimization.""" +import io +from alphafold.common import residue_constants +from Bio import PDB +import numpy as np +from simtk.openmm import app as openmm_app +from simtk.openmm.app.internal.pdbstructure import PdbStructure + + +def overwrite_pdb_coordinates(pdb_str: str, pos) -> str: + pdb_file = io.StringIO(pdb_str) + structure = PdbStructure(pdb_file) + topology = openmm_app.PDBFile(structure).getTopology() + with io.StringIO() as f: + openmm_app.PDBFile.writeFile(topology, pos, f) + return f.getvalue() + + +def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str: + """Overwrites the B-factors in pdb_str with contents of bfactors array. + + Args: + pdb_str: An input PDB string. + bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the + B-factors are per residue; i.e. that the nonzero entries are identical in + [0, i, :]. + + Returns: + A new PDB string with the B-factors replaced. + """ + if bfactors.shape[-1] != residue_constants.atom_type_num: + raise ValueError( + f'Invalid final dimension size for bfactors: {bfactors.shape[-1]}.') + + parser = PDB.PDBParser(QUIET=True) + handle = io.StringIO(pdb_str) + structure = parser.get_structure('', handle) + + curr_resid = ('', '', '') + idx = -1 + for atom in structure.get_atoms(): + atom_resid = atom.parent.get_id() + if atom_resid != curr_resid: + idx += 1 + if idx >= bfactors.shape[0]: + raise ValueError('Index into bfactors exceeds number of residues. ' + 'B-factors shape: {shape}, idx: {idx}.') + curr_resid = atom_resid + atom.bfactor = bfactors[idx, residue_constants.atom_order['CA']] + + new_pdb = io.StringIO() + pdb_io = PDB.PDBIO() + pdb_io.set_structure(structure) + pdb_io.save(new_pdb) + return new_pdb.getvalue() + + +def assert_equal_nonterminal_atom_types( + atom_mask: np.ndarray, ref_atom_mask: np.ndarray): + """Checks that pre- and post-minimized proteins have same atom set.""" + # Ignore any terminal OXT atoms which may have been added by minimization. + oxt = residue_constants.atom_order['OXT'] + no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool) + no_oxt_mask[..., oxt] = False + np.testing.assert_almost_equal(ref_atom_mask[no_oxt_mask], + atom_mask[no_oxt_mask]) |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/alphafold/relax/utils_test.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/relax/utils_test.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,55 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for utils.""" + +import os + +from absl.testing import absltest +from alphafold.common import protein +from alphafold.relax import utils +import numpy as np +# Internal import (7716). + + +class UtilsTest(absltest.TestCase): + + def test_overwrite_b_factors(self): + testdir = os.path.join( + absltest.get_default_test_srcdir(), + 'alphafold/relax/testdata/' + 'multiple_disulfides_target.pdb') + with open(testdir) as f: + test_pdb = f.read() + n_residues = 191 + bfactors = np.stack([np.arange(0, n_residues)] * 37, axis=-1) + + output_pdb = utils.overwrite_b_factors(test_pdb, bfactors) + + # Check that the atom lines are unchanged apart from the B-factors. + atom_lines_original = [l for l in test_pdb.split('\n') if l[:4] == ('ATOM')] + atom_lines_new = [l for l in output_pdb.split('\n') if l[:4] == ('ATOM')] + for line_original, line_new in zip(atom_lines_original, atom_lines_new): + self.assertEqual(line_original[:60].strip(), line_new[:60].strip()) + self.assertEqual(line_original[66:].strip(), line_new[66:].strip()) + + # Check B-factors are correctly set for all atoms present. + as_protein = protein.from_pdb_string(output_pdb) + np.testing.assert_almost_equal( + np.where(as_protein.atom_mask > 0, as_protein.b_factors, 0), + np.where(as_protein.atom_mask > 0, bfactors, 0)) + + +if __name__ == '__main__': + absltest.main() |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/docker/Dockerfile --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/docker/Dockerfile Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,85 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARG CUDA=11.0 +FROM nvidia/cuda:${CUDA}-cudnn8-runtime-ubuntu18.04 +# FROM directive resets ARGS, so we specify again (the value is retained if +# previously set). +ARG CUDA + +# Use bash to support string substitution. +SHELL ["/bin/bash", "-c"] + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + build-essential \ + cmake \ + cuda-command-line-tools-${CUDA/./-} \ + git \ + hmmer \ + kalign \ + tzdata \ + wget \ + && rm -rf /var/lib/apt/lists/* + +# Compile HHsuite from source. +RUN git clone --branch v3.3.0 https://github.com/soedinglab/hh-suite.git /tmp/hh-suite \ + && mkdir /tmp/hh-suite/build \ + && pushd /tmp/hh-suite/build \ + && cmake -DCMAKE_INSTALL_PREFIX=/opt/hhsuite .. \ + && make -j 4 && make install \ + && ln -s /opt/hhsuite/bin/* /usr/bin \ + && popd \ + && rm -rf /tmp/hh-suite + +# Install Miniconda package manager. +RUN wget -q -P /tmp \ + https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ + && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \ + && rm /tmp/Miniconda3-latest-Linux-x86_64.sh + +# Install conda packages. +ENV PATH="/opt/conda/bin:$PATH" +RUN conda update -qy conda \ + && conda install -y -c conda-forge \ + openmm=7.5.1 \ + cudatoolkit==${CUDA_VERSION} \ + pdbfixer \ + pip \ + python=3.7 + +COPY . /app/alphafold +RUN wget -q -P /app/alphafold/alphafold/common/ \ + https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt + +# Install pip packages. +RUN pip3 install --upgrade pip \ + && pip3 install -r /app/alphafold/requirements.txt \ + && pip3 install --upgrade jax jaxlib==0.1.69+cuda${CUDA/./} -f \ + https://storage.googleapis.com/jax-releases/jax_releases.html + +# Apply OpenMM patch. +WORKDIR /opt/conda/lib/python3.7/site-packages +RUN patch -p0 < /app/alphafold/docker/openmm.patch + +# We need to run `ldconfig` first to ensure GPUs are visible, due to some quirk +# with Debian. See https://github.com/NVIDIA/nvidia-docker/issues/1399 for +# details. +# ENTRYPOINT does not support easily running multiple commands, so instead we +# write a shell script to wrap them up. +WORKDIR /app/alphafold +RUN echo $'#!/bin/bash\n\ +ldconfig\n\ +python /app/alphafold/run_alphafold.py "$@"' > /app/run_alphafold.sh \ + && chmod +x /app/run_alphafold.sh +ENTRYPOINT ["/app/run_alphafold.sh"] |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/docker/openmm.patch --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/docker/openmm.patch Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,42 @@ +Index: simtk/openmm/app/topology.py +=================================================================== +--- simtk.orig/openmm/app/topology.py ++++ simtk/openmm/app/topology.py +@@ -356,19 +356,35 @@ + def isCyx(res): + names = [atom.name for atom in res._atoms] + return 'SG' in names and 'HG' not in names ++ # This function is used to prevent multiple di-sulfide bonds from being ++ # assigned to a given atom. This is a DeepMind modification. ++ def isDisulfideBonded(atom): ++ for b in self._bonds: ++ if (atom in b and b[0].name == 'SG' and ++ b[1].name == 'SG'): ++ return True ++ ++ return False + + cyx = [res for res in self.residues() if res.name == 'CYS' and isCyx(res)] + atomNames = [[atom.name for atom in res._atoms] for res in cyx] + for i in range(len(cyx)): + sg1 = cyx[i]._atoms[atomNames[i].index('SG')] + pos1 = positions[sg1.index] ++ candidate_distance, candidate_atom = 0.3*nanometers, None + for j in range(i): + sg2 = cyx[j]._atoms[atomNames[j].index('SG')] + pos2 = positions[sg2.index] + delta = [x-y for (x,y) in zip(pos1, pos2)] + distance = sqrt(delta[0]*delta[0] + delta[1]*delta[1] + delta[2]*delta[2]) +- if distance < 0.3*nanometers: +- self.addBond(sg1, sg2) ++ if distance < candidate_distance and not isDisulfideBonded(sg2): ++ candidate_distance = distance ++ candidate_atom = sg2 ++ # Assign bond to closest pair. ++ if candidate_atom: ++ self.addBond(sg1, candidate_atom) ++ ++ + + class Chain(object): + """A Chain object represents a chain within a Topology.""" |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/docker/requirements.txt --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/docker/requirements.txt Tue Mar 01 02:53:05 2022 +0000 |
b |
@@ -0,0 +1,3 @@ +# Dependencies necessary to execute run_docker.py +absl-py==0.13.0 +docker==5.0.0 |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/docker/run_docker.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/docker/run_docker.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,231 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Docker launch script for Alphafold docker image."""\n+\n+import os\n+import pathlib\n+import signal\n+from typing import Tuple\n+\n+from absl import app\n+from absl import flags\n+from absl import logging\n+import docker\n+from docker import types\n+\n+\n+flags.DEFINE_bool(\n+ \'use_gpu\', True, \'Enable NVIDIA runtime to run with GPUs.\')\n+flags.DEFINE_string(\n+ \'gpu_devices\', \'all\',\n+ \'Comma separated list of devices to pass to NVIDIA_VISIBLE_DEVICES.\')\n+flags.DEFINE_list(\n+ \'fasta_paths\', None, \'Paths to FASTA files, each containing a prediction \'\n+ \'target that will be folded one after another. If a FASTA file contains \'\n+ \'multiple sequences, then it will be folded as a multimer. Paths should be \'\n+ \'separated by commas. All FASTA paths must have a unique basename as the \'\n+ \'basename is used to name the output directories for each prediction.\')\n+flags.DEFINE_list(\n+ \'is_prokaryote_list\', None, \'Optional for multimer system, not used by the \'\n+ \'single chain system. This list should contain a boolean for each fasta \'\n+ \'specifying true where the target complex is from a prokaryote, and false \'\n+ \'where it is not, or where the origin is unknown. These values determine \'\n+ \'the pairing method for the MSA.\')\n+flags.DEFINE_string(\n+ \'output_dir\', \'/tmp/alphafold\',\n+ \'Path to a directory that will store the results.\')\n+flags.DEFINE_string(\n+ \'data_dir\', None,\n+ \'Path to directory with supporting data: AlphaFold parameters and genetic \'\n+ \'and template databases. Set to the target of download_all_databases.sh.\')\n+flags.DEFINE_string(\n+ \'docker_image_name\', \'alphafold\', \'Name of the AlphaFold Docker image.\')\n+flags.DEFINE_string(\n+ \'max_template_date\', None,\n+ \'Maximum template release date to consider (ISO-8601 format: YYYY-MM-DD). \'\n+ \'Important if folding historical test sets.\')\n+flags.DEFINE_enum(\n+ \'db_preset\', \'full_dbs\', [\'full_dbs\', \'reduced_dbs\'],\n+ \'Choose preset MSA database configuration - smaller genetic database \'\n+ \'config (reduced_dbs) or full genetic database config (full_dbs)\')\n+flags.DEFINE_enum(\n+ \'model_preset\', \'monomer\',\n+ [\'monomer\', \'monomer_casp14\', \'monomer_ptm\', \'multimer\'],\n+ \'Choose preset model configuration - the monomer model, the monomer model \'\n+ \'with extra ensembling, monomer model with pTM head, or multimer model\')\n+flags.DEFINE_boolean(\n+ \'benchmark\', False,\n+ \'Run multiple JAX model evaluations to obtain a timing that excludes the \'\n+ \'compilation time, which should be more indicative of the time required \'\n+ \'for inferencing many proteins.\')\n+flags.DEFINE_boolean(\n+ \'use_precomputed_msas\', False,\n+ \'Whether to read MSAs that have been written to disk. WARNING: This will \'\n+ \'not check if the sequence, database or configuration have changed.\')\n+\n+FLAGS = flags.FLAGS\n+\n+_ROOT_MOUNT_DIRECTORY = \'/mnt/\'\n+\n+\n+def _create_mount(mount_name: str, path: str) -> Tuple[types.Mount, str]:\n+ path = os.path.abspath(path)\n+ source_path = os.path.dirname(path)\n+ target_path = os.path.join(_ROOT_MOUNT_DIRECTORY, mount_name)\n+ logging.info(\'Mounting %s -> %s\', source_path, target_path)\n+ mount = types.Mount(target_path, source_path, type=\'bind\', read_only=True)\n+ return mount, os.path.join(target_path, os.path.basename(path))\n+\n+\n+def main(argv):\n+ if len(argv) > 1:\n+ raise app.UsageError('..b' for use by hmmsearch.\n+ pdb_seqres_database_path = os.path.join(\n+ FLAGS.data_dir, \'pdb_seqres\', \'pdb_seqres.txt\')\n+\n+ # Path to a directory with template mmCIF structures, each named <pdb_id>.cif.\n+ template_mmcif_dir = os.path.join(FLAGS.data_dir, \'pdb_mmcif\', \'mmcif_files\')\n+\n+ # Path to a file mapping obsolete PDB IDs to their replacements.\n+ obsolete_pdbs_path = os.path.join(FLAGS.data_dir, \'pdb_mmcif\', \'obsolete.dat\')\n+\n+ alphafold_path = pathlib.Path(__file__).parent.parent\n+ data_dir_path = pathlib.Path(FLAGS.data_dir)\n+ if alphafold_path == data_dir_path or alphafold_path in data_dir_path.parents:\n+ raise app.UsageError(\n+ f\'The download directory {FLAGS.data_dir} should not be a subdirectory \'\n+ f\'in the AlphaFold repository directory. If it is, the Docker build is \'\n+ f\'slow since the large databases are copied during the image creation.\')\n+\n+ mounts = []\n+ command_args = []\n+\n+ # Mount each fasta path as a unique target directory.\n+ target_fasta_paths = []\n+ for i, fasta_path in enumerate(FLAGS.fasta_paths):\n+ mount, target_path = _create_mount(f\'fasta_path_{i}\', fasta_path)\n+ mounts.append(mount)\n+ target_fasta_paths.append(target_path)\n+ command_args.append(f\'--fasta_paths={",".join(target_fasta_paths)}\')\n+\n+ database_paths = [\n+ (\'uniref90_database_path\', uniref90_database_path),\n+ (\'mgnify_database_path\', mgnify_database_path),\n+ (\'data_dir\', FLAGS.data_dir),\n+ (\'template_mmcif_dir\', template_mmcif_dir),\n+ (\'obsolete_pdbs_path\', obsolete_pdbs_path),\n+ ]\n+\n+ if FLAGS.model_preset == \'multimer\':\n+ database_paths.append((\'uniprot_database_path\', uniprot_database_path))\n+ database_paths.append((\'pdb_seqres_database_path\',\n+ pdb_seqres_database_path))\n+ else:\n+ database_paths.append((\'pdb70_database_path\', pdb70_database_path))\n+\n+ if FLAGS.db_preset == \'reduced_dbs\':\n+ database_paths.append((\'small_bfd_database_path\', small_bfd_database_path))\n+ else:\n+ database_paths.extend([\n+ (\'uniclust30_database_path\', uniclust30_database_path),\n+ (\'bfd_database_path\', bfd_database_path),\n+ ])\n+ for name, path in database_paths:\n+ if path:\n+ mount, target_path = _create_mount(name, path)\n+ mounts.append(mount)\n+ command_args.append(f\'--{name}={target_path}\')\n+\n+ output_target_path = os.path.join(_ROOT_MOUNT_DIRECTORY, \'output\')\n+ mounts.append(types.Mount(output_target_path, FLAGS.output_dir, type=\'bind\'))\n+\n+ command_args.extend([\n+ f\'--output_dir={output_target_path}\',\n+ f\'--max_template_date={FLAGS.max_template_date}\',\n+ f\'--db_preset={FLAGS.db_preset}\',\n+ f\'--model_preset={FLAGS.model_preset}\',\n+ f\'--benchmark={FLAGS.benchmark}\',\n+ f\'--use_precomputed_msas={FLAGS.use_precomputed_msas}\',\n+ \'--logtostderr\',\n+ ])\n+\n+ if FLAGS.is_prokaryote_list:\n+ command_args.append(\n+ f\'--is_prokaryote_list={",".join(FLAGS.is_prokaryote_list)}\')\n+\n+ client = docker.from_env()\n+ container = client.containers.run(\n+ image=FLAGS.docker_image_name,\n+ command=command_args,\n+ runtime=\'nvidia\' if FLAGS.use_gpu else None,\n+ remove=True,\n+ detach=True,\n+ mounts=mounts,\n+ environment={\n+ \'NVIDIA_VISIBLE_DEVICES\': FLAGS.gpu_devices,\n+ # The following flags allow us to make predictions on proteins that\n+ # would typically be too long to fit into GPU memory.\n+ \'TF_FORCE_UNIFIED_MEMORY\': \'1\',\n+ \'XLA_PYTHON_CLIENT_MEM_FRACTION\': \'4.0\',\n+ })\n+\n+ # Add signal handler to ensure CTRL+C also stops the running container.\n+ signal.signal(signal.SIGINT,\n+ lambda unused_sig, unused_frame: container.kill())\n+\n+ for line in container.logs(stream=True):\n+ logging.info(line.strip().decode(\'utf-8\'))\n+\n+\n+if __name__ == \'__main__\':\n+ flags.mark_flags_as_required([\n+ \'data_dir\',\n+ \'fasta_paths\',\n+ \'max_template_date\',\n+ ])\n+ app.run(main)\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/imgs/casp14_predictions.gif |
b |
Binary file docker/alphafold/imgs/casp14_predictions.gif has changed |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/imgs/header.jpg |
b |
Binary file docker/alphafold/imgs/header.jpg has changed |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/notebooks/AlphaFold.ipynb --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/notebooks/AlphaFold.ipynb Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,795 @@\n+{\n+ "cells": [\n+ {\n+ "cell_type": "markdown",\n+ "metadata": {\n+ "id": "pc5-mbsX9PZC"\n+ },\n+ "source": [\n+ "# AlphaFold Colab\\n",\n+ "\\n",\n+ "This Colab notebook allows you to easily predict the structure of a protein using a slightly simplified version of [AlphaFold v2.1.0](https://doi.org/10.1038/s41586-021-03819-2). \\n",\n+ "\\n",\n+ "**Differences to AlphaFold v2.1.0**\\n",\n+ "\\n",\n+ "In comparison to AlphaFold v2.1.0, this Colab notebook uses **no templates (homologous structures)** and a selected portion of the [BFD database](https://bfd.mmseqs.com/). We have validated these changes on several thousand recent PDB structures. While accuracy will be near-identical to the full AlphaFold system on many targets, a small fraction have a large drop in accuracy due to the smaller MSA and lack of templates. For best reliability, we recommend instead using the [full open source AlphaFold](https://github.com/deepmind/alphafold/), or the [AlphaFold Protein Structure Database](https://alphafold.ebi.ac.uk/).\\n",\n+ "\\n",\n+ "**This Colab has an small drop in average accuracy for multimers compared to local AlphaFold installation, for full multimer accuracy it is highly recommended to run [AlphaFold locally](https://github.com/deepmind/alphafold#running-alphafold).** Moreover, the AlphaFold-Multimer requires searching for MSA for every unique sequence in the complex, hence it is substantially slower. If your notebook times-out due to slow multimer MSA search, we recommend either using Colab Pro or running AlphaFold locally.\\n",\n+ "\\n",\n+ "Please note that this Colab notebook is provided as an early-access prototype and is not a finished product. It is provided for theoretical modelling only and caution should be exercised in its use. \\n",\n+ "\\n",\n+ "**Citing this work**\\n",\n+ "\\n",\n+ "Any publication that discloses findings arising from using this notebook should [cite](https://github.com/deepmind/alphafold/#citing-this-work) the [AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2).\\n",\n+ "\\n",\n+ "**Licenses**\\n",\n+ "\\n",\n+ "This Colab uses the [AlphaFold model parameters](https://github.com/deepmind/alphafold/#model-parameters-license) and its outputs are thus for non-commercial use only, under the Creative Commons Attribution-NonCommercial 4.0 International ([CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/legalcode)) license. The Colab itself is provided under the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0). See the full license statement below.\\n",\n+ "\\n",\n+ "**More information**\\n",\n+ "\\n",\n+ "You can find more information about how AlphaFold works in the following papers:\\n",\n+ "\\n",\n+ "* [AlphaFold methods paper](https://www.nature.com/articles/s41586-021-03819-2)\\n",\n+ "* [AlphaFold predictions of the human proteome paper](https://www.nature.com/articles/s41586-021-03828-1)\\n",\n+ "* [AlphaFold-Multimer paper](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1)\\n",\n+ "\\n",\n+ "FAQ on how to interpret AlphaFold predictions are [here](https://alphafold.ebi.ac.uk/faq)."\n+ ]\n+ },\n+ {\n+ "cell_type": "code",\n+ "execution_count": null,\n+ "metadata": {\n+ "cellView": "form",\n+ "id": "woIxeCPygt7K"\n+ },\n+ "outputs": [],\n+ "source": [\n+ "#@title Install third-party software\\n",\n+ "\\n",\n+ "#@markdown Please execute this cell by pressing the _Play_ button \\n",\n+ "#@markdown on the left to download and import third-party software \\n",\n+ "#@markdown in this Colab notebook. (See the [acknowledgements](https://github.com/deepmind/alphafold/#acknowledgements) in our readme.)\\n",\n+ "\\n",\n+ "#@markdown **Note**: This installs the software on the Co'..b'ignment creation.\\n"\n+ ]\n+ },\n+ {\n+ "cell_type": "markdown",\n+ "metadata": {\n+ "id": "YfPhvYgKC81B"\n+ },\n+ "source": [\n+ "# License and Disclaimer\\n",\n+ "\\n",\n+ "This is not an officially-supported Google product.\\n",\n+ "\\n",\n+ "This Colab notebook and other information provided is for theoretical modelling only, caution should be exercised in its use. It is provided \xe2\x80\x98as-is\xe2\x80\x99 without any warranty of any kind, whether expressed or implied. Information is not intended to be a substitute for professional medical advice, diagnosis, or treatment, and does not constitute medical or other professional advice.\\n",\n+ "\\n",\n+ "Copyright 2021 DeepMind Technologies Limited.\\n",\n+ "\\n",\n+ "\\n",\n+ "## AlphaFold Code License\\n",\n+ "\\n",\n+ "Licensed under the Apache License, Version 2.0 (the \\"License\\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0.\\n",\n+ "\\n",\n+ "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \\"AS IS\\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\\n",\n+ "\\n",\n+ "## Model Parameters License\\n",\n+ "\\n",\n+ "The AlphaFold parameters are made available for non-commercial use only, under the terms of the Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) license. You can find details at: https://creativecommons.org/licenses/by-nc/4.0/legalcode\\n",\n+ "\\n",\n+ "\\n",\n+ "## Third-party software\\n",\n+ "\\n",\n+ "Use of the third-party software, libraries or code referred to in the [Acknowledgements section](https://github.com/deepmind/alphafold/#acknowledgements) in the AlphaFold README may be governed by separate terms and conditions or license provisions. Your use of the third-party software, libraries or code is subject to any such terms and you should check that you can comply with any applicable restrictions or terms and conditions before use.\\n",\n+ "\\n",\n+ "\\n",\n+ "## Mirrored Databases\\n",\n+ "\\n",\n+ "The following databases have been mirrored by DeepMind, and are available with reference to the following:\\n",\n+ "* UniProt: v2021\\\\_03 (unmodified), by The UniProt Consortium, available under a [Creative Commons Attribution-NoDerivatives 4.0 International License](http://creativecommons.org/licenses/by-nd/4.0/).\\n",\n+ "* UniRef90: v2021\\\\_03 (unmodified), by The UniProt Consortium, available under a [Creative Commons Attribution-NoDerivatives 4.0 International License](http://creativecommons.org/licenses/by-nd/4.0/).\\n",\n+ "* MGnify: v2019\\\\_05 (unmodified), by Mitchell AL et al., available free of all copyright restrictions and made fully and freely available for both non-commercial and commercial use under [CC0 1.0 Universal (CC0 1.0) Public Domain Dedication](https://creativecommons.org/publicdomain/zero/1.0/).\\n",\n+ "* BFD: (modified), by Steinegger M. and S\xc3\xb6ding J., modified by DeepMind, available under a [Creative Commons Attribution-ShareAlike 4.0 International License](https://creativecommons.org/licenses/by/4.0/). See the Methods section of the [AlphaFold proteome paper](https://www.nature.com/articles/s41586-021-03828-1) for details."\n+ ]\n+ }\n+ ],\n+ "metadata": {\n+ "accelerator": "GPU",\n+ "colab": {\n+ "collapsed_sections": [],\n+ "name": "AlphaFold.ipynb",\n+ "private_outputs": true,\n+ "provenance": []\n+ },\n+ "kernelspec": {\n+ "display_name": "Python 3",\n+ "name": "python3"\n+ },\n+ "language_info": {\n+ "name": "python"\n+ }\n+ },\n+ "nbformat": 4,\n+ "nbformat_minor": 0\n+}\n' |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/requirements.txt --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/requirements.txt Tue Mar 01 02:53:05 2022 +0000 |
b |
@@ -0,0 +1,13 @@ +absl-py==0.13.0 +biopython==1.79 +chex==0.0.7 +dm-haiku==0.0.4 +dm-tree==0.1.6 +docker==5.0.0 +immutabledict==2.0.0 +jax==0.2.14 +ml-collections==0.1.0 +numpy==1.19.5 +pandas==1.3.4 +scipy==1.7.0 +tensorflow-cpu==2.5.0 |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/run_alphafold.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/run_alphafold.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
b'@@ -0,0 +1,427 @@\n+# Copyright 2021 DeepMind Technologies Limited\n+#\n+# Licensed under the Apache License, Version 2.0 (the "License");\n+# you may not use this file except in compliance with the License.\n+# You may obtain a copy of the License at\n+#\n+# http://www.apache.org/licenses/LICENSE-2.0\n+#\n+# Unless required by applicable law or agreed to in writing, software\n+# distributed under the License is distributed on an "AS IS" BASIS,\n+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n+# See the License for the specific language governing permissions and\n+# limitations under the License.\n+\n+"""Full AlphaFold protein structure prediction script."""\n+import json\n+import os\n+import pathlib\n+import pickle\n+import random\n+import shutil\n+import sys\n+import time\n+from typing import Dict, Union, Optional\n+\n+from absl import app\n+from absl import flags\n+from absl import logging\n+from alphafold.common import protein\n+from alphafold.common import residue_constants\n+from alphafold.data import pipeline\n+from alphafold.data import pipeline_multimer\n+from alphafold.data import templates\n+from alphafold.data.tools import hhsearch\n+from alphafold.data.tools import hmmsearch\n+from alphafold.model import config\n+from alphafold.model import model\n+from alphafold.relax import relax\n+import numpy as np\n+\n+from alphafold.model import data\n+# Internal import (7716).\n+\n+logging.set_verbosity(logging.INFO)\n+\n+flags.DEFINE_list(\n+ \'fasta_paths\', None, \'Paths to FASTA files, each containing a prediction \'\n+ \'target that will be folded one after another. If a FASTA file contains \'\n+ \'multiple sequences, then it will be folded as a multimer. Paths should be \'\n+ \'separated by commas. All FASTA paths must have a unique basename as the \'\n+ \'basename is used to name the output directories for each prediction.\')\n+flags.DEFINE_list(\n+ \'is_prokaryote_list\', None, \'Optional for multimer system, not used by the \'\n+ \'single chain system. This list should contain a boolean for each fasta \'\n+ \'specifying true where the target complex is from a prokaryote, and false \'\n+ \'where it is not, or where the origin is unknown. These values determine \'\n+ \'the pairing method for the MSA.\')\n+\n+flags.DEFINE_string(\'data_dir\', None, \'Path to directory of supporting data.\')\n+flags.DEFINE_string(\'output_dir\', None, \'Path to a directory that will \'\n+ \'store the results.\')\n+flags.DEFINE_string(\'jackhmmer_binary_path\', shutil.which(\'jackhmmer\'),\n+ \'Path to the JackHMMER executable.\')\n+flags.DEFINE_string(\'hhblits_binary_path\', shutil.which(\'hhblits\'),\n+ \'Path to the HHblits executable.\')\n+flags.DEFINE_string(\'hhsearch_binary_path\', shutil.which(\'hhsearch\'),\n+ \'Path to the HHsearch executable.\')\n+flags.DEFINE_string(\'hmmsearch_binary_path\', shutil.which(\'hmmsearch\'),\n+ \'Path to the hmmsearch executable.\')\n+flags.DEFINE_string(\'hmmbuild_binary_path\', shutil.which(\'hmmbuild\'),\n+ \'Path to the hmmbuild executable.\')\n+flags.DEFINE_string(\'kalign_binary_path\', shutil.which(\'kalign\'),\n+ \'Path to the Kalign executable.\')\n+flags.DEFINE_string(\'uniref90_database_path\', None, \'Path to the Uniref90 \'\n+ \'database for use by JackHMMER.\')\n+flags.DEFINE_string(\'mgnify_database_path\', None, \'Path to the MGnify \'\n+ \'database for use by JackHMMER.\')\n+flags.DEFINE_string(\'bfd_database_path\', None, \'Path to the BFD \'\n+ \'database for use by HHblits.\')\n+flags.DEFINE_string(\'small_bfd_database_path\', None, \'Path to the small \'\n+ \'version of BFD used with the "reduced_dbs" preset.\')\n+flags.DEFINE_string(\'uniclust30_database_path\', None, \'Path to the Uniclust30 \'\n+ \'database for use by HHblits.\')\n+flags.DEFINE_string(\'uniprot_database_path\', None, \'Path to the Uniprot \'\n+ \'database for use by JackHMMer.\')\n+flag'..b"earch.Hmmsearch(\n+ binary_path=FLAGS.hmmsearch_binary_path,\n+ hmmbuild_binary_path=FLAGS.hmmbuild_binary_path,\n+ database_path=FLAGS.pdb_seqres_database_path)\n+ template_featurizer = templates.HmmsearchHitFeaturizer(\n+ mmcif_dir=FLAGS.template_mmcif_dir,\n+ max_template_date=FLAGS.max_template_date,\n+ max_hits=MAX_TEMPLATE_HITS,\n+ kalign_binary_path=FLAGS.kalign_binary_path,\n+ release_dates_path=None,\n+ obsolete_pdbs_path=FLAGS.obsolete_pdbs_path)\n+ else:\n+ template_searcher = hhsearch.HHSearch(\n+ binary_path=FLAGS.hhsearch_binary_path,\n+ databases=[FLAGS.pdb70_database_path])\n+ template_featurizer = templates.HhsearchHitFeaturizer(\n+ mmcif_dir=FLAGS.template_mmcif_dir,\n+ max_template_date=FLAGS.max_template_date,\n+ max_hits=MAX_TEMPLATE_HITS,\n+ kalign_binary_path=FLAGS.kalign_binary_path,\n+ release_dates_path=None,\n+ obsolete_pdbs_path=FLAGS.obsolete_pdbs_path)\n+\n+ monomer_data_pipeline = pipeline.DataPipeline(\n+ jackhmmer_binary_path=FLAGS.jackhmmer_binary_path,\n+ hhblits_binary_path=FLAGS.hhblits_binary_path,\n+ uniref90_database_path=FLAGS.uniref90_database_path,\n+ mgnify_database_path=FLAGS.mgnify_database_path,\n+ bfd_database_path=FLAGS.bfd_database_path,\n+ uniclust30_database_path=FLAGS.uniclust30_database_path,\n+ small_bfd_database_path=FLAGS.small_bfd_database_path,\n+ template_searcher=template_searcher,\n+ template_featurizer=template_featurizer,\n+ use_small_bfd=use_small_bfd,\n+ use_precomputed_msas=FLAGS.use_precomputed_msas)\n+\n+ if run_multimer_system:\n+ data_pipeline = pipeline_multimer.DataPipeline(\n+ monomer_data_pipeline=monomer_data_pipeline,\n+ jackhmmer_binary_path=FLAGS.jackhmmer_binary_path,\n+ uniprot_database_path=FLAGS.uniprot_database_path,\n+ use_precomputed_msas=FLAGS.use_precomputed_msas)\n+ else:\n+ data_pipeline = monomer_data_pipeline\n+\n+ model_runners = {}\n+ model_names = config.MODEL_PRESETS[FLAGS.model_preset]\n+ for model_name in model_names:\n+ model_config = config.model_config(model_name)\n+ if run_multimer_system:\n+ model_config.model.num_ensemble_eval = num_ensemble\n+ else:\n+ model_config.data.eval.num_ensemble = num_ensemble\n+ model_params = data.get_model_haiku_params(\n+ model_name=model_name, data_dir=FLAGS.data_dir)\n+ model_runner = model.RunModel(model_config, model_params)\n+ model_runners[model_name] = model_runner\n+\n+ logging.info('Have %d models: %s', len(model_runners),\n+ list(model_runners.keys()))\n+\n+ amber_relaxer = relax.AmberRelaxation(\n+ max_iterations=RELAX_MAX_ITERATIONS,\n+ tolerance=RELAX_ENERGY_TOLERANCE,\n+ stiffness=RELAX_STIFFNESS,\n+ exclude_residues=RELAX_EXCLUDE_RESIDUES,\n+ max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS)\n+\n+ random_seed = FLAGS.random_seed\n+ if random_seed is None:\n+ random_seed = random.randrange(sys.maxsize // len(model_names))\n+ logging.info('Using random seed %d for the data pipeline', random_seed)\n+\n+ # Predict structure for each of the sequences.\n+ for i, fasta_path in enumerate(FLAGS.fasta_paths):\n+ is_prokaryote = is_prokaryote_list[i] if run_multimer_system else None\n+ fasta_name = fasta_names[i]\n+ predict_structure(\n+ fasta_path=fasta_path,\n+ fasta_name=fasta_name,\n+ output_dir_base=FLAGS.output_dir,\n+ data_pipeline=data_pipeline,\n+ model_runners=model_runners,\n+ amber_relaxer=amber_relaxer,\n+ benchmark=FLAGS.benchmark,\n+ random_seed=random_seed,\n+ is_prokaryote=is_prokaryote)\n+\n+\n+if __name__ == '__main__':\n+ flags.mark_flags_as_required([\n+ 'fasta_paths',\n+ 'output_dir',\n+ 'data_dir',\n+ 'uniref90_database_path',\n+ 'mgnify_database_path',\n+ 'template_mmcif_dir',\n+ 'max_template_date',\n+ 'obsolete_pdbs_path',\n+ ])\n+\n+ app.run(main)\n" |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/run_alphafold_test.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/run_alphafold_test.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,101 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for run_alphafold.""" + +import os + +from absl.testing import absltest +from absl.testing import parameterized +import run_alphafold +import mock +import numpy as np +# Internal import (7716). + + +class RunAlphafoldTest(parameterized.TestCase): + + @parameterized.named_parameters( + ('relax', True), + ('no_relax', False), + ) + def test_end_to_end(self, do_relax): + + data_pipeline_mock = mock.Mock() + model_runner_mock = mock.Mock() + amber_relaxer_mock = mock.Mock() + + data_pipeline_mock.process.return_value = {} + model_runner_mock.process_features.return_value = { + 'aatype': np.zeros((12, 10), dtype=np.int32), + 'residue_index': np.tile(np.arange(10, dtype=np.int32)[None], (12, 1)), + } + model_runner_mock.predict.return_value = { + 'structure_module': { + 'final_atom_positions': np.zeros((10, 37, 3)), + 'final_atom_mask': np.ones((10, 37)), + }, + 'predicted_lddt': { + 'logits': np.ones((10, 50)), + }, + 'plddt': np.ones(10) * 42, + 'ranking_confidence': 90, + 'ptm': np.array(0.), + 'aligned_confidence_probs': np.zeros((10, 10, 50)), + 'predicted_aligned_error': np.zeros((10, 10)), + 'max_predicted_aligned_error': np.array(0.), + } + model_runner_mock.multimer_mode = False + amber_relaxer_mock.process.return_value = ('RELAXED', None, None) + + fasta_path = os.path.join(absltest.get_default_test_tmpdir(), + 'target.fasta') + with open(fasta_path, 'wt') as f: + f.write('>A\nAAAAAAAAAAAAA') + fasta_name = 'test' + + out_dir = absltest.get_default_test_tmpdir() + + run_alphafold.predict_structure( + fasta_path=fasta_path, + fasta_name=fasta_name, + output_dir_base=out_dir, + data_pipeline=data_pipeline_mock, + model_runners={'model1': model_runner_mock}, + amber_relaxer=amber_relaxer_mock if do_relax else None, + benchmark=False, + random_seed=0) + + base_output_files = os.listdir(out_dir) + self.assertIn('target.fasta', base_output_files) + self.assertIn('test', base_output_files) + + target_output_files = os.listdir(os.path.join(out_dir, 'test')) + expected_files = [ + 'features.pkl', 'msas', 'ranked_0.pdb', 'ranking_debug.json', + 'result_model1.pkl', 'timings.json', 'unrelaxed_model1.pdb', + ] + if do_relax: + expected_files.append('relaxed_model1.pdb') + self.assertCountEqual(expected_files, target_output_files) + + # Check that pLDDT is set in the B-factor column. + with open(os.path.join(out_dir, 'test', 'unrelaxed_model1.pdb')) as f: + for line in f: + if line.startswith('ATOM'): + self.assertEqual(line[61:66], '42.00') + + +if __name__ == '__main__': + absltest.main() |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/scripts/download_all_data.sh --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/scripts/download_all_data.sh Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,74 @@ +#!/bin/bash +# +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Downloads and unzips all required data for AlphaFold. +# +# Usage: bash download_all_data.sh /path/to/download/directory +set -e + +if [[ $# -eq 0 ]]; then + echo "Error: download directory must be provided as an input argument." + exit 1 +fi + +if ! command -v aria2c &> /dev/null ; then + echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." + exit 1 +fi + +DOWNLOAD_DIR="$1" +DOWNLOAD_MODE="${2:-full_dbs}" # Default mode to full_dbs. +if [[ "${DOWNLOAD_MODE}" != full_dbs && "${DOWNLOAD_MODE}" != reduced_dbs ]] +then + echo "DOWNLOAD_MODE ${DOWNLOAD_MODE} not recognized." + exit 1 +fi + +SCRIPT_DIR="$(dirname "$(realpath "$0")")" + +echo "Downloading AlphaFold parameters..." +bash "${SCRIPT_DIR}/download_alphafold_params.sh" "${DOWNLOAD_DIR}" + +if [[ "${DOWNLOAD_MODE}" = reduced_dbs ]] ; then + echo "Downloading Small BFD..." + bash "${SCRIPT_DIR}/download_small_bfd.sh" "${DOWNLOAD_DIR}" +else + echo "Downloading BFD..." + bash "${SCRIPT_DIR}/download_bfd.sh" "${DOWNLOAD_DIR}" +fi + +echo "Downloading MGnify..." +bash "${SCRIPT_DIR}/download_mgnify.sh" "${DOWNLOAD_DIR}" + +echo "Downloading PDB70..." +bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}" + +echo "Downloading PDB mmCIF files..." +bash "${SCRIPT_DIR}/download_pdb_mmcif.sh" "${DOWNLOAD_DIR}" + +echo "Downloading Uniclust30..." +bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}" + +echo "Downloading Uniref90..." +bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}" + +echo "Downloading UniProt..." +bash "${SCRIPT_DIR}/download_uniprot.sh" "${DOWNLOAD_DIR}" + +echo "Downloading PDB SeqRes..." +bash "${SCRIPT_DIR}/download_pdb_seqres.sh" "${DOWNLOAD_DIR}" + +echo "All data downloaded." |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/scripts/download_alphafold_params.sh --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/scripts/download_alphafold_params.sh Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,41 @@ +#!/bin/bash +# +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Downloads and unzips the AlphaFold parameters. +# +# Usage: bash download_alphafold_params.sh /path/to/download/directory +set -e + +if [[ $# -eq 0 ]]; then + echo "Error: download directory must be provided as an input argument." + exit 1 +fi + +if ! command -v aria2c &> /dev/null ; then + echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." + exit 1 +fi + +DOWNLOAD_DIR="$1" +ROOT_DIR="${DOWNLOAD_DIR}/params" +SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2021-10-27.tar" +BASENAME=$(basename "${SOURCE_URL}") + +mkdir --parents "${ROOT_DIR}" +aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" +tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \ + --directory="${ROOT_DIR}" --preserve-permissions +rm "${ROOT_DIR}/${BASENAME}" |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/scripts/download_bfd.sh --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/scripts/download_bfd.sh Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,43 @@ +#!/bin/bash +# +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Downloads and unzips the BFD database for AlphaFold. +# +# Usage: bash download_bfd.sh /path/to/download/directory +set -e + +if [[ $# -eq 0 ]]; then + echo "Error: download directory must be provided as an input argument." + exit 1 +fi + +if ! command -v aria2c &> /dev/null ; then + echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." + exit 1 +fi + +DOWNLOAD_DIR="$1" +ROOT_DIR="${DOWNLOAD_DIR}/bfd" +# Mirror of: +# https://bfd.mmseqs.com/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz. +SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz" +BASENAME=$(basename "${SOURCE_URL}") + +mkdir --parents "${ROOT_DIR}" +aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" +tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \ + --directory="${ROOT_DIR}" +rm "${ROOT_DIR}/${BASENAME}" |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/scripts/download_mgnify.sh --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/scripts/download_mgnify.sh Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,43 @@ +#!/bin/bash +# +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Downloads and unzips the MGnify database for AlphaFold. +# +# Usage: bash download_mgnify.sh /path/to/download/directory +set -e + +if [[ $# -eq 0 ]]; then + echo "Error: download directory must be provided as an input argument." + exit 1 +fi + +if ! command -v aria2c &> /dev/null ; then + echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." + exit 1 +fi + +DOWNLOAD_DIR="$1" +ROOT_DIR="${DOWNLOAD_DIR}/mgnify" +# Mirror of: +# ftp://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/2018_12/mgy_clusters.fa.gz +SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/mgy_clusters_2018_12.fa.gz" +BASENAME=$(basename "${SOURCE_URL}") + +mkdir --parents "${ROOT_DIR}" +aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" +pushd "${ROOT_DIR}" +gunzip "${ROOT_DIR}/${BASENAME}" +popd |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/scripts/download_pdb70.sh --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/scripts/download_pdb70.sh Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,41 @@ +#!/bin/bash +# +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Downloads and unzips the PDB70 database for AlphaFold. +# +# Usage: bash download_pdb70.sh /path/to/download/directory +set -e + +if [[ $# -eq 0 ]]; then + echo "Error: download directory must be provided as an input argument." + exit 1 +fi + +if ! command -v aria2c &> /dev/null ; then + echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." + exit 1 +fi + +DOWNLOAD_DIR="$1" +ROOT_DIR="${DOWNLOAD_DIR}/pdb70" +SOURCE_URL="http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/old-releases/pdb70_from_mmcif_200401.tar.gz" +BASENAME=$(basename "${SOURCE_URL}") + +mkdir --parents "${ROOT_DIR}" +aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" +tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \ + --directory="${ROOT_DIR}" +rm "${ROOT_DIR}/${BASENAME}" |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/scripts/download_pdb_mmcif.sh --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/scripts/download_pdb_mmcif.sh Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,61 @@ +#!/bin/bash +# +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Downloads, unzips and flattens the PDB database for AlphaFold. +# +# Usage: bash download_pdb_mmcif.sh /path/to/download/directory +set -e + +if [[ $# -eq 0 ]]; then + echo "Error: download directory must be provided as an input argument." + exit 1 +fi + +if ! command -v aria2c &> /dev/null ; then + echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." + exit 1 +fi + +if ! command -v rsync &> /dev/null ; then + echo "Error: rsync could not be found. Please install rsync." + exit 1 +fi + +DOWNLOAD_DIR="$1" +ROOT_DIR="${DOWNLOAD_DIR}/pdb_mmcif" +RAW_DIR="${ROOT_DIR}/raw" +MMCIF_DIR="${ROOT_DIR}/mmcif_files" + +echo "Running rsync to fetch all mmCIF files (note that the rsync progress estimate might be inaccurate)..." +mkdir --parents "${RAW_DIR}" +rsync --recursive --links --perms --times --compress --info=progress2 --delete --port=33444 \ + rsync.rcsb.org::ftp_data/structures/divided/mmCIF/ \ + "${RAW_DIR}" + +echo "Unzipping all mmCIF files..." +find "${RAW_DIR}/" -type f -iname "*.gz" -exec gunzip {} + + +echo "Flattening all mmCIF files..." +mkdir --parents "${MMCIF_DIR}" +find "${RAW_DIR}" -type d -empty -delete # Delete empty directories. +for subdir in "${RAW_DIR}"/*; do + mv "${subdir}/"*.cif "${MMCIF_DIR}" +done + +# Delete empty download directory structure. +find "${RAW_DIR}" -type d -empty -delete + +aria2c "ftp://ftp.wwpdb.org/pub/pdb/data/status/obsolete.dat" --dir="${ROOT_DIR}" |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/scripts/download_pdb_seqres.sh --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/scripts/download_pdb_seqres.sh Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,38 @@ +#!/bin/bash +# +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Downloads and unzips the PDB SeqRes database for AlphaFold. +# +# Usage: bash download_pdb_seqres.sh /path/to/download/directory +set -e + +if [[ $# -eq 0 ]]; then + echo "Error: download directory must be provided as an input argument." + exit 1 +fi + +if ! command -v aria2c &> /dev/null ; then + echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." + exit 1 +fi + +DOWNLOAD_DIR="$1" +ROOT_DIR="${DOWNLOAD_DIR}/pdb_seqres" +SOURCE_URL="ftp://ftp.wwpdb.org/pub/pdb/derived_data/pdb_seqres.txt" +BASENAME=$(basename "${SOURCE_URL}") + +mkdir --parents "${ROOT_DIR}" +aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/scripts/download_small_bfd.sh --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/scripts/download_small_bfd.sh Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,41 @@ +#!/bin/bash +# +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Downloads and unzips the Small BFD database for AlphaFold. +# +# Usage: bash download_small_bfd.sh /path/to/download/directory +set -e + +if [[ $# -eq 0 ]]; then + echo "Error: download directory must be provided as an input argument." + exit 1 +fi + +if ! command -v aria2c &> /dev/null ; then + echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." + exit 1 +fi + +DOWNLOAD_DIR="$1" +ROOT_DIR="${DOWNLOAD_DIR}/small_bfd" +SOURCE_URL="https://storage.googleapis.com/alphafold-databases/reduced_dbs/bfd-first_non_consensus_sequences.fasta.gz" +BASENAME=$(basename "${SOURCE_URL}") + +mkdir --parents "${ROOT_DIR}" +aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" +pushd "${ROOT_DIR}" +gunzip "${ROOT_DIR}/${BASENAME}" +popd |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/scripts/download_uniclust30.sh --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/scripts/download_uniclust30.sh Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,43 @@ +#!/bin/bash +# +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Downloads and unzips the Uniclust30 database for AlphaFold. +# +# Usage: bash download_uniclust30.sh /path/to/download/directory +set -e + +if [[ $# -eq 0 ]]; then + echo "Error: download directory must be provided as an input argument." + exit 1 +fi + +if ! command -v aria2c &> /dev/null ; then + echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." + exit 1 +fi + +DOWNLOAD_DIR="$1" +ROOT_DIR="${DOWNLOAD_DIR}/uniclust30" +# Mirror of: +# http://wwwuser.gwdg.de/~compbiol/uniclust/2018_08/uniclust30_2018_08_hhsuite.tar.gz +SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/uniclust30_2018_08_hhsuite.tar.gz" +BASENAME=$(basename "${SOURCE_URL}") + +mkdir --parents "${ROOT_DIR}" +aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" +tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \ + --directory="${ROOT_DIR}" +rm "${ROOT_DIR}/${BASENAME}" |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/scripts/download_uniprot.sh --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/scripts/download_uniprot.sh Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,55 @@ +#!/bin/bash +# +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Downloads, unzips and merges the SwissProt and TrEMBL databases for +# AlphaFold-Multimer. +# +# Usage: bash download_uniprot.sh /path/to/download/directory +set -e + +if [[ $# -eq 0 ]]; then + echo "Error: download directory must be provided as an input argument." + exit 1 +fi + +if ! command -v aria2c &> /dev/null ; then + echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." + exit 1 +fi + +DOWNLOAD_DIR="$1" +ROOT_DIR="${DOWNLOAD_DIR}/uniprot" + +TREMBL_SOURCE_URL="ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_trembl.fasta.gz" +TREMBL_BASENAME=$(basename "${TREMBL_SOURCE_URL}") +TREMBL_UNZIPPED_BASENAME="${TREMBL_BASENAME%.gz}" + +SPROT_SOURCE_URL="ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.fasta.gz" +SPROT_BASENAME=$(basename "${SPROT_SOURCE_URL}") +SPROT_UNZIPPED_BASENAME="${SPROT_BASENAME%.gz}" + +mkdir --parents "${ROOT_DIR}" +aria2c "${TREMBL_SOURCE_URL}" --dir="${ROOT_DIR}" +aria2c "${SPROT_SOURCE_URL}" --dir="${ROOT_DIR}" +pushd "${ROOT_DIR}" +gunzip "${ROOT_DIR}/${TREMBL_BASENAME}" +gunzip "${ROOT_DIR}/${SPROT_BASENAME}" + +# Concatenate TrEMBL and SwissProt, rename to uniprot and clean up. +cat "${ROOT_DIR}/${SPROT_UNZIPPED_BASENAME}" >> "${ROOT_DIR}/${TREMBL_UNZIPPED_BASENAME}" +mv "${ROOT_DIR}/${TREMBL_UNZIPPED_BASENAME}" "${ROOT_DIR}/uniprot.fasta" +rm "${ROOT_DIR}/${SPROT_UNZIPPED_BASENAME}" +popd |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/scripts/download_uniref90.sh --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/scripts/download_uniref90.sh Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,41 @@ +#!/bin/bash +# +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Downloads and unzips the UniRef90 database for AlphaFold. +# +# Usage: bash download_uniref90.sh /path/to/download/directory +set -e + +if [[ $# -eq 0 ]]; then + echo "Error: download directory must be provided as an input argument." + exit 1 +fi + +if ! command -v aria2c &> /dev/null ; then + echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." + exit 1 +fi + +DOWNLOAD_DIR="$1" +ROOT_DIR="${DOWNLOAD_DIR}/uniref90" +SOURCE_URL="ftp://ftp.uniprot.org/pub/databases/uniprot/uniref/uniref90/uniref90.fasta.gz" +BASENAME=$(basename "${SOURCE_URL}") + +mkdir --parents "${ROOT_DIR}" +aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" +pushd "${ROOT_DIR}" +gunzip "${ROOT_DIR}/${BASENAME}" +popd |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/alphafold/setup.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/setup.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,58 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Install script for setuptools.""" + +from setuptools import find_packages +from setuptools import setup + +setup( + name='alphafold', + version='2.1.0', + description='An implementation of the inference pipeline of AlphaFold v2.0.' + 'This is a completely new model that was entered as AlphaFold2 in CASP14 ' + 'and published in Nature.', + author='DeepMind', + author_email='alphafold@deepmind.com', + license='Apache License, Version 2.0', + url='https://github.com/deepmind/alphafold', + packages=find_packages(), + install_requires=[ + 'absl-py', + 'biopython', + 'chex', + 'dm-haiku', + 'dm-tree', + 'docker', + 'immutabledict', + 'jax', + 'ml-collections', + 'numpy', + 'pandas', + 'scipy', + 'tensorflow-cpu', + ], + tests_require=['mock'], + classifiers=[ + 'Development Status :: 5 - Production/Stable', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: POSIX :: Linux', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + ], +) |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/claremcwhite/Dockerfile --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/claremcwhite/Dockerfile Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -0,0 +1,87 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARG CUDA_FULL=11.2.2 +FROM nvidia/cuda:${CUDA_FULL}-cudnn8-runtime-ubuntu20.04 +# FROM directive resets ARGS, so we specify again (the value is retained if +# previously set). +ARG CUDA_FULL +ARG CUDA=11.2 +# JAXLIB no longer built for all minor CUDA versions: +# https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-0166-may-11-2021 +ARG CUDA_JAXLIB=11.1 + +# Use bash to support string substitution. +SHELL ["/bin/bash", "-c"] + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + build-essential \ + cmake \ + cuda-command-line-tools-${CUDA/./-} \ + git \ + hmmer \ + kalign \ + tzdata \ + wget \ + && rm -rf /var/lib/apt/lists/* + +# Compile HHsuite from source. +RUN git clone --branch v3.3.0 https://github.com/soedinglab/hh-suite.git /tmp/hh-suite \ + && mkdir /tmp/hh-suite/build +WORKDIR /tmp/hh-suite/build +RUN cmake -DCMAKE_INSTALL_PREFIX=/opt/hhsuite .. \ + && make -j 4 && make install \ + && ln -s /opt/hhsuite/bin/* /usr/bin \ + && rm -rf /tmp/hh-suite + +# Install Miniconda package manger. +RUN wget -q -P /tmp \ + https://repo.anaconda.com/miniconda/Miniconda3-py38_4.9.2-Linux-x86_64.sh \ + && bash /tmp/Miniconda3-py38_4.9.2-Linux-x86_64.sh -b -p /opt/conda \ + && rm /tmp/Miniconda3-py38_4.9.2-Linux-x86_64.sh + +# Install conda packages. +ENV PATH="/opt/conda/bin:$PATH" +RUN conda update -qy conda \ + && conda install -y -c conda-forge \ + openmm=7.5.1 \ + cudatoolkit==${CUDA_FULL} \ + pdbfixer \ + pip + +COPY . /app/alphafold +RUN wget -q -P /app/alphafold/alphafold/common/ \ + https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt + +# Install pip packages. +RUN pip3 install --upgrade pip \ + && pip3 install -r /app/alphafold/requirements.txt \ + && pip3 install --upgrade jax jaxlib==0.1.69+cuda${CUDA_JAXLIB/./} -f \ + https://storage.googleapis.com/jax-releases/jax_releases.html + +# Apply OpenMM patch. +WORKDIR /opt/conda/lib/python3.8/site-packages +RUN patch -p0 < /app/alphafold/docker/openmm.patch + +# We need to run `ldconfig` first to ensure GPUs are visible, due to some quirk +# with Debian. See https://github.com/NVIDIA/nvidia-docker/issues/1399 for +# details. +# ENTRYPOINT does not support easily running multiple commands, so instead we +# write a shell script to wrap them up. +WORKDIR /app/alphafold +RUN echo $'#!/bin/bash\n\ +ldconfig\n\ +python /app/alphafold/run_alphafold.py "$@"' > /app/run_alphafold.sh \ + && chmod +x /app/run_alphafold.sh +ENTRYPOINT ["/app/run_alphafold.sh"] |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 docker/claremcwhite/README.md --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/claremcwhite/README.md Tue Mar 01 02:53:05 2022 +0000 |
b |
@@ -0,0 +1,1 @@ +Cam: I think this is the source of the claremcwhite alphafold container but not sure |
b |
diff -r 7ae9d78b06f5 -r 6c92e000d684 validate_fasta.py --- a/validate_fasta.py Fri Jan 28 04:56:29 2022 +0000 +++ b/validate_fasta.py Tue Mar 01 02:53:05 2022 +0000 |
[ |
@@ -1,5 +1,6 @@ +"""Validate input FASTA sequence.""" - +import re import argparse from typing import List, TextIO @@ -11,51 +12,67 @@ class FastaLoader: - def __init__(self): - """creates a Fasta() from a file""" - self.fastas: List[Fasta] = [] + def __init__(self, fasta_path: str): + """Initialize from FASTA file.""" + self.fastas = [] + self.load(fasta_path) + print("Loaded FASTA sequences:") + for f in self.fastas: + print(f.header) + print(f.aa_seq) def load(self, fasta_path: str): - """ - load function has to be very flexible. - file may be normal fasta format (header, seq) or can just be a bare sequence. - """ - with open(fasta_path, 'r') as fp: - header, sequence = self.interpret_first_line(fp) - line = fp.readline().rstrip('\n') - - while line: - if line.startswith('>'): - self.update_fastas(header, sequence) - header = line - sequence = '' - else: - sequence += line - line = fp.readline().rstrip('\n') + """Load bare or FASTA formatted sequence.""" + with open(fasta_path, 'r') as f: + self.content = f.read() + + if "__cn__" in self.content: + # Pasted content with escaped characters + self.newline = '__cn__' + self.caret = '__gt__' + else: + # Uploaded file with normal content + self.newline = '\n' + self.caret = '>' + + self.lines = self.content.split(self.newline) + header, sequence = self.interpret_first_line() + + i = 0 + while i < len(self.lines): + line = self.lines[i] + if line.startswith(self.caret): + self.update_fastas(header, sequence) + header = '>' + self.strip_header(line) + sequence = '' + else: + sequence += line.strip('\n ') + i += 1 # after reading whole file, header & sequence buffers might be full self.update_fastas(header, sequence) - return self.fastas - def interpret_first_line(self, fp: TextIO): - header = '' - sequence = '' - line = fp.readline().rstrip('\n') - if line.startswith('>'): - header = line + def interpret_first_line(self): + line = self.lines[0] + if line.startswith(self.caret): + header = '>' + self.strip_header(line) + return header, '' else: - sequence += line - return header, sequence - + return '', line + + def strip_header(self, line): + """Strip characters escaped with underscores from pasted text.""" + return re.sub(r'\_\_.{2}\_\_', '', line).strip('>') + def update_fastas(self, header: str, sequence: str): # if we have a sequence - if not sequence == '': + if sequence: # create generic header if not exists - if header == '': + if not header: fasta_count = len(self.fastas) header = f'>sequence_{fasta_count}' - # create new Fasta + # Create new Fasta self.fastas.append(Fasta(header, sequence)) @@ -65,9 +82,9 @@ self.min_length = 30 self.max_length = 2000 self.iupac_characters = { - 'A', 'B', 'C', 'D', 'E', 'F', 'G', - 'H', 'I', 'K', 'L', 'M', 'N', 'P', - 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', + 'A', 'B', 'C', 'D', 'E', 'F', 'G', + 'H', 'I', 'K', 'L', 'M', 'N', 'P', + 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '-' } @@ -76,9 +93,9 @@ self.validate_num_seqs() self.validate_length() self.validate_alphabet() - # not checking for 'X' nucleotides at the moment. - # alphafold can throw an error if it doesn't like it. - #self.validate_x() + # not checking for 'X' nucleotides at the moment. + # alphafold can throw an error if it doesn't like it. + #self.validate_x() def validate_num_seqs(self) -> None: if len(self.fasta_list) > 1: @@ -93,19 +110,19 @@ raise Exception(f'Error encountered validating fasta: Sequence too short ({len(fasta.aa_seq)}aa). Must be > 30aa') if len(fasta.aa_seq) > self.max_length: raise Exception(f'Error encountered validating fasta: Sequence too long ({len(fasta.aa_seq)}aa). Must be < 2000aa') - + def validate_alphabet(self): """ - Confirms whether the sequence conforms to IUPAC codes. - If not, reports the offending character and its position. - """ + Confirms whether the sequence conforms to IUPAC codes. + If not, reports the offending character and its position. + """ fasta = self.fasta_list[0] for i, char in enumerate(fasta.aa_seq.upper()): if char not in self.iupac_characters: - raise Exception(f'Error encountered validating fasta: Invalid amino acid found at pos {i}: {char}') + raise Exception(f'Error encountered validating fasta: Invalid amino acid found at pos {i}: "{char}"') def validate_x(self): - """checks if any bases are X. TODO check whether alphafold accepts X bases. """ + """checks if any bases are X. TODO check whether alphafold accepts X bases. """ fasta = self.fasta_list[0] for i, char in enumerate(fasta.aa_seq.upper()): if char == 'X': @@ -134,28 +151,27 @@ def main(): # load fasta file args = parse_args() - fl = FastaLoader() - fastas = fl.load(args.input_fasta) + fas = FastaLoader(args.input_fasta) # validate - fv = FastaValidator(fastas) + fv = FastaValidator(fas.fastas) fv.validate() # write cleaned version fw = FastaWriter() - fw.write(fastas[0]) + fw.write(fas.fastas[0]) - + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument( - "input_fasta", - help="input fasta file", + "input_fasta", + help="input fasta file", type=str - ) + ) return parser.parse_args() if __name__ == '__main__': - main() \ No newline at end of file + main() |