diff docker/alphafold/alphafold/model/tf/shape_helpers.py @ 1:6c92e000d684 draft

"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author galaxy-australia
date Tue, 01 Mar 2022 02:53:05 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/docker/alphafold/alphafold/model/tf/shape_helpers.py	Tue Mar 01 02:53:05 2022 +0000
@@ -0,0 +1,47 @@
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for dealing with shapes of TensorFlow tensors."""
+import tensorflow.compat.v1 as tf
+
+
+def shape_list(x):
+  """Return list of dimensions of a tensor, statically where possible.
+
+  Like `x.shape.as_list()` but with tensors instead of `None`s.
+
+  Args:
+    x: A tensor.
+  Returns:
+    A list with length equal to the rank of the tensor. The n-th element of the
+    list is an integer when that dimension is statically known otherwise it is
+    the n-th element of `tf.shape(x)`.
+  """
+  x = tf.convert_to_tensor(x)
+
+  # If unknown rank, return dynamic shape
+  if x.get_shape().dims is None:
+    return tf.shape(x)
+
+  static = x.get_shape().as_list()
+  shape = tf.shape(x)
+
+  ret = []
+  for i in range(len(static)):
+    dim = static[i]
+    if dim is None:
+      dim = shape[i]
+    ret.append(dim)
+  return ret
+