From 3c37d4f20e5f2f9370181263fd02952648c7aae0 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 17 Sep 2024 15:58:14 -0700 Subject: [PATCH] Improve documentation for jax.lax.stop_gradient --- docs/_tutorials/advanced-autodiff.md | 2 +- jax/_src/lax/lax.py | 49 +++++++++++++++++++++------- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/docs/_tutorials/advanced-autodiff.md b/docs/_tutorials/advanced-autodiff.md index 180f65f5d492..da5cd0feaa1a 100644 --- a/docs/_tutorials/advanced-autodiff.md +++ b/docs/_tutorials/advanced-autodiff.md @@ -77,7 +77,7 @@ def meta_loss_fn(params, data): meta_grads = jax.grad(meta_loss_fn)(params, data) ``` - +(stopping-gradients)= ### Stopping gradients Autodiff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, you might want some additional control: for instance, you might want to avoid backpropagating gradients through some subset of the computational graph. diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8d2c24d6e64c..c791c668e68b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1372,18 +1372,43 @@ def stop_gradient(x: T) -> T: argument `x` unchanged. However, ``stop_gradient`` prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations, ``stop_gradient`` stops gradients - for all of them. - - For example: - - >>> jax.grad(lambda x: x**2)(3.) - Array(6., dtype=float32, weak_type=True) - >>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.) - Array(0., dtype=float32, weak_type=True) - >>> jax.grad(jax.grad(lambda x: x**2))(3.) - Array(2., dtype=float32, weak_type=True) - >>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.) - Array(0., dtype=float32, weak_type=True) + for all of them. For some discussion of where this is useful, refer to + :ref:`stopping-gradients`. + + Args: + x: array or pytree of arrays + + Returns: + input value is returned unchanged, but within autodiff will be treated as + a constant. + + Examples: + Consider a simple function that returns the square of the input value: + + >>> def f1(x): + ... return x ** 2 + >>> x = jnp.float32(3.0) + >>> f1(x) + Array(9.0, dtype=float32) + >>> jax.grad(f1)(x) + Array(6.0, dtype=float32) + + The same function with ``stop_gradient`` around ``x`` will be equivalent + under normal evaluation, but return a zero gradient because ``x`` is + effectively treated as a constant: + + >>> def f2(x): + ... return jax.lax.stop_gradient(x) ** 2 + >>> f2(x) + Array(9.0, dtype=float32) + >>> jax.grad(f2)(x) + Array(0.0, dtype=float32) + + This is used in a number of places within the JAX codebase; for example + :func:`jax.nn.softmax` internally normalizes the input by its maximum + value, and this maximum value is wrapped in ``stop_gradient`` for + efficiency. Refer to :ref:`stopping-gradients` for more discussion of + the applicability of ``stop_gradient``. """ def stop(x): # only bind primitive on inexact dtypes, to avoid some staging