Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NNX transforms nnx.while_loop and nnx.switch #4343

Merged
merged 1 commit into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs_nnx/api_reference/flax.nnx/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@ transforms
.. autofunction:: value_and_grad
.. autofunction:: vmap
.. autofunction:: eval_shape
.. autofunction:: cond
.. autofunction:: custom_vjp
.. autofunction:: cond
.. autofunction:: switch
.. autofunction:: while_loop
2 changes: 2 additions & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@
from .transforms.iteration import pmap as pmap
from .transforms.transforms import eval_shape as eval_shape
from .transforms.transforms import cond as cond
from .transforms.transforms import switch as switch
from .transforms.iteration import while_loop as while_loop
from .transforms.iteration import StateAxes as StateAxes
from .variablelib import A as A
from .variablelib import BatchStat as BatchStat
Expand Down
145 changes: 145 additions & 0 deletions flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
M = tp.TypeVar('M', bound=Module)
MA = tp.TypeVar('MA', bound=Module)
N = tp.TypeVar('N', bound=Module)
T = tp.TypeVar('T')
StrInt = tp.TypeVar('StrInt', str, int)
AxisName = tp.Hashable
Leaves = tp.List[Leaf]
Expand Down Expand Up @@ -1304,3 +1305,147 @@ def scan_wrapper(*args, **kwargs):
return out

return scan_wrapper # type: ignore





# -------------------------------
# while_loop
# -------------------------------


@dataclasses.dataclass(eq=False)
class WhileLoopCondFn:
f: tp.Callable[..., tp.Any]

def __post_init__(self):
functools.update_wrapper(self, self.f)

def __call__(self, pure_val):
val = extract.from_tree(pure_val)
out = self.f(val)
return out


def _add_fake_index_mapping(tree: tp.Any):
def per_node_state(ns: extract.NodeStates | tp.Any):
global_index_mapping = {}
if not isinstance(ns, extract.NodeStates) or not isinstance(
ns._graphdef, graph.NodeDef
):
return ns

def per_node_def(nd: graph.NodeDef | tp.Any):
if nd.index >= 0:
global_index_mapping[nd.index] = nd.index
for sub_nd in nd.subgraphs.values():
per_node_def(sub_nd)
for l in nd.leaves.values():
if isinstance(l, graph.NodeRef) and l.index >= 0:
global_index_mapping[l.index] = l.index
return

per_node_def(ns._graphdef)
return dataclasses.replace(ns, _graphdef=dataclasses.replace(
ns._graphdef,
index_mapping=FrozenDict(global_index_mapping)
))

return jax.tree.map(per_node_state, tree,
is_leaf=lambda x: isinstance(x, extract.NodeStates))


def _remove_index_mapping(tree: tp.Any):
'''Remove a fake index_mapping for the input to match that of the output.'''
def per_node_state(ns: extract.NodeStates | tp.Any):
if not isinstance(ns, extract.NodeStates) or not isinstance(
ns._graphdef, graph.NodeDef
):
return ns
assert isinstance(ns._graphdef, graph.NodeDef)
return dataclasses.replace(ns, _graphdef=dataclasses.replace(
ns._graphdef, index_mapping=None
))

return jax.tree.map(per_node_state, tree,
is_leaf=lambda x: isinstance(x, extract.NodeStates))


@dataclasses.dataclass(eq=False)
class WhileLoopBodyFn:
f: tp.Callable[..., tp.Any]

def __post_init__(self):
functools.update_wrapper(self, self.f)

@graph.update_context('while_loop_body')
def __call__(self, pure_val):
# Removing the dummy index mapping being added outside of body function.
pure_val_in = _remove_index_mapping(pure_val)

val = extract.from_tree(pure_val_in, ctxtag='while_loop_body')
out = self.f(val)
pure_out = extract.to_tree(out, ctxtag='while_loop_body')

try:
jax.tree.map(lambda a, b: None, pure_val, pure_out)
except ValueError as e:
msg = ("nnx.while_loop requires body function's input and output to "
"have the same reference and pytree structure, but they differ. "
"If the mismatch comes from `index_mapping` field, you might "
"have modified reference structure within the body function, "
"which is not allowed."
f"Detail of the mismatch: \n {str(e)}")
raise ValueError(msg)

return pure_out


@graph.update_context('while_loop')
def while_loop(cond_fun: tp.Callable[[T], tp.Any],
body_fun: tp.Callable[[T], T],
init_val: T) -> T:
"""NNX transform of `jax.lax.while_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html>`_.

Caution: for the NNX internal reference tracing mechanism to work, you cannot
change the reference structure of `init_val` inside `body_fun`.

Example::

>>> import jax
>>> from flax import nnx
>>> def fwd_fn(input):
... module, x, count = input
... return module, module(x), count - 1.0

>>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
>>> x = jax.random.normal(jax.random.key(0), (10,))
>>> # `module` will be called three times
>>> _, y, _ = nnx.while_loop(
... lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))


Args:
cond_fun: a function for the continue condition of the while loop, taking a
single input of type `T` and outputting a boolean.
body_fun: a function that takes an input of type `T` and outputs an `T`.
Note that both data and modules of `T` must have the same reference
structure between inputs and outputs.
init_val: the initial input for cond_fun and body_fun. Must be of type `T`.

"""

pure_init_val = extract.to_tree(init_val, ctxtag='while_loop')

# Adding the expected reference mapping to `pure_init_val` to match
# `body_fun`'s output pytree structure, to make JAX while_loop happy.
pure_init_val = _add_fake_index_mapping(pure_init_val)

pure_out = jax.lax.while_loop(
WhileLoopCondFn(cond_fun),
WhileLoopBodyFn(body_fun),
pure_init_val,
)
out = extract.from_tree(pure_out, ctxtag='while_loop')
return out
16 changes: 15 additions & 1 deletion flax/nnx/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _eval_shape_fn(*args, **kwargs):


# -------------------------------
# cond
# cond and switch
# -------------------------------


Expand All @@ -160,3 +160,17 @@ def cond(
*operands,
**kwargs,
)


@general.split_inputs(ctxtag='switch')
def switch(
index,
branches: tp.Sequence[tp.Callable[..., A]],
*operands,
) -> A:
return jax.lax.switch(
index,
[general.merge_inputs(f, ctxtag='switch') for f in branches],
*operands,
)

168 changes: 167 additions & 1 deletion tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,7 +1673,6 @@ def unroll(cell: RNNCell, carry, x) -> tuple[jax.Array, jax.Array]:

x = jnp.ones((16, 10, 20))
y = rnn_forward(cell, x)
print(y.shape)


class TestRemat(absltest.TestCase):
Expand Down Expand Up @@ -2612,6 +2611,173 @@ def no_nothing(env: Env):
)


class TestSwitch(absltest.TestCase):
def test_basic(self):
class RoundTable(nnx.Module):
def __init__(self):
self.next_index = 0
self.linear = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
self.linear.kernel.value = jnp.identity(10)
self.rounds_count = nnx.Variable(jnp.array(0))

def __call__(self, x):
def fn0(m, x):
m.rounds_count += 1
return m.linear(x)
def fn1(m, x):
return m.linear(x) * 2
def fn2(m, x):
m.linear.kernel.value = jnp.zeros((10, 10))
return m.linear(x)

# y = nnx.cond(self.next_index.value == 0, fn0, fn1, self, x)
y = nnx.switch(self.next_index, (fn0, fn1, fn2), self, x)
self.next_index = (self.next_index + 1) % 3
return y

model = RoundTable()
x = jnp.ones((10,))
np.testing.assert_array_equal(model(x), x)
assert model.rounds_count.value == 1
assert model.next_index == 1
np.testing.assert_array_equal(model(x), x * 2)
assert model.rounds_count.value == 1
assert model.next_index == 2
np.testing.assert_array_equal(model(x), jnp.zeros((10,)))
assert model.rounds_count.value == 1
assert model.next_index == 0
np.testing.assert_array_equal(model(x), jnp.zeros((10,)))
assert model.rounds_count.value == 2
assert model.next_index == 1


class TestWhileLoop(absltest.TestCase):
def test_basic(self):
def fwd_fn(input):
m, x, c = input
y = m(x)
return m, y, c - 1.0

module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
module.kernel.value = jnp.identity(10) * 2
x = 1e1 * jax.random.normal(jax.random.key(0), (10,))

_, y, _ = nnx.while_loop(
lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
np.testing.assert_array_equal(y, x * 8)

def test_multiple_objects(self):
def fwd_fn(input):
m1, (w2,), x, c = input
y = m1(x) @ w2
return m1, (w2,), y, c - 1.0

m1 = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
m1.kernel.value = jnp.identity(10) * 2
w2 = nnx.Variable(jnp.identity(10) * 0.5)
x = 1e1 * jax.random.normal(jax.random.key(0), (10,))

_, _, y, _ = nnx.while_loop(
lambda input: input[-1] > 0, fwd_fn, (m1, (w2,), x, 3.0))
np.testing.assert_allclose(y, x)

def test_nested_module(self):
def fwd_fn(input):
m, x, c = input
y = m(x)
return m, y, c - 1.0

module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
module.kernel.value = jnp.identity(10) * 2
module = nnx.Sequential(module)
x = 1e1 * jax.random.normal(jax.random.key(0), (10,))

_, y, _ = nnx.while_loop(
lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
np.testing.assert_array_equal(y, x * 8)


def test_shared_module(self):
m1 = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
m2 = nnx.Linear(10, 10, use_bias=False, rngs=nnx.Rngs(0))
m2.kernel = m1.kernel
module = nnx.Sequential(m1, m2)
self.assertLen(jax.tree.leaves(nnx.state(module)), 2) # only m1 params

def fwd_fn(input):
m, x, c = input
y = m(x)
m.layers[0].kernel.value = jnp.zeros_like(m.layers[0].kernel.value)
return m, y, c - 1.0

x = 1e1 * jax.random.normal(jax.random.key(0), (10,))
_, y, _ = nnx.while_loop(
lambda input: input[-1] > 0, fwd_fn, (module, x, 2.0))
self.assertLen(jax.tree.leaves(nnx.state(module)), 2) # only m1 params
np.testing.assert_array_equal(m1.kernel.value, jnp.zeros((10, 10,)))
np.testing.assert_array_equal(m2.kernel.value, jnp.zeros((10, 10,)))
np.testing.assert_array_equal(y, jnp.zeros((10,)))


def test_value_changed(self):
def fwd_fn(input):
m, x, c = input
m.kernel.value = jnp.zeros_like(m.kernel.value)
y = m(x)
return m, y, c - 1.0

module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
x = 1e1 * jax.random.normal(jax.random.key(0), (10,))

_, y, _ = nnx.while_loop(
lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
np.testing.assert_array_equal(module.kernel.value, jnp.zeros((10, 10,)))
np.testing.assert_array_equal(y, jnp.zeros((10,)))


def test_ref_changed(self):
def fwd_fn(input):
m, x, c = input
y = m(x)
m.kernel = nnx.Param(jnp.zeros_like(m.kernel.value))
return m, y, c - 1.0

module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
x = 1e1 * jax.random.normal(jax.random.key(0), (10,))

with self.assertRaises(ValueError):
_, y, _ = nnx.while_loop(
lambda input: input[-1] > 0, fwd_fn, (module, x, 2.0))


def test_structure_changed(self):
def fwd_fn(input):
m, x, c = input
m = nnx.Linear(10, 10, use_bias=False, rngs=nnx.Rngs(1))
m.kernel.value = jnp.identity(10) * 2
y = m(x)
return m, y, c - 1.0

module = nnx.Linear(10, 10, use_bias=True, rngs=nnx.Rngs(0))
x = 1e1 * jax.random.normal(jax.random.key(0), (10,))

with self.assertRaises(ValueError):
_, y, _ = nnx.while_loop(
lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))

def test_repeated_object(self):
m = nnx.Linear(10, 10, rngs=nnx.Rngs(0))

def body_fn(val):
count, m, _ = val
return count + 1, m, m

count, m, _ = nnx.while_loop(
lambda val: val[0] < 2,
body_fn,
(0, m, m),
)

class TestSplitMergeInputs(absltest.TestCase):
def test_split_inputs(self):
class StatefulLinear(nnx.Linear):
Expand Down
Loading