-
Notifications
You must be signed in to change notification settings - Fork 645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add NNX transforms nnx.while_loop
and nnx.switch
#4343
Conversation
c0e6ba2
to
338bfb8
Compare
flax/nnx/transforms/transforms.py
Outdated
global_index_mapping = {} | ||
if not isinstance(ns, extract.NodeStates): | ||
return ns | ||
assert isinstance(ns._graphdef, graph.NodeDef) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you have a repeated NNX object in the pytree processed by to_tree
you will get a NodeRef
here. I tested this example that fails (can we add it to the test cases?).
def test_repeated_object(self):
m = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
def body_fn(val):
count, m, _ = val
return count + 1, m, m
count, m, _ = nnx.while_loop(
lambda val: val[0] < 2,
body_fn,
(0, m, m),
)
To fix it you can probably just return.
assert isinstance(ns._graphdef, graph.NodeDef) | |
if not isinstance(ns._graphdef, graph.NodeDef): | |
return ns |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch, thanks!
flax/nnx/transforms/transforms.py
Outdated
def per_node_state(ns: extract.NodeStates | tp.Any): | ||
if not isinstance(ns, extract.NodeStates): | ||
return ns | ||
assert isinstance(ns._graphdef, graph.NodeDef) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as before.
flax/nnx/transforms/transforms.py
Outdated
init_val: T) -> T: | ||
"""NNX transform of `jax.lax.while_loop`. | ||
|
||
See: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More friendly link.
See: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html | |
See `jax.lax.while_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html>`_. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or maybe we link be above mention?
flax/nnx/transforms/transforms.py
Outdated
"""NNX transform of `jax.lax.while_loop`. | ||
|
||
See: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a simple example usage here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
BTW: Can we move |
Implemented
nnx.switch
similar tonnx.cond
Implemented
nnx.while_loop
jax.lax.while_loop
, no reference structure change is allowed insidennx.while_loop
for NNX objects.