Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update flax.nnx.jit docstring #4369

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions flax/nnx/transforms/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,14 @@ def jit(
abstracted_axes: tp.Optional[tp.Any] = None,
) -> F | tp.Callable[[F], F]:
"""
Lifted version of ``jax.jit`` that can handle Modules / graph nodes as
A "lifted" version of ``jax.jit`` that can handle ``nnx.Modules`` / graph nodes as
arguments.

Learn more in `Flax NNX vs JAX Transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`_.

Args:
fun: Function to be jitted. ``fun`` should be a pure function, as
side-effects may only be executed once.
fun: A function to be `JIT-compiled <https://jax.readthedocs.io/en/latest/jit-compilation.html>`_.
``fun`` should be a pure function, as side-effects may only be executed once.

The arguments and return value of ``fun`` should be arrays,
scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
Expand All @@ -186,15 +188,15 @@ def jit(
JAX keeps a weak reference to ``fun`` for use as a compilation cache key,
so the object ``fun`` must be weakly-referenceable. Most :class:`Callable`
objects will already satisfy this requirement.
in_shardings: Pytree of structure matching that of arguments to ``fun``,
in_shardings: A JAX pytree of structure matching that of arguments to ``fun``,
with all actual arguments replaced by resource assignment specifications.
It is also valid to specify a pytree prefix (e.g. one value in place of a
whole subtree), in which case the leaves get broadcast to all values in
that subtree.

The ``in_shardings`` argument is optional. JAX will infer the shardings
from the input :py:class:`jax.Array`'s and defaults to replicating the input
if the sharding cannot be inferred.
from the input `jax.Arrays <https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array>`_
and defaults to replicating the input if the sharding cannot be inferred.

The valid resource assignment specifications are:
- :py:class:`Sharding`, which will decide how the value
Expand All @@ -208,12 +210,14 @@ def jit(
determine the output shardings.

The size of every dimension has to be a multiple of the total number of
resources assigned to it. This is similar to pjit's in_shardings.
out_shardings: Like ``in_shardings``, but specifies resource
assignment for function outputs. This is similar to pjit's
out_shardings.

The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
resources assigned to it. This is similar to `pjit <https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html>`_
``in_shardings``.
out_shardings: Similar to ``in_shardings``, but specifies resource
assignment for function outputs. This is similar to JAX ``pjit``
``out_shardings``.

The ``out_shardings`` argument is optional. If not specified,
`jax.jit <https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit>`_
will use GSPMD's sharding propagation to figure out what the sharding of the
output(s) should be.
static_argnums: An optional int or collection of ints that specify which
Expand All @@ -223,7 +227,7 @@ def jit(
any Python object.

Static arguments should be hashable, meaning both ``__hash__`` and
``__eq__`` are implemented, and immutable. Calling the jitted function
``__eq__`` are implemented, and immutable. Calling the JIT-compiled function
with different values for these constants will trigger recompilation.
Arguments that are not arrays or containers thereof must be marked as
static.
Expand Down Expand Up @@ -262,18 +266,18 @@ def jit(
be donated.

For more details on buffer donation see the
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
`JAX FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
donate_argnames: An optional string or collection of strings specifying
which named arguments are donated to the computation. See the
comment on ``donate_argnums`` for details. If not
provided but ``donate_argnums`` is set, the default is based on calling
``inspect.signature(fun)`` to find corresponding named arguments.
keep_unused: If `False` (the default), arguments that JAX determines to be
unused by `fun` *may* be dropped from resulting compiled XLA executables.
keep_unused: If ``False`` (the default), arguments that JAX determines to be
unused by ``fun`` *may* be dropped from resulting compiled XLA executables.
Such arguments will not be transferred to the device nor provided to the
underlying executable. If `True`, unused arguments will not be pruned.
device: This is an experimental feature and the API is likely to change.
Optional, the Device the jitted function will run on. (Available devices
Optional, the Device the JIT-compiled function will run on. (Available devices
can be retrieved via :py:func:`jax.devices`.) The default is inherited
from XLA's DeviceAssignment logic and is usually to use
``jax.devices()[0]``.
Expand All @@ -282,7 +286,7 @@ def jit(
``'tpu'``.
inline: Specify whether this function should be inlined into enclosing
jaxprs (rather than being represented as an application of the xla_call
primitive with its own subjaxpr). Default False.
primitive with its own subjaxpr). Default ``False``.

Returns:
A wrapped version of ``fun``, set up for just-in-time compilation.
Expand Down
Loading