comparison docker/alphafold/alphafold/common/protein.py @ 1:6c92e000d684 draft

"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author galaxy-australia
date Tue, 01 Mar 2022 02:53:05 +0000
parents
children
comparison
equal deleted inserted replaced
0:7ae9d78b06f5 1:6c92e000d684
1 # Copyright 2021 DeepMind Technologies Limited
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 """Protein data type."""
16 import dataclasses
17 import io
18 from typing import Any, Mapping, Optional
19 from alphafold.common import residue_constants
20 from Bio.PDB import PDBParser
21 import numpy as np
22
23 FeatureDict = Mapping[str, np.ndarray]
24 ModelOutput = Mapping[str, Any] # Is a nested dict.
25
26 # Complete sequence of chain IDs supported by the PDB format.
27 PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
28 PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62.
29
30
31 @dataclasses.dataclass(frozen=True)
32 class Protein:
33 """Protein structure representation."""
34
35 # Cartesian coordinates of atoms in angstroms. The atom types correspond to
36 # residue_constants.atom_types, i.e. the first three are N, CA, CB.
37 atom_positions: np.ndarray # [num_res, num_atom_type, 3]
38
39 # Amino-acid type for each residue represented as an integer between 0 and
40 # 20, where 20 is 'X'.
41 aatype: np.ndarray # [num_res]
42
43 # Binary float mask to indicate presence of a particular atom. 1.0 if an atom
44 # is present and 0.0 if not. This should be used for loss masking.
45 atom_mask: np.ndarray # [num_res, num_atom_type]
46
47 # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
48 residue_index: np.ndarray # [num_res]
49
50 # 0-indexed number corresponding to the chain in the protein that this residue
51 # belongs to.
52 chain_index: np.ndarray # [num_res]
53
54 # B-factors, or temperature factors, of each residue (in sq. angstroms units),
55 # representing the displacement of the residue from its ground truth mean
56 # value.
57 b_factors: np.ndarray # [num_res, num_atom_type]
58
59 def __post_init__(self):
60 if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS:
61 raise ValueError(
62 f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains '
63 'because these cannot be written to PDB format.')
64
65
66 def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
67 """Takes a PDB string and constructs a Protein object.
68
69 WARNING: All non-standard residue types will be converted into UNK. All
70 non-standard atoms will be ignored.
71
72 Args:
73 pdb_str: The contents of the pdb file
74 chain_id: If chain_id is specified (e.g. A), then only that chain
75 is parsed. Otherwise all chains are parsed.
76
77 Returns:
78 A new `Protein` parsed from the pdb contents.
79 """
80 pdb_fh = io.StringIO(pdb_str)
81 parser = PDBParser(QUIET=True)
82 structure = parser.get_structure('none', pdb_fh)
83 models = list(structure.get_models())
84 if len(models) != 1:
85 raise ValueError(
86 f'Only single model PDBs are supported. Found {len(models)} models.')
87 model = models[0]
88
89 atom_positions = []
90 aatype = []
91 atom_mask = []
92 residue_index = []
93 chain_ids = []
94 b_factors = []
95
96 for chain in model:
97 if chain_id is not None and chain.id != chain_id:
98 continue
99 for res in chain:
100 if res.id[2] != ' ':
101 raise ValueError(
102 f'PDB contains an insertion code at chain {chain.id} and residue '
103 f'index {res.id[1]}. These are not supported.')
104 res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
105 restype_idx = residue_constants.restype_order.get(
106 res_shortname, residue_constants.restype_num)
107 pos = np.zeros((residue_constants.atom_type_num, 3))
108 mask = np.zeros((residue_constants.atom_type_num,))
109 res_b_factors = np.zeros((residue_constants.atom_type_num,))
110 for atom in res:
111 if atom.name not in residue_constants.atom_types:
112 continue
113 pos[residue_constants.atom_order[atom.name]] = atom.coord
114 mask[residue_constants.atom_order[atom.name]] = 1.
115 res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
116 if np.sum(mask) < 0.5:
117 # If no known atom positions are reported for the residue then skip it.
118 continue
119 aatype.append(restype_idx)
120 atom_positions.append(pos)
121 atom_mask.append(mask)
122 residue_index.append(res.id[1])
123 chain_ids.append(chain.id)
124 b_factors.append(res_b_factors)
125
126 # Chain IDs are usually characters so map these to ints.
127 unique_chain_ids = np.unique(chain_ids)
128 chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
129 chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
130
131 return Protein(
132 atom_positions=np.array(atom_positions),
133 atom_mask=np.array(atom_mask),
134 aatype=np.array(aatype),
135 residue_index=np.array(residue_index),
136 chain_index=chain_index,
137 b_factors=np.array(b_factors))
138
139
140 def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
141 chain_end = 'TER'
142 return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
143 f'{chain_name:>1}{residue_index:>4}')
144
145
146 def to_pdb(prot: Protein) -> str:
147 """Converts a `Protein` instance to a PDB string.
148
149 Args:
150 prot: The protein to convert to PDB.
151
152 Returns:
153 PDB string.
154 """
155 restypes = residue_constants.restypes + ['X']
156 res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK')
157 atom_types = residue_constants.atom_types
158
159 pdb_lines = []
160
161 atom_mask = prot.atom_mask
162 aatype = prot.aatype
163 atom_positions = prot.atom_positions
164 residue_index = prot.residue_index.astype(np.int32)
165 chain_index = prot.chain_index.astype(np.int32)
166 b_factors = prot.b_factors
167
168 if np.any(aatype > residue_constants.restype_num):
169 raise ValueError('Invalid aatypes.')
170
171 # Construct a mapping from chain integer indices to chain ID strings.
172 chain_ids = {}
173 for i in np.unique(chain_index): # np.unique gives sorted output.
174 if i >= PDB_MAX_CHAINS:
175 raise ValueError(
176 f'The PDB format supports at most {PDB_MAX_CHAINS} chains.')
177 chain_ids[i] = PDB_CHAIN_IDS[i]
178
179 pdb_lines.append('MODEL 1')
180 atom_index = 1
181 last_chain_index = chain_index[0]
182 # Add all atom sites.
183 for i in range(aatype.shape[0]):
184 # Close the previous chain if in a multichain PDB.
185 if last_chain_index != chain_index[i]:
186 pdb_lines.append(_chain_end(
187 atom_index, res_1to3(aatype[i - 1]), chain_ids[chain_index[i - 1]],
188 residue_index[i - 1]))
189 last_chain_index = chain_index[i]
190 atom_index += 1 # Atom index increases at the TER symbol.
191
192 res_name_3 = res_1to3(aatype[i])
193 for atom_name, pos, mask, b_factor in zip(
194 atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
195 if mask < 0.5:
196 continue
197
198 record_type = 'ATOM'
199 name = atom_name if len(atom_name) == 4 else f' {atom_name}'
200 alt_loc = ''
201 insertion_code = ''
202 occupancy = 1.00
203 element = atom_name[0] # Protein supports only C, N, O, S, this works.
204 charge = ''
205 # PDB is a columnar format, every space matters here!
206 atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
207 f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}'
208 f'{residue_index[i]:>4}{insertion_code:>1} '
209 f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
210 f'{occupancy:>6.2f}{b_factor:>6.2f} '
211 f'{element:>2}{charge:>2}')
212 pdb_lines.append(atom_line)
213 atom_index += 1
214
215 # Close the final chain.
216 pdb_lines.append(_chain_end(atom_index, res_1to3(aatype[-1]),
217 chain_ids[chain_index[-1]], residue_index[-1]))
218 pdb_lines.append('ENDMDL')
219 pdb_lines.append('END')
220
221 # Pad all lines to 80 characters.
222 pdb_lines = [line.ljust(80) for line in pdb_lines]
223 return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
224
225
226 def ideal_atom_mask(prot: Protein) -> np.ndarray:
227 """Computes an ideal atom mask.
228
229 `Protein.atom_mask` typically is defined according to the atoms that are
230 reported in the PDB. This function computes a mask according to heavy atoms
231 that should be present in the given sequence of amino acids.
232
233 Args:
234 prot: `Protein` whose fields are `numpy.ndarray` objects.
235
236 Returns:
237 An ideal atom mask.
238 """
239 return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
240
241
242 def from_prediction(
243 features: FeatureDict,
244 result: ModelOutput,
245 b_factors: Optional[np.ndarray] = None,
246 remove_leading_feature_dimension: bool = True) -> Protein:
247 """Assembles a protein from a prediction.
248
249 Args:
250 features: Dictionary holding model inputs.
251 result: Dictionary holding model outputs.
252 b_factors: (Optional) B-factors to use for the protein.
253 remove_leading_feature_dimension: Whether to remove the leading dimension
254 of the `features` values.
255
256 Returns:
257 A protein instance.
258 """
259 fold_output = result['structure_module']
260
261 def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
262 return arr[0] if remove_leading_feature_dimension else arr
263
264 if 'asym_id' in features:
265 chain_index = _maybe_remove_leading_dim(features['asym_id'])
266 else:
267 chain_index = np.zeros_like(_maybe_remove_leading_dim(features['aatype']))
268
269 if b_factors is None:
270 b_factors = np.zeros_like(fold_output['final_atom_mask'])
271
272 return Protein(
273 aatype=_maybe_remove_leading_dim(features['aatype']),
274 atom_positions=fold_output['final_atom_positions'],
275 atom_mask=fold_output['final_atom_mask'],
276 residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1,
277 chain_index=chain_index,
278 b_factors=b_factors)