Mercurial > repos > galaxy-australia > alphafold2
diff docker/alphafold/alphafold/model/tf/shape_helpers.py @ 1:6c92e000d684 draft
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author | galaxy-australia |
---|---|
date | Tue, 01 Mar 2022 02:53:05 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/docker/alphafold/alphafold/model/tf/shape_helpers.py Tue Mar 01 02:53:05 2022 +0000 @@ -0,0 +1,47 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for dealing with shapes of TensorFlow tensors.""" +import tensorflow.compat.v1 as tf + + +def shape_list(x): + """Return list of dimensions of a tensor, statically where possible. + + Like `x.shape.as_list()` but with tensors instead of `None`s. + + Args: + x: A tensor. + Returns: + A list with length equal to the rank of the tensor. The n-th element of the + list is an integer when that dimension is statically known otherwise it is + the n-th element of `tf.shape(x)`. + """ + x = tf.convert_to_tensor(x) + + # If unknown rank, return dynamic shape + if x.get_shape().dims is None: + return tf.shape(x) + + static = x.get_shape().as_list() + shape = tf.shape(x) + + ret = [] + for i in range(len(static)): + dim = static[i] + if dim is None: + dim = shape[i] + ret.append(dim) + return ret +