view docker/alphafold/alphafold/model/layer_stack.py @ 1:6c92e000d684 draft

"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author galaxy-australia
date Tue, 01 Mar 2022 02:53:05 +0000
parents
children
line wrap: on
line source

# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Function to stack repeats of a layer function without shared parameters."""

import collections
import contextlib
import functools
import inspect
from typing import Any, Callable, Optional, Tuple, Union

import haiku as hk
import jax
import jax.numpy as jnp

LayerStackCarry = collections.namedtuple('LayerStackCarry', ['x', 'rng'])
LayerStackScanned = collections.namedtuple('LayerStackScanned',
                                           ['i', 'args_ys'])

# WrappedFn should take in arbitrarily nested `jnp.ndarray`, and return the
# exact same type. We cannot express this with `typing`. So we just use it
# to inform the user. In reality, the typing below will accept anything.
NestedArray = Any
WrappedFn = Callable[..., Union[NestedArray, Tuple[NestedArray]]]


def _check_no_varargs(f):
  if list(inspect.signature(
      f).parameters.values())[0].kind == inspect.Parameter.VAR_POSITIONAL:
    raise ValueError(
        'The function `f` should not have any `varargs` (that is *args) '
        'argument. Instead, it should only use explicit positional'
        'arguments.')


@contextlib.contextmanager
def nullcontext():
  yield


def maybe_with_rng(key):
  if key is not None:
    return hk.with_rng(key)
  else:
    return nullcontext()


def maybe_fold_in(key, data):
  if key is not None:
    return jax.random.fold_in(key, data)
  else:
    return None


class _LayerStack(hk.Module):
  """Module to compose parameterized functions, implemented as a scan."""

  def __init__(self,
               count: int,
               unroll: int,
               name: Optional[str] = None):
    """Iterate a function `f` `count` times, with non-shared parameters."""
    super().__init__(name=name)
    self._count = count
    self._unroll = unroll

  def __call__(self, x, *args_ys):
    count = self._count
    if hk.running_init():
      # At initialization time, we run just one layer but add an extra first
      # dimension to every initialized tensor, making sure to use different
      # random keys for different slices.
      def creator(next_creator, shape, dtype, init, context):
        del context

        def multi_init(shape, dtype):
          assert shape[0] == count
          key = hk.maybe_next_rng_key()

          def rng_context_init(slice_idx):
            slice_key = maybe_fold_in(key, slice_idx)
            with maybe_with_rng(slice_key):
              return init(shape[1:], dtype)

          return jax.vmap(rng_context_init)(jnp.arange(count))

        return next_creator((count,) + tuple(shape), dtype, multi_init)

      def getter(next_getter, value, context):
        trailing_dims = len(context.original_shape) + 1
        sliced_value = jax.lax.index_in_dim(
            value, index=0, axis=value.ndim - trailing_dims, keepdims=False)
        return next_getter(sliced_value)

      with hk.experimental.custom_creator(
          creator), hk.experimental.custom_getter(getter):
        if len(args_ys) == 1 and args_ys[0] is None:
          args0 = (None,)
        else:
          args0 = [
              jax.lax.dynamic_index_in_dim(ys, 0, keepdims=False)
              for ys in args_ys
          ]
        x, z = self._call_wrapped(x, *args0)
        if z is None:
          return x, z

        # Broadcast state to hold each layer state.
        def broadcast_state(layer_state):
          return jnp.broadcast_to(
              layer_state, [count,] + list(layer_state.shape))
        zs = jax.tree_util.tree_map(broadcast_state, z)
        return x, zs
    else:
      # Use scan during apply, threading through random seed so that it's
      # unique for each layer.
      def layer(carry: LayerStackCarry, scanned: LayerStackScanned):
        rng = carry.rng

        def getter(next_getter, value, context):
          # Getter slices the full param at the current loop index.
          trailing_dims = len(context.original_shape) + 1
          assert value.shape[value.ndim - trailing_dims] == count, (
              f'Attempting to use a parameter stack of size '
              f'{value.shape[value.ndim - trailing_dims]} for a LayerStack of '
              f'size {count}.')

          sliced_value = jax.lax.dynamic_index_in_dim(
              value, scanned.i, axis=value.ndim - trailing_dims, keepdims=False)
          return next_getter(sliced_value)

        with hk.experimental.custom_getter(getter):
          if rng is None:
            out_x, z = self._call_wrapped(carry.x, *scanned.args_ys)
          else:
            rng, rng_ = jax.random.split(rng)
            with hk.with_rng(rng_):
              out_x, z = self._call_wrapped(carry.x, *scanned.args_ys)
        return LayerStackCarry(x=out_x, rng=rng), z

      carry = LayerStackCarry(x=x, rng=hk.maybe_next_rng_key())
      scanned = LayerStackScanned(i=jnp.arange(count, dtype=jnp.int32),
                                  args_ys=args_ys)

      carry, zs = hk.scan(
          layer, carry, scanned, length=count, unroll=self._unroll)
      return carry.x, zs

  def _call_wrapped(self,
                    x: jnp.ndarray,
                    *args,
                    ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
    raise NotImplementedError()


class _LayerStackNoState(_LayerStack):
  """_LayerStack impl with no per-layer state provided to the function."""

  def __init__(self,
               f: WrappedFn,
               count: int,
               unroll: int,
               name: Optional[str] = None):
    super().__init__(count=count, unroll=unroll, name=name)
    _check_no_varargs(f)
    self._f = f

  @hk.transparent
  def _call_wrapped(self, args, y):
    del y
    ret = self._f(*args)
    if len(args) == 1:
      # If the function takes a single argument, the wrapped function receives
      # a tuple of length 1, and therefore it must return a tuple of length 1.
      ret = (ret,)
    return ret, None


class _LayerStackWithState(_LayerStack):
  """_LayerStack impl with per-layer state provided to the function."""

  def __init__(self,
               f: WrappedFn,
               count: int,
               unroll: int,
               name: Optional[str] = None):
    super().__init__(count=count, unroll=unroll, name=name)
    self._f = f

  @hk.transparent
  def _call_wrapped(self, x, *args):
    return self._f(x, *args)


def layer_stack(num_layers: int,
                with_state=False,
                unroll: int = 1,
                name: Optional[str] = None):
  """Utility to wrap a Haiku function and recursively apply it to an input.

  A function is valid if it uses only explicit position parameters, and
  its return type matches its input type. The position parameters can be
  arbitrarily nested structures with `jnp.ndarray` at the leaf nodes. Note
  that kwargs are not supported, neither are functions with variable number
  of parameters (specified by `*args`).

  If `with_state=False` then the new, wrapped function can be understood as
  performing the following:
  ```
  for i in range(num_layers):
    x = f(x)
  return x
  ```

  And if `with_state=True`, assuming `f` takes two arguments on top of `x`:
  ```
  for i in range(num_layers):
    x, zs[i] = f(x, ys_0[i], ys_1[i])
  return x, zs
  ```
  The code using `layer_stack` for the above function would be:
  ```
  def f(x, y_0, y_1):
    ...
    return new_x, z
  x, zs = layer_stack.layer_stack(num_layers,
                                  with_state=True)(f)(x, ys_0, ys_1)
  ```

  Crucially, any parameters created inside `f` will not be shared across
  iterations.

  Args:
    num_layers: The number of times to iterate the wrapped function.
    with_state: Whether or not to pass per-layer state to the wrapped function.
    unroll: the unroll used by `scan`.
    name: Name of the Haiku context.

  Returns:
    Callable that will produce a layer stack when called with a valid function.
  """
  def iterate(f):
    if with_state:
      @functools.wraps(f)
      def wrapped(x, *args):
        for ys in args:
          assert ys.shape[0] == num_layers
        return _LayerStackWithState(
            f, num_layers, unroll=unroll, name=name)(x, *args)
    else:
      _check_no_varargs(f)
      @functools.wraps(f)
      def wrapped(*args):
        ret = _LayerStackNoState(
            f, num_layers, unroll=unroll, name=name)(args, None)[0]
        if len(args) == 1:
          # If the function takes a single argument, we must also return a
          # single value, and not a tuple of length 1.
          ret = ret[0]
        return ret

    return wrapped
  return iterate