Skip to content

Commit

Permalink
Add conditionals nnx.while_loop and nnx.switch
Browse files Browse the repository at this point in the history
  • Loading branch information
IvyZX committed Oct 31, 2024
1 parent 3f3c03b commit a68e45a
Show file tree
Hide file tree
Showing 4 changed files with 303 additions and 3 deletions.
4 changes: 3 additions & 1 deletion docs_nnx/api_reference/flax.nnx/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@ transforms
.. autofunction:: value_and_grad
.. autofunction:: vmap
.. autofunction:: eval_shape
.. autofunction:: cond
.. autofunction:: custom_vjp
.. autofunction:: cond
.. autofunction:: switch
.. autofunction:: while_loop
2 changes: 2 additions & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@
from .transforms.iteration import pmap as pmap
from .transforms.transforms import eval_shape as eval_shape
from .transforms.transforms import cond as cond
from .transforms.transforms import switch as switch
from .transforms.transforms import while_loop as while_loop
from .transforms.iteration import StateAxes as StateAxes
from .variablelib import A as A
from .variablelib import BatchStat as BatchStat
Expand Down
144 changes: 143 additions & 1 deletion flax/nnx/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
from __future__ import annotations

from abc import abstractmethod
import dataclasses
import functools
import inspect
import typing as tp

from flax.core import FrozenDict
from flax.nnx import (
extract,
graph,
)
from flax.nnx.module import Module
from flax.nnx.proxy_caller import (
Expand All @@ -41,6 +44,7 @@
M = tp.TypeVar('M', bound=Module)
MA = tp.TypeVar('MA', bound=Module)
N = tp.TypeVar('N', bound=Module)
T = tp.TypeVar('T')
StrInt = tp.TypeVar('StrInt', str, int)
AxisName = tp.Hashable
Leaves = tp.List[Leaf]
Expand Down Expand Up @@ -141,7 +145,7 @@ def _eval_shape_fn(*args, **kwargs):


# -------------------------------
# cond
# cond and switch
# -------------------------------


Expand All @@ -160,3 +164,141 @@ def cond(
*operands,
**kwargs,
)


@general.split_inputs(ctxtag='switch')
def switch(
index,
branches: tp.Sequence[tp.Callable[..., A]],
*operands,
) -> A:
return jax.lax.switch(
index,
[general.merge_inputs(f, ctxtag='switch') for f in branches],
*operands,
)


# -------------------------------
# while_loop
# -------------------------------


@dataclasses.dataclass(eq=False)
class WhileLoopCondFn:
f: tp.Callable[..., tp.Any]

def __post_init__(self):
functools.update_wrapper(self, self.f)

def __call__(self, pure_val):
val = extract.from_tree(pure_val)
out = self.f(val)
return out


def _add_fake_index_mapping(tree: tp.Any):
def per_node_state(ns: extract.NodeStates | tp.Any):
global_index_mapping = {}
if not isinstance(ns, extract.NodeStates):
return ns
assert isinstance(ns._graphdef, graph.NodeDef)

def per_node_def(nd: graph.NodeDef | tp.Any):
if nd.index >= 0:
global_index_mapping[nd.index] = nd.index
for sub_nd in nd.subgraphs.values():
per_node_def(sub_nd)
for l in nd.leaves.values():
if isinstance(l, graph.NodeRef) and l.index >= 0:
global_index_mapping[l.index] = l.index
return

per_node_def(ns._graphdef)
return dataclasses.replace(ns, _graphdef=dataclasses.replace(
ns._graphdef,
index_mapping=FrozenDict(global_index_mapping)
))

return jax.tree.map(per_node_state, tree,
is_leaf=lambda x: isinstance(x, extract.NodeStates))


def _remove_index_mapping(tree: tp.Any):
'''Remove a fake index_mapping for the input to match that of the output.'''
def per_node_state(ns: extract.NodeStates | tp.Any):
if not isinstance(ns, extract.NodeStates):
return ns
assert isinstance(ns._graphdef, graph.NodeDef)
return dataclasses.replace(ns, _graphdef=dataclasses.replace(
ns._graphdef, index_mapping=None
))

return jax.tree.map(per_node_state, tree,
is_leaf=lambda x: isinstance(x, extract.NodeStates))


@dataclasses.dataclass(eq=False)
class WhileLoopBodyFn:
f: tp.Callable[..., tp.Any]

def __post_init__(self):
functools.update_wrapper(self, self.f)

@graph.update_context('while_loop_body')
def __call__(self, pure_val):
# Removing the dummy index mapping being added outside of body function.
pure_val_in = _remove_index_mapping(pure_val)

val = extract.from_tree(pure_val_in, ctxtag='while_loop_body')
out = self.f(val)
pure_out = extract.to_tree(out, ctxtag='while_loop_body')

try:
jax.tree.map(lambda a, b: None, pure_val, pure_out)
except ValueError as e:
msg = ("nnx.while_loop requires body function's input and output to "
"have the same reference and pytree structure, but they differ. "
"If the mismatch comes from `index_mapping` field, you might "
"have modified reference structure within the body function, "
"which is not allowed."
f"Detail of the mismatch: \n {str(e)}")
raise ValueError(msg)

return pure_out


@graph.update_context('while_loop')
def while_loop(cond_fun: tp.Callable[[T], tp.Any],
body_fun: tp.Callable[[T], T],
init_val: T) -> T:
"""NNX transform of `jax.lax.while_loop`.
See: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html
Caution: for the NNX internal reference tracing mechanism to work, you cannot
change the reference structure of `init_val` inside `body_fun`.
Args:
cond_fun: a function for the continue condition of the while loop, taking a
single input of type `T` and outputting a boolean.
body_fun: a function that takes an input of type `T` and outputs an `T`.
Note that both data and modules of `T` must have the same reference
structure between inputs and outputs.
init_val: the initial input for cond_fun and body_fun. Must be of type `T`.
"""

pure_init_val = extract.to_tree(init_val, ctxtag='while_loop')

# Adding the expected reference mapping to `pure_init_val` to match
# `body_fun`'s output pytree structure, to make JAX while_loop happy.
pure_init_val = _add_fake_index_mapping(pure_init_val)

pure_out = jax.lax.while_loop(
WhileLoopCondFn(cond_fun),
WhileLoopBodyFn(body_fun),
pure_init_val,
)
out = extract.from_tree(pure_out, ctxtag='while_loop')
return out
156 changes: 155 additions & 1 deletion tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,7 +1673,6 @@ def unroll(cell: RNNCell, carry, x) -> tuple[jax.Array, jax.Array]:

x = jnp.ones((16, 10, 20))
y = rnn_forward(cell, x)
print(y.shape)


class TestRemat(absltest.TestCase):
Expand Down Expand Up @@ -2612,6 +2611,161 @@ def no_nothing(env: Env):
)


class TestSwitch(absltest.TestCase):
def test_basic(self):
class RoundTable(nnx.Module):
def __init__(self):
self.next_index = 0
self.linear = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
self.linear.kernel.value = jnp.identity(10)
self.rounds_count = nnx.Variable(jnp.array(0))

def __call__(self, x):
def fn0(m, x):
m.rounds_count += 1
return m.linear(x)
def fn1(m, x):
return m.linear(x) * 2
def fn2(m, x):
m.linear.kernel.value = jnp.zeros((10, 10))
return m.linear(x)

# y = nnx.cond(self.next_index.value == 0, fn0, fn1, self, x)
y = nnx.switch(self.next_index, (fn0, fn1, fn2), self, x)
self.next_index = (self.next_index + 1) % 3
return y

model = RoundTable()
x = jnp.ones((10,))
np.testing.assert_array_equal(model(x), x)
assert model.rounds_count.value == 1
assert model.next_index == 1
np.testing.assert_array_equal(model(x), x * 2)
assert model.rounds_count.value == 1
assert model.next_index == 2
np.testing.assert_array_equal(model(x), jnp.zeros((10,)))
assert model.rounds_count.value == 1
assert model.next_index == 0
np.testing.assert_array_equal(model(x), jnp.zeros((10,)))
assert model.rounds_count.value == 2
assert model.next_index == 1


class TestWhileLoop(absltest.TestCase):
def test_basic(self):
def fwd_fn(input):
m, x, c = input
y = m(x)
return m, y, c - 1.0

module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
module.kernel.value = jnp.identity(10) * 2
x = 1e1 * jax.random.normal(jax.random.key(0), (10,))

_, y, _ = nnx.while_loop(
lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
np.testing.assert_array_equal(y, x * 8)

def test_multiple_objects(self):
def fwd_fn(input):
m1, (w2,), x, c = input
y = m1(x) @ w2
return m1, (w2,), y, c - 1.0

m1 = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
m1.kernel.value = jnp.identity(10) * 2
w2 = nnx.Variable(jnp.identity(10) * 0.5)
x = 1e1 * jax.random.normal(jax.random.key(0), (10,))

_, _, y, _ = nnx.while_loop(
lambda input: input[-1] > 0, fwd_fn, (m1, (w2,), x, 3.0))
np.testing.assert_allclose(y, x)

def test_nested_module(self):
def fwd_fn(input):
m, x, c = input
y = m(x)
return m, y, c - 1.0

module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
module.kernel.value = jnp.identity(10) * 2
module = nnx.Sequential(module)
x = 1e1 * jax.random.normal(jax.random.key(0), (10,))

_, y, _ = nnx.while_loop(
lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
np.testing.assert_array_equal(y, x * 8)


def test_shared_module(self):
m1 = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
m2 = nnx.Linear(10, 10, use_bias=False, rngs=nnx.Rngs(0))
m2.kernel = m1.kernel
module = nnx.Sequential(m1, m2)
self.assertLen(jax.tree.leaves(nnx.state(module)), 2) # only m1 params

def fwd_fn(input):
m, x, c = input
y = m(x)
m.layers[0].kernel.value = jnp.zeros_like(m.layers[0].kernel.value)
return m, y, c - 1.0

x = 1e1 * jax.random.normal(jax.random.key(0), (10,))
_, y, _ = nnx.while_loop(
lambda input: input[-1] > 0, fwd_fn, (module, x, 2.0))
self.assertLen(jax.tree.leaves(nnx.state(module)), 2) # only m1 params
np.testing.assert_array_equal(m1.kernel.value, jnp.zeros((10, 10,)))
np.testing.assert_array_equal(m2.kernel.value, jnp.zeros((10, 10,)))
np.testing.assert_array_equal(y, jnp.zeros((10,)))


def test_value_changed(self):
def fwd_fn(input):
m, x, c = input
m.kernel.value = jnp.zeros_like(m.kernel.value)
y = m(x)
return m, y, c - 1.0

module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
x = 1e1 * jax.random.normal(jax.random.key(0), (10,))

_, y, _ = nnx.while_loop(
lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
np.testing.assert_array_equal(module.kernel.value, jnp.zeros((10, 10,)))
np.testing.assert_array_equal(y, jnp.zeros((10,)))


def test_ref_changed(self):
def fwd_fn(input):
m, x, c = input
y = m(x)
m.kernel = nnx.Param(jnp.zeros_like(m.kernel.value))
return m, y, c - 1.0

module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
x = 1e1 * jax.random.normal(jax.random.key(0), (10,))

with self.assertRaises(ValueError):
_, y, _ = nnx.while_loop(
lambda input: input[-1] > 0, fwd_fn, (module, x, 2.0))


def test_structure_changed(self):
def fwd_fn(input):
m, x, c = input
m = nnx.Linear(10, 10, use_bias=False, rngs=nnx.Rngs(1))
m.kernel.value = jnp.identity(10) * 2
y = m(x)
return m, y, c - 1.0

module = nnx.Linear(10, 10, use_bias=True, rngs=nnx.Rngs(0))
x = 1e1 * jax.random.normal(jax.random.key(0), (10,))

with self.assertRaises(ValueError):
_, y, _ = nnx.while_loop(
lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))


class TestSplitMergeInputs(absltest.TestCase):
def test_split_inputs(self):
class StatefulLinear(nnx.Linear):
Expand Down

0 comments on commit a68e45a

Please sign in to comment.