comparison docker/alphafold/alphafold/model/tf/shape_helpers.py @ 1:6c92e000d684 draft

"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author galaxy-australia
date Tue, 01 Mar 2022 02:53:05 +0000
parents
children
comparison
equal deleted inserted replaced
0:7ae9d78b06f5 1:6c92e000d684
1 # Copyright 2021 DeepMind Technologies Limited
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 """Utilities for dealing with shapes of TensorFlow tensors."""
16 import tensorflow.compat.v1 as tf
17
18
19 def shape_list(x):
20 """Return list of dimensions of a tensor, statically where possible.
21
22 Like `x.shape.as_list()` but with tensors instead of `None`s.
23
24 Args:
25 x: A tensor.
26 Returns:
27 A list with length equal to the rank of the tensor. The n-th element of the
28 list is an integer when that dimension is statically known otherwise it is
29 the n-th element of `tf.shape(x)`.
30 """
31 x = tf.convert_to_tensor(x)
32
33 # If unknown rank, return dynamic shape
34 if x.get_shape().dims is None:
35 return tf.shape(x)
36
37 static = x.get_shape().as_list()
38 shape = tf.shape(x)
39
40 ret = []
41 for i in range(len(static)):
42 dim = static[i]
43 if dim is None:
44 dim = shape[i]
45 ret.append(dim)
46 return ret
47