Skip to content

Commit

Permalink
[nnx] RNN: add broadcast_rngs and state_axes APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Nov 28, 2024
1 parent 5d896bc commit 7b1ec01
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 38 deletions.
78 changes: 43 additions & 35 deletions flax/nnx/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@

"""RNN modules for Flax."""

from typing import (
Any,
TypeVar
)
from typing import Any, TypeVar
from collections.abc import Mapping
from collections.abc import Callable
from functools import partial
from typing_extensions import Protocol
Expand All @@ -27,13 +25,13 @@
import jax.numpy as jnp

from flax import nnx
from flax.nnx import rnglib
from flax.nnx import filterlib, rnglib
from flax.nnx.module import Module
from flax.nnx.nn import initializers
from flax.nnx.nn.linear import Linear
from flax.nnx.nn.activations import sigmoid
from flax.nnx.nn.activations import tanh
from flax.nnx.transforms.iteration import Carry
from flax.nnx.transforms.iteration import Carry, StateAxes
from flax.typing import (
Dtype,
Initializer,
Expand Down Expand Up @@ -594,14 +592,16 @@ class RNN(Module):
"""

def __init__(
self,
cell: RNNCellBase,
time_major: bool = False,
return_carry: bool = False,
reverse: bool = False,
keep_order: bool = False,
unroll: int = 1,
rngs: rnglib.Rngs | None = None,
self,
cell: RNNCellBase,
time_major: bool = False,
return_carry: bool = False,
reverse: bool = False,
keep_order: bool = False,
unroll: int = 1,
rngs: rnglib.Rngs | None = None,
state_axes: Mapping[str, int | type[Carry] | None] | None = None,
broadcast_rngs: filterlib.Filter = None,
):
self.cell = cell
self.time_major = time_major
Expand All @@ -612,19 +612,21 @@ def __init__(
if rngs is None:
rngs = rnglib.Rngs(0)
self.rngs = rngs
self.state_axes = state_axes or {...: Carry}
self.broadcast_rngs = broadcast_rngs

def __call__(
self,
inputs: Array,
*,
initial_carry: Carry | None = None,
seq_lengths: Array | None = None,
return_carry: bool | None = None,
time_major: bool | None = None,
reverse: bool | None = None,
keep_order: bool | None = None,
rngs: rnglib.Rngs | None = None,
):
self,
inputs: Array,
*,
initial_carry: Carry | None = None,
seq_lengths: Array | None = None,
return_carry: bool | None = None,
time_major: bool | None = None,
reverse: bool | None = None,
keep_order: bool | None = None,
rngs: rnglib.Rngs | None = None,
):
if return_carry is None:
return_carry = self.return_carry
if time_major is None:
Expand Down Expand Up @@ -670,20 +672,26 @@ def __call__(
)

slice_carry = seq_lengths is not None and return_carry

def scan_fn(cell: RNNCellBase, carry: Carry, x: Array) -> tuple[Carry, Array] | tuple[Carry, tuple[Carry, Array]]:
broadcast_rngs = nnx.All(nnx.RngState, self.broadcast_rngs)
state_axes = StateAxes({broadcast_rngs: None, **self.state_axes}) # type: ignore

# we use split_rngs with splits=1 and squeeze=True to get unique rngs
# every time RNN is called
@nnx.split_rngs(splits=1, only=self.broadcast_rngs, squeeze=True)
@nnx.scan(
in_axes=(state_axes, Carry, time_axis),
out_axes=(Carry, (0, time_axis)) if slice_carry else (Carry, time_axis),
unroll=self.unroll,
)
def scan_fn(
cell: RNNCellBase, carry: Carry, x: Array
) -> tuple[Carry, Array] | tuple[Carry, tuple[Carry, Array]]:
carry, y = cell(carry, x)
if slice_carry:
return carry, (carry, y)
return carry, y
state_axes = nnx.StateAxes({...: Carry}) # type: ignore[arg-type]
scan = nnx.scan(
scan_fn,
in_axes=(state_axes, Carry, time_axis),
out_axes=(Carry, (0, time_axis)) if slice_carry else (Carry, time_axis),
unroll=self.unroll,
)
scan_output = scan(self.cell, carry, inputs)

scan_output = scan_fn(self.cell, carry, inputs)

# Next we select the final carry. If a segmentation mask was provided and
# return_carry is True we slice the carry history and select the last valid
Expand Down
19 changes: 16 additions & 3 deletions flax/nnx/rnglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,19 +302,22 @@ def split_rngs(
*,
splits: int | tuple[int, ...],
only: filterlib.Filter = ...,
squeeze: bool = False,
) -> SplitBackups: ...
@tp.overload
def split_rngs(
*,
splits: int | tuple[int, ...],
only: filterlib.Filter = ...,
squeeze: bool = False,
) -> tp.Callable[[F], F]: ...
def split_rngs(
node: tp.Any = MISSING,
/,
*,
splits: int | tuple[int, ...],
only: filterlib.Filter = ...,
squeeze: bool = False,
) -> SplitBackups | tp.Callable[[F], F]:
"""Splits the (nested) Rng states of the given node.
Expand Down Expand Up @@ -412,13 +415,18 @@ def split_rngs(
def split_rngs_decorator(f: F) -> F:
@functools.wraps(f)
def split_rngs_wrapper(*args, **kwargs):
with split_rngs((args, kwargs), splits=splits, only=only):
with split_rngs(
(args, kwargs), splits=splits, only=only, squeeze=squeeze
):
return f(*args, **kwargs)

return tp.cast(F, split_rngs_wrapper)

return split_rngs_decorator # type: ignore[bad-return-type]

if squeeze and splits != 1:
raise ValueError('squeeze=True is only supported for splits=1')

predicate = filterlib.to_predicate(only)
backups: list[StreamBackup] = []
for path, stream in graph.iter_graph(node):
Expand All @@ -429,8 +437,13 @@ def split_rngs_wrapper(*args, **kwargs):
):
key = stream()
backups.append((stream, stream.key.value, stream.count.value))
stream.key.value = jax.random.split(key, splits)
if isinstance(splits, int):
key = jax.random.split(key, splits)
if squeeze:
key = key[0]
stream.key.value = key
if squeeze:
counts_shape = stream.count.shape
elif isinstance(splits, int):
counts_shape = (splits, *stream.count.shape)
else:
counts_shape = (*splits, *stream.count.shape)
Expand Down

0 comments on commit 7b1ec01

Please sign in to comment.