From a57cabeef0516380fb0a7bd37b1c9af1b1024dba Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 11 Nov 2024 16:23:49 +0000 Subject: [PATCH 1/2] Lint flax.nnx.jit docstring --- flax/nnx/transforms/compilation.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index e5ce20f8e3..9febe4d38f 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -169,11 +169,13 @@ def jit( abstracted_axes: tp.Optional[tp.Any] = None, ) -> F | tp.Callable[[F], F]: """ - Lifted version of ``jax.jit`` that can handle Modules / graph nodes as + A "lifted" version of ``jax.jit`` that can handle ``nnx.Modules`` / graph nodes as arguments. + Learn more in `Flax NNX vs JAX Transformations `_. + Args: - fun: Function to be jitted. ``fun`` should be a pure function, as + fun: A function to be JIT-compiled. ``fun`` should be a pure function, as side-effects may only be executed once. The arguments and return value of ``fun`` should be arrays, @@ -186,7 +188,7 @@ def jit( JAX keeps a weak reference to ``fun`` for use as a compilation cache key, so the object ``fun`` must be weakly-referenceable. Most :class:`Callable` objects will already satisfy this requirement. - in_shardings: Pytree of structure matching that of arguments to ``fun``, + in_shardings: A JAX pytree of structure matching that of arguments to ``fun``, with all actual arguments replaced by resource assignment specifications. It is also valid to specify a pytree prefix (e.g. one value in place of a whole subtree), in which case the leaves get broadcast to all values in @@ -209,9 +211,9 @@ def jit( The size of every dimension has to be a multiple of the total number of resources assigned to it. This is similar to pjit's in_shardings. - out_shardings: Like ``in_shardings``, but specifies resource - assignment for function outputs. This is similar to pjit's - out_shardings. + out_shardings: Similar to ``in_shardings``, but specifies resource + assignment for function outputs. This is similar to JAX ``pjit`` + ``out_shardings``. The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit` will use GSPMD's sharding propagation to figure out what the sharding of the @@ -223,7 +225,7 @@ def jit( any Python object. Static arguments should be hashable, meaning both ``__hash__`` and - ``__eq__`` are implemented, and immutable. Calling the jitted function + ``__eq__`` are implemented, and immutable. Calling the JIT-compiled function with different values for these constants will trigger recompilation. Arguments that are not arrays or containers thereof must be marked as static. @@ -262,18 +264,18 @@ def jit( be donated. For more details on buffer donation see the - `FAQ `_. + `JAX FAQ `_. donate_argnames: An optional string or collection of strings specifying which named arguments are donated to the computation. See the comment on ``donate_argnums`` for details. If not provided but ``donate_argnums`` is set, the default is based on calling ``inspect.signature(fun)`` to find corresponding named arguments. - keep_unused: If `False` (the default), arguments that JAX determines to be - unused by `fun` *may* be dropped from resulting compiled XLA executables. + keep_unused: If ``False`` (the default), arguments that JAX determines to be + unused by ``fun`` *may* be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If `True`, unused arguments will not be pruned. device: This is an experimental feature and the API is likely to change. - Optional, the Device the jitted function will run on. (Available devices + Optional, the Device the JIT-compiled function will run on. (Available devices can be retrieved via :py:func:`jax.devices`.) The default is inherited from XLA's DeviceAssignment logic and is usually to use ``jax.devices()[0]``. @@ -282,7 +284,7 @@ def jit( ``'tpu'``. inline: Specify whether this function should be inlined into enclosing jaxprs (rather than being represented as an application of the xla_call - primitive with its own subjaxpr). Default False. + primitive with its own subjaxpr). Default ``False``. Returns: A wrapped version of ``fun``, set up for just-in-time compilation. From 12e74a51d81c5f638e7ba4b74160c45f00874722 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 11 Nov 2024 20:02:33 +0000 Subject: [PATCH 2/2] Update flax.nnx.jit docstring --- flax/nnx/transforms/compilation.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index 9febe4d38f..0bb867d491 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -175,8 +175,8 @@ def jit( Learn more in `Flax NNX vs JAX Transformations `_. Args: - fun: A function to be JIT-compiled. ``fun`` should be a pure function, as - side-effects may only be executed once. + fun: A function to be `JIT-compiled `_. + ``fun`` should be a pure function, as side-effects may only be executed once. The arguments and return value of ``fun`` should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. @@ -195,8 +195,8 @@ def jit( that subtree. The ``in_shardings`` argument is optional. JAX will infer the shardings - from the input :py:class:`jax.Array`'s and defaults to replicating the input - if the sharding cannot be inferred. + from the input `jax.Arrays `_ + and defaults to replicating the input if the sharding cannot be inferred. The valid resource assignment specifications are: - :py:class:`Sharding`, which will decide how the value @@ -210,12 +210,14 @@ def jit( determine the output shardings. The size of every dimension has to be a multiple of the total number of - resources assigned to it. This is similar to pjit's in_shardings. + resources assigned to it. This is similar to `pjit `_ + ``in_shardings``. out_shardings: Similar to ``in_shardings``, but specifies resource assignment for function outputs. This is similar to JAX ``pjit`` ``out_shardings``. - The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit` + The ``out_shardings`` argument is optional. If not specified, + `jax.jit `_ will use GSPMD's sharding propagation to figure out what the sharding of the output(s) should be. static_argnums: An optional int or collection of ints that specify which