Mercurial > repos > galaxy-australia > alphafold2
comparison docker/alphafold/alphafold/model/tf/shape_helpers.py @ 1:6c92e000d684 draft
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author | galaxy-australia |
---|---|
date | Tue, 01 Mar 2022 02:53:05 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
0:7ae9d78b06f5 | 1:6c92e000d684 |
---|---|
1 # Copyright 2021 DeepMind Technologies Limited | |
2 # | |
3 # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 # you may not use this file except in compliance with the License. | |
5 # You may obtain a copy of the License at | |
6 # | |
7 # http://www.apache.org/licenses/LICENSE-2.0 | |
8 # | |
9 # Unless required by applicable law or agreed to in writing, software | |
10 # distributed under the License is distributed on an "AS IS" BASIS, | |
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 # See the License for the specific language governing permissions and | |
13 # limitations under the License. | |
14 | |
15 """Utilities for dealing with shapes of TensorFlow tensors.""" | |
16 import tensorflow.compat.v1 as tf | |
17 | |
18 | |
19 def shape_list(x): | |
20 """Return list of dimensions of a tensor, statically where possible. | |
21 | |
22 Like `x.shape.as_list()` but with tensors instead of `None`s. | |
23 | |
24 Args: | |
25 x: A tensor. | |
26 Returns: | |
27 A list with length equal to the rank of the tensor. The n-th element of the | |
28 list is an integer when that dimension is statically known otherwise it is | |
29 the n-th element of `tf.shape(x)`. | |
30 """ | |
31 x = tf.convert_to_tensor(x) | |
32 | |
33 # If unknown rank, return dynamic shape | |
34 if x.get_shape().dims is None: | |
35 return tf.shape(x) | |
36 | |
37 static = x.get_shape().as_list() | |
38 shape = tf.shape(x) | |
39 | |
40 ret = [] | |
41 for i in range(len(static)): | |
42 dim = static[i] | |
43 if dim is None: | |
44 dim = shape[i] | |
45 ret.append(dim) | |
46 return ret | |
47 |