Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lint flax.nnx.while_loop docstring #4371

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,10 +1406,10 @@ def __call__(self, pure_val):
def while_loop(cond_fun: tp.Callable[[T], tp.Any],
body_fun: tp.Callable[[T], T],
init_val: T) -> T:
"""NNX transform of `jax.lax.while_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html>`_.
"""A Flax NNX transformation of `jax.lax.while_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html>`_.

Caution: for the NNX internal reference tracing mechanism to work, you cannot
change the variable reference structure of `init_val` inside `body_fun`.
change the variable reference structure of ``init_val`` inside ``body_fun``.

Example::

Expand All @@ -1427,12 +1427,12 @@ def while_loop(cond_fun: tp.Callable[[T], tp.Any],


Args:
cond_fun: a function for the continue condition of the while loop, taking a
single input of type `T` and outputting a boolean.
body_fun: a function that takes an input of type `T` and outputs an `T`.
Note that both data and modules of `T` must have the same reference
cond_fun: A function for the continue condition of the while loop, taking a
single input of type ``T`` and outputting a boolean.
body_fun: A function that takes an input of type ``T`` and outputs an ``T``.
Note that both data and modules of ``T`` must have the same reference
structure between inputs and outputs.
init_val: the initial input for cond_fun and body_fun. Must be of type `T`.
init_val: The initial input for ``cond_fun`` and ``body_fun``. Must be of type ``T``.

"""

Expand Down Expand Up @@ -1537,4 +1537,4 @@ def fori_loop(lower: int, upper: int,
ForiLoopBodyFn(body_fun), pure_init_val,
unroll=unroll)
out = extract.from_tree(pure_out, ctxtag='fori_loop')
return out
return out
Loading