diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index e5ce20f8e3..0bb867d491 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -169,12 +169,14 @@ def jit( abstracted_axes: tp.Optional[tp.Any] = None, ) -> F | tp.Callable[[F], F]: """ - Lifted version of ``jax.jit`` that can handle Modules / graph nodes as + A "lifted" version of ``jax.jit`` that can handle ``nnx.Modules`` / graph nodes as arguments. + Learn more in `Flax NNX vs JAX Transformations `_. + Args: - fun: Function to be jitted. ``fun`` should be a pure function, as - side-effects may only be executed once. + fun: A function to be `JIT-compiled `_. + ``fun`` should be a pure function, as side-effects may only be executed once. The arguments and return value of ``fun`` should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. @@ -186,15 +188,15 @@ def jit( JAX keeps a weak reference to ``fun`` for use as a compilation cache key, so the object ``fun`` must be weakly-referenceable. Most :class:`Callable` objects will already satisfy this requirement. - in_shardings: Pytree of structure matching that of arguments to ``fun``, + in_shardings: A JAX pytree of structure matching that of arguments to ``fun``, with all actual arguments replaced by resource assignment specifications. It is also valid to specify a pytree prefix (e.g. one value in place of a whole subtree), in which case the leaves get broadcast to all values in that subtree. The ``in_shardings`` argument is optional. JAX will infer the shardings - from the input :py:class:`jax.Array`'s and defaults to replicating the input - if the sharding cannot be inferred. + from the input `jax.Arrays `_ + and defaults to replicating the input if the sharding cannot be inferred. The valid resource assignment specifications are: - :py:class:`Sharding`, which will decide how the value @@ -208,12 +210,14 @@ def jit( determine the output shardings. The size of every dimension has to be a multiple of the total number of - resources assigned to it. This is similar to pjit's in_shardings. - out_shardings: Like ``in_shardings``, but specifies resource - assignment for function outputs. This is similar to pjit's - out_shardings. - - The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit` + resources assigned to it. This is similar to `pjit `_ + ``in_shardings``. + out_shardings: Similar to ``in_shardings``, but specifies resource + assignment for function outputs. This is similar to JAX ``pjit`` + ``out_shardings``. + + The ``out_shardings`` argument is optional. If not specified, + `jax.jit `_ will use GSPMD's sharding propagation to figure out what the sharding of the output(s) should be. static_argnums: An optional int or collection of ints that specify which @@ -223,7 +227,7 @@ def jit( any Python object. Static arguments should be hashable, meaning both ``__hash__`` and - ``__eq__`` are implemented, and immutable. Calling the jitted function + ``__eq__`` are implemented, and immutable. Calling the JIT-compiled function with different values for these constants will trigger recompilation. Arguments that are not arrays or containers thereof must be marked as static. @@ -262,18 +266,18 @@ def jit( be donated. For more details on buffer donation see the - `FAQ `_. + `JAX FAQ `_. donate_argnames: An optional string or collection of strings specifying which named arguments are donated to the computation. See the comment on ``donate_argnums`` for details. If not provided but ``donate_argnums`` is set, the default is based on calling ``inspect.signature(fun)`` to find corresponding named arguments. - keep_unused: If `False` (the default), arguments that JAX determines to be - unused by `fun` *may* be dropped from resulting compiled XLA executables. + keep_unused: If ``False`` (the default), arguments that JAX determines to be + unused by ``fun`` *may* be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If `True`, unused arguments will not be pruned. device: This is an experimental feature and the API is likely to change. - Optional, the Device the jitted function will run on. (Available devices + Optional, the Device the JIT-compiled function will run on. (Available devices can be retrieved via :py:func:`jax.devices`.) The default is inherited from XLA's DeviceAssignment logic and is usually to use ``jax.devices()[0]``. @@ -282,7 +286,7 @@ def jit( ``'tpu'``. inline: Specify whether this function should be inlined into enclosing jaxprs (rather than being represented as an application of the xla_call - primitive with its own subjaxpr). Default False. + primitive with its own subjaxpr). Default ``False``. Returns: A wrapped version of ``fun``, set up for just-in-time compilation.