Skip to content

Commit

Permalink
Merge pull request #4353 from IvyZX:conds
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 691978691
  • Loading branch information
Flax Authors committed Oct 31, 2024
2 parents f740ab3 + 6ad09a3 commit 591cd40
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs_nnx/api_reference/flax.nnx/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ transforms
.. autofunction:: cond
.. autofunction:: switch
.. autofunction:: while_loop
.. autofunction:: fori_loop
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
from .transforms.transforms import cond as cond
from .transforms.transforms import switch as switch
from .transforms.iteration import while_loop as while_loop
from .transforms.iteration import fori_loop as fori_loop
from .transforms.iteration import StateAxes as StateAxes
from .variablelib import A as A
from .variablelib import BatchStat as BatchStat
Expand Down
91 changes: 90 additions & 1 deletion flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,7 +1409,7 @@ def while_loop(cond_fun: tp.Callable[[T], tp.Any],
"""NNX transform of `jax.lax.while_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html>`_.
Caution: for the NNX internal reference tracing mechanism to work, you cannot
change the reference structure of `init_val` inside `body_fun`.
change the variable reference structure of `init_val` inside `body_fun`.
Example::
Expand Down Expand Up @@ -1448,4 +1448,93 @@ def while_loop(cond_fun: tp.Callable[[T], tp.Any],
pure_init_val,
)
out = extract.from_tree(pure_out, ctxtag='while_loop')
return out


@dataclasses.dataclass(eq=False)
class ForiLoopBodyFn:
f: tp.Callable[..., tp.Any]

def __post_init__(self):
functools.update_wrapper(self, self.f)

@graph.update_context('fori_loop_body')
def __call__(self, i, pure_val):
# Removing the dummy index mapping being added outside of body function.
pure_val_in = _remove_index_mapping(pure_val)

val = extract.from_tree(pure_val_in, ctxtag='fori_loop_body')
out = self.f(i, val)
pure_out = extract.to_tree(out, ctxtag='fori_loop_body')

try:
jax.tree.map(lambda a, b: None, pure_val, pure_out)
except ValueError as e:
msg = ("nnx.fori_loop requires body function's input and output to "
"have the same reference and pytree structure, but they differ. "
"If the mismatch comes from `index_mapping` field, you might "
"have modified reference structure within the body function, "
"which is not allowed."
f"Detail of the mismatch: \n {str(e)}")
raise ValueError(msg)

return pure_out


@graph.update_context('fori_loop')
def fori_loop(lower: int, upper: int,
body_fun: tp.Callable[[int, T], T],
init_val: T,
*,
unroll: int | bool | None = None) -> T:
"""NNX transform of `jax.lax.fori_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html>`_.
Caution: for the NNX internal reference tracing mechanism to work, you cannot
change the variable reference structure of `init_val` inside `body_fun`.
Example::
>>> import jax
>>> from flax import nnx
>>> def fwd_fn(i, input):
... m, x = input
... m.kernel.value = jnp.identity(10) * i
... return m, m(x)
>>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
>>> x = jax.random.normal(jax.random.key(0), (10,))
>>> _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x))
>>> np.testing.assert_array_equal(y, x * 2 * 3)
Args:
lower: an integer representing the loop index lower bound (inclusive)
upper: an integer representing the loop index upper bound (exclusive)
body_fun: a function that takes an input of type `T` and outputs an `T`.
Note that both data and modules of `T` must have the same reference
structure between inputs and outputs.
init_val: the initial input for body_fun. Must be of type `T`.
unroll: An optional integer or boolean that determines how much to unroll
the loop. If an integer is provided, it determines how many unrolled
loop iterations to run within a single rolled iteration of the loop. If a
boolean is provided, it will determine if the loop is competely unrolled
(i.e. `unroll=True`) or left completely unrolled (i.e. `unroll=False`).
This argument is only applicable if the loop bounds are statically known.
Returns:
Loop value from the final iteration, of type ``T``.
"""

pure_init_val = extract.to_tree(init_val, ctxtag='fori_loop')

# Adding the expected reference mapping to `pure_init_val` to match
# `body_fun`'s output pytree structure, to make JAX happy.
pure_init_val = _add_fake_index_mapping(pure_init_val)

pure_out = jax.lax.fori_loop(lower, upper,
ForiLoopBodyFn(body_fun), pure_init_val,
unroll=unroll)
out = extract.from_tree(pure_out, ctxtag='fori_loop')
return out
13 changes: 13 additions & 0 deletions tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2922,6 +2922,19 @@ def body_fn(val):
(0, m, m),
)

def test_fori_loop_basic(self):
def fwd_fn(i, input):
m, x = input
m.kernel.value = jnp.identity(10) * i
return m, m(x)

module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(0), (10,))

_, y = nnx.fori_loop(2, 4, fwd_fn, (module, x))
np.testing.assert_array_equal(y, x * 2 * 3)


class TestSplitMergeInputs(absltest.TestCase):
def test_split_inputs(self):
class StatefulLinear(nnx.Linear):
Expand Down

0 comments on commit 591cd40

Please sign in to comment.