Mercurial > repos > galaxy-australia > alphafold2
comparison docker/alphafold/alphafold/common/protein.py @ 1:6c92e000d684 draft
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author | galaxy-australia |
---|---|
date | Tue, 01 Mar 2022 02:53:05 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
0:7ae9d78b06f5 | 1:6c92e000d684 |
---|---|
1 # Copyright 2021 DeepMind Technologies Limited | |
2 # | |
3 # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 # you may not use this file except in compliance with the License. | |
5 # You may obtain a copy of the License at | |
6 # | |
7 # http://www.apache.org/licenses/LICENSE-2.0 | |
8 # | |
9 # Unless required by applicable law or agreed to in writing, software | |
10 # distributed under the License is distributed on an "AS IS" BASIS, | |
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 # See the License for the specific language governing permissions and | |
13 # limitations under the License. | |
14 | |
15 """Protein data type.""" | |
16 import dataclasses | |
17 import io | |
18 from typing import Any, Mapping, Optional | |
19 from alphafold.common import residue_constants | |
20 from Bio.PDB import PDBParser | |
21 import numpy as np | |
22 | |
23 FeatureDict = Mapping[str, np.ndarray] | |
24 ModelOutput = Mapping[str, Any] # Is a nested dict. | |
25 | |
26 # Complete sequence of chain IDs supported by the PDB format. | |
27 PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' | |
28 PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62. | |
29 | |
30 | |
31 @dataclasses.dataclass(frozen=True) | |
32 class Protein: | |
33 """Protein structure representation.""" | |
34 | |
35 # Cartesian coordinates of atoms in angstroms. The atom types correspond to | |
36 # residue_constants.atom_types, i.e. the first three are N, CA, CB. | |
37 atom_positions: np.ndarray # [num_res, num_atom_type, 3] | |
38 | |
39 # Amino-acid type for each residue represented as an integer between 0 and | |
40 # 20, where 20 is 'X'. | |
41 aatype: np.ndarray # [num_res] | |
42 | |
43 # Binary float mask to indicate presence of a particular atom. 1.0 if an atom | |
44 # is present and 0.0 if not. This should be used for loss masking. | |
45 atom_mask: np.ndarray # [num_res, num_atom_type] | |
46 | |
47 # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. | |
48 residue_index: np.ndarray # [num_res] | |
49 | |
50 # 0-indexed number corresponding to the chain in the protein that this residue | |
51 # belongs to. | |
52 chain_index: np.ndarray # [num_res] | |
53 | |
54 # B-factors, or temperature factors, of each residue (in sq. angstroms units), | |
55 # representing the displacement of the residue from its ground truth mean | |
56 # value. | |
57 b_factors: np.ndarray # [num_res, num_atom_type] | |
58 | |
59 def __post_init__(self): | |
60 if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS: | |
61 raise ValueError( | |
62 f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains ' | |
63 'because these cannot be written to PDB format.') | |
64 | |
65 | |
66 def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: | |
67 """Takes a PDB string and constructs a Protein object. | |
68 | |
69 WARNING: All non-standard residue types will be converted into UNK. All | |
70 non-standard atoms will be ignored. | |
71 | |
72 Args: | |
73 pdb_str: The contents of the pdb file | |
74 chain_id: If chain_id is specified (e.g. A), then only that chain | |
75 is parsed. Otherwise all chains are parsed. | |
76 | |
77 Returns: | |
78 A new `Protein` parsed from the pdb contents. | |
79 """ | |
80 pdb_fh = io.StringIO(pdb_str) | |
81 parser = PDBParser(QUIET=True) | |
82 structure = parser.get_structure('none', pdb_fh) | |
83 models = list(structure.get_models()) | |
84 if len(models) != 1: | |
85 raise ValueError( | |
86 f'Only single model PDBs are supported. Found {len(models)} models.') | |
87 model = models[0] | |
88 | |
89 atom_positions = [] | |
90 aatype = [] | |
91 atom_mask = [] | |
92 residue_index = [] | |
93 chain_ids = [] | |
94 b_factors = [] | |
95 | |
96 for chain in model: | |
97 if chain_id is not None and chain.id != chain_id: | |
98 continue | |
99 for res in chain: | |
100 if res.id[2] != ' ': | |
101 raise ValueError( | |
102 f'PDB contains an insertion code at chain {chain.id} and residue ' | |
103 f'index {res.id[1]}. These are not supported.') | |
104 res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') | |
105 restype_idx = residue_constants.restype_order.get( | |
106 res_shortname, residue_constants.restype_num) | |
107 pos = np.zeros((residue_constants.atom_type_num, 3)) | |
108 mask = np.zeros((residue_constants.atom_type_num,)) | |
109 res_b_factors = np.zeros((residue_constants.atom_type_num,)) | |
110 for atom in res: | |
111 if atom.name not in residue_constants.atom_types: | |
112 continue | |
113 pos[residue_constants.atom_order[atom.name]] = atom.coord | |
114 mask[residue_constants.atom_order[atom.name]] = 1. | |
115 res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor | |
116 if np.sum(mask) < 0.5: | |
117 # If no known atom positions are reported for the residue then skip it. | |
118 continue | |
119 aatype.append(restype_idx) | |
120 atom_positions.append(pos) | |
121 atom_mask.append(mask) | |
122 residue_index.append(res.id[1]) | |
123 chain_ids.append(chain.id) | |
124 b_factors.append(res_b_factors) | |
125 | |
126 # Chain IDs are usually characters so map these to ints. | |
127 unique_chain_ids = np.unique(chain_ids) | |
128 chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)} | |
129 chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) | |
130 | |
131 return Protein( | |
132 atom_positions=np.array(atom_positions), | |
133 atom_mask=np.array(atom_mask), | |
134 aatype=np.array(aatype), | |
135 residue_index=np.array(residue_index), | |
136 chain_index=chain_index, | |
137 b_factors=np.array(b_factors)) | |
138 | |
139 | |
140 def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str: | |
141 chain_end = 'TER' | |
142 return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} ' | |
143 f'{chain_name:>1}{residue_index:>4}') | |
144 | |
145 | |
146 def to_pdb(prot: Protein) -> str: | |
147 """Converts a `Protein` instance to a PDB string. | |
148 | |
149 Args: | |
150 prot: The protein to convert to PDB. | |
151 | |
152 Returns: | |
153 PDB string. | |
154 """ | |
155 restypes = residue_constants.restypes + ['X'] | |
156 res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK') | |
157 atom_types = residue_constants.atom_types | |
158 | |
159 pdb_lines = [] | |
160 | |
161 atom_mask = prot.atom_mask | |
162 aatype = prot.aatype | |
163 atom_positions = prot.atom_positions | |
164 residue_index = prot.residue_index.astype(np.int32) | |
165 chain_index = prot.chain_index.astype(np.int32) | |
166 b_factors = prot.b_factors | |
167 | |
168 if np.any(aatype > residue_constants.restype_num): | |
169 raise ValueError('Invalid aatypes.') | |
170 | |
171 # Construct a mapping from chain integer indices to chain ID strings. | |
172 chain_ids = {} | |
173 for i in np.unique(chain_index): # np.unique gives sorted output. | |
174 if i >= PDB_MAX_CHAINS: | |
175 raise ValueError( | |
176 f'The PDB format supports at most {PDB_MAX_CHAINS} chains.') | |
177 chain_ids[i] = PDB_CHAIN_IDS[i] | |
178 | |
179 pdb_lines.append('MODEL 1') | |
180 atom_index = 1 | |
181 last_chain_index = chain_index[0] | |
182 # Add all atom sites. | |
183 for i in range(aatype.shape[0]): | |
184 # Close the previous chain if in a multichain PDB. | |
185 if last_chain_index != chain_index[i]: | |
186 pdb_lines.append(_chain_end( | |
187 atom_index, res_1to3(aatype[i - 1]), chain_ids[chain_index[i - 1]], | |
188 residue_index[i - 1])) | |
189 last_chain_index = chain_index[i] | |
190 atom_index += 1 # Atom index increases at the TER symbol. | |
191 | |
192 res_name_3 = res_1to3(aatype[i]) | |
193 for atom_name, pos, mask, b_factor in zip( | |
194 atom_types, atom_positions[i], atom_mask[i], b_factors[i]): | |
195 if mask < 0.5: | |
196 continue | |
197 | |
198 record_type = 'ATOM' | |
199 name = atom_name if len(atom_name) == 4 else f' {atom_name}' | |
200 alt_loc = '' | |
201 insertion_code = '' | |
202 occupancy = 1.00 | |
203 element = atom_name[0] # Protein supports only C, N, O, S, this works. | |
204 charge = '' | |
205 # PDB is a columnar format, every space matters here! | |
206 atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}' | |
207 f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}' | |
208 f'{residue_index[i]:>4}{insertion_code:>1} ' | |
209 f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}' | |
210 f'{occupancy:>6.2f}{b_factor:>6.2f} ' | |
211 f'{element:>2}{charge:>2}') | |
212 pdb_lines.append(atom_line) | |
213 atom_index += 1 | |
214 | |
215 # Close the final chain. | |
216 pdb_lines.append(_chain_end(atom_index, res_1to3(aatype[-1]), | |
217 chain_ids[chain_index[-1]], residue_index[-1])) | |
218 pdb_lines.append('ENDMDL') | |
219 pdb_lines.append('END') | |
220 | |
221 # Pad all lines to 80 characters. | |
222 pdb_lines = [line.ljust(80) for line in pdb_lines] | |
223 return '\n'.join(pdb_lines) + '\n' # Add terminating newline. | |
224 | |
225 | |
226 def ideal_atom_mask(prot: Protein) -> np.ndarray: | |
227 """Computes an ideal atom mask. | |
228 | |
229 `Protein.atom_mask` typically is defined according to the atoms that are | |
230 reported in the PDB. This function computes a mask according to heavy atoms | |
231 that should be present in the given sequence of amino acids. | |
232 | |
233 Args: | |
234 prot: `Protein` whose fields are `numpy.ndarray` objects. | |
235 | |
236 Returns: | |
237 An ideal atom mask. | |
238 """ | |
239 return residue_constants.STANDARD_ATOM_MASK[prot.aatype] | |
240 | |
241 | |
242 def from_prediction( | |
243 features: FeatureDict, | |
244 result: ModelOutput, | |
245 b_factors: Optional[np.ndarray] = None, | |
246 remove_leading_feature_dimension: bool = True) -> Protein: | |
247 """Assembles a protein from a prediction. | |
248 | |
249 Args: | |
250 features: Dictionary holding model inputs. | |
251 result: Dictionary holding model outputs. | |
252 b_factors: (Optional) B-factors to use for the protein. | |
253 remove_leading_feature_dimension: Whether to remove the leading dimension | |
254 of the `features` values. | |
255 | |
256 Returns: | |
257 A protein instance. | |
258 """ | |
259 fold_output = result['structure_module'] | |
260 | |
261 def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray: | |
262 return arr[0] if remove_leading_feature_dimension else arr | |
263 | |
264 if 'asym_id' in features: | |
265 chain_index = _maybe_remove_leading_dim(features['asym_id']) | |
266 else: | |
267 chain_index = np.zeros_like(_maybe_remove_leading_dim(features['aatype'])) | |
268 | |
269 if b_factors is None: | |
270 b_factors = np.zeros_like(fold_output['final_atom_mask']) | |
271 | |
272 return Protein( | |
273 aatype=_maybe_remove_leading_dim(features['aatype']), | |
274 atom_positions=fold_output['final_atom_positions'], | |
275 atom_mask=fold_output['final_atom_mask'], | |
276 residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1, | |
277 chain_index=chain_index, | |
278 b_factors=b_factors) |