From 7b1ec01e3f0eabbe18f5dee8ab08017d422897f0 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Thu, 28 Nov 2024 16:33:48 +0000 Subject: [PATCH] [nnx] RNN: add broadcast_rngs and state_axes APIs --- flax/nnx/nn/recurrent.py | 78 ++++++++++++++++++++++------------------ flax/nnx/rnglib.py | 19 ++++++++-- 2 files changed, 59 insertions(+), 38 deletions(-) diff --git a/flax/nnx/nn/recurrent.py b/flax/nnx/nn/recurrent.py index ea18805d0f..dd64023448 100644 --- a/flax/nnx/nn/recurrent.py +++ b/flax/nnx/nn/recurrent.py @@ -14,10 +14,8 @@ """RNN modules for Flax.""" -from typing import ( - Any, - TypeVar -) +from typing import Any, TypeVar +from collections.abc import Mapping from collections.abc import Callable from functools import partial from typing_extensions import Protocol @@ -27,13 +25,13 @@ import jax.numpy as jnp from flax import nnx -from flax.nnx import rnglib +from flax.nnx import filterlib, rnglib from flax.nnx.module import Module from flax.nnx.nn import initializers from flax.nnx.nn.linear import Linear from flax.nnx.nn.activations import sigmoid from flax.nnx.nn.activations import tanh -from flax.nnx.transforms.iteration import Carry +from flax.nnx.transforms.iteration import Carry, StateAxes from flax.typing import ( Dtype, Initializer, @@ -594,14 +592,16 @@ class RNN(Module): """ def __init__( - self, - cell: RNNCellBase, - time_major: bool = False, - return_carry: bool = False, - reverse: bool = False, - keep_order: bool = False, - unroll: int = 1, - rngs: rnglib.Rngs | None = None, + self, + cell: RNNCellBase, + time_major: bool = False, + return_carry: bool = False, + reverse: bool = False, + keep_order: bool = False, + unroll: int = 1, + rngs: rnglib.Rngs | None = None, + state_axes: Mapping[str, int | type[Carry] | None] | None = None, + broadcast_rngs: filterlib.Filter = None, ): self.cell = cell self.time_major = time_major @@ -612,19 +612,21 @@ def __init__( if rngs is None: rngs = rnglib.Rngs(0) self.rngs = rngs + self.state_axes = state_axes or {...: Carry} + self.broadcast_rngs = broadcast_rngs def __call__( - self, - inputs: Array, - *, - initial_carry: Carry | None = None, - seq_lengths: Array | None = None, - return_carry: bool | None = None, - time_major: bool | None = None, - reverse: bool | None = None, - keep_order: bool | None = None, - rngs: rnglib.Rngs | None = None, - ): + self, + inputs: Array, + *, + initial_carry: Carry | None = None, + seq_lengths: Array | None = None, + return_carry: bool | None = None, + time_major: bool | None = None, + reverse: bool | None = None, + keep_order: bool | None = None, + rngs: rnglib.Rngs | None = None, + ): if return_carry is None: return_carry = self.return_carry if time_major is None: @@ -670,20 +672,26 @@ def __call__( ) slice_carry = seq_lengths is not None and return_carry - - def scan_fn(cell: RNNCellBase, carry: Carry, x: Array) -> tuple[Carry, Array] | tuple[Carry, tuple[Carry, Array]]: + broadcast_rngs = nnx.All(nnx.RngState, self.broadcast_rngs) + state_axes = StateAxes({broadcast_rngs: None, **self.state_axes}) # type: ignore + + # we use split_rngs with splits=1 and squeeze=True to get unique rngs + # every time RNN is called + @nnx.split_rngs(splits=1, only=self.broadcast_rngs, squeeze=True) + @nnx.scan( + in_axes=(state_axes, Carry, time_axis), + out_axes=(Carry, (0, time_axis)) if slice_carry else (Carry, time_axis), + unroll=self.unroll, + ) + def scan_fn( + cell: RNNCellBase, carry: Carry, x: Array + ) -> tuple[Carry, Array] | tuple[Carry, tuple[Carry, Array]]: carry, y = cell(carry, x) if slice_carry: return carry, (carry, y) return carry, y - state_axes = nnx.StateAxes({...: Carry}) # type: ignore[arg-type] - scan = nnx.scan( - scan_fn, - in_axes=(state_axes, Carry, time_axis), - out_axes=(Carry, (0, time_axis)) if slice_carry else (Carry, time_axis), - unroll=self.unroll, - ) - scan_output = scan(self.cell, carry, inputs) + + scan_output = scan_fn(self.cell, carry, inputs) # Next we select the final carry. If a segmentation mask was provided and # return_carry is True we slice the carry history and select the last valid diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py index 17bbaf37c8..ab9817acaa 100644 --- a/flax/nnx/rnglib.py +++ b/flax/nnx/rnglib.py @@ -302,12 +302,14 @@ def split_rngs( *, splits: int | tuple[int, ...], only: filterlib.Filter = ..., + squeeze: bool = False, ) -> SplitBackups: ... @tp.overload def split_rngs( *, splits: int | tuple[int, ...], only: filterlib.Filter = ..., + squeeze: bool = False, ) -> tp.Callable[[F], F]: ... def split_rngs( node: tp.Any = MISSING, @@ -315,6 +317,7 @@ def split_rngs( *, splits: int | tuple[int, ...], only: filterlib.Filter = ..., + squeeze: bool = False, ) -> SplitBackups | tp.Callable[[F], F]: """Splits the (nested) Rng states of the given node. @@ -412,13 +415,18 @@ def split_rngs( def split_rngs_decorator(f: F) -> F: @functools.wraps(f) def split_rngs_wrapper(*args, **kwargs): - with split_rngs((args, kwargs), splits=splits, only=only): + with split_rngs( + (args, kwargs), splits=splits, only=only, squeeze=squeeze + ): return f(*args, **kwargs) return tp.cast(F, split_rngs_wrapper) return split_rngs_decorator # type: ignore[bad-return-type] + if squeeze and splits != 1: + raise ValueError('squeeze=True is only supported for splits=1') + predicate = filterlib.to_predicate(only) backups: list[StreamBackup] = [] for path, stream in graph.iter_graph(node): @@ -429,8 +437,13 @@ def split_rngs_wrapper(*args, **kwargs): ): key = stream() backups.append((stream, stream.key.value, stream.count.value)) - stream.key.value = jax.random.split(key, splits) - if isinstance(splits, int): + key = jax.random.split(key, splits) + if squeeze: + key = key[0] + stream.key.value = key + if squeeze: + counts_shape = stream.count.shape + elif isinstance(splits, int): counts_shape = (splits, *stream.count.shape) else: counts_shape = (*splits, *stream.count.shape)