Skip to content

Commit

Permalink
Improve documentation for jax.lax.stop_gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 17, 2024
1 parent e92a599 commit 3c37d4f
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
2 changes: 1 addition & 1 deletion docs/_tutorials/advanced-autodiff.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def meta_loss_fn(params, data):
meta_grads = jax.grad(meta_loss_fn)(params, data)
```


(stopping-gradients)=
### Stopping gradients

Autodiff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, you might want some additional control: for instance, you might want to avoid backpropagating gradients through some subset of the computational graph.
Expand Down
49 changes: 37 additions & 12 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,18 +1372,43 @@ def stop_gradient(x: T) -> T:
argument `x` unchanged. However, ``stop_gradient`` prevents the flow of
gradients during forward or reverse-mode automatic differentiation. If there
are multiple nested gradient computations, ``stop_gradient`` stops gradients
for all of them.
For example:
>>> jax.grad(lambda x: x**2)(3.)
Array(6., dtype=float32, weak_type=True)
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
Array(0., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: x**2))(3.)
Array(2., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
Array(0., dtype=float32, weak_type=True)
for all of them. For some discussion of where this is useful, refer to
:ref:`stopping-gradients`.
Args:
x: array or pytree of arrays
Returns:
input value is returned unchanged, but within autodiff will be treated as
a constant.
Examples:
Consider a simple function that returns the square of the input value:
>>> def f1(x):
... return x ** 2
>>> x = jnp.float32(3.0)
>>> f1(x)
Array(9.0, dtype=float32)
>>> jax.grad(f1)(x)
Array(6.0, dtype=float32)
The same function with ``stop_gradient`` around ``x`` will be equivalent
under normal evaluation, but return a zero gradient because ``x`` is
effectively treated as a constant:
>>> def f2(x):
... return jax.lax.stop_gradient(x) ** 2
>>> f2(x)
Array(9.0, dtype=float32)
>>> jax.grad(f2)(x)
Array(0.0, dtype=float32)
This is used in a number of places within the JAX codebase; for example
:func:`jax.nn.softmax` internally normalizes the input by its maximum
value, and this maximum value is wrapped in ``stop_gradient`` for
efficiency. Refer to :ref:`stopping-gradients` for more discussion of
the applicability of ``stop_gradient``.
"""
def stop(x):
# only bind primitive on inexact dtypes, to avoid some staging
Expand Down

0 comments on commit 3c37d4f

Please sign in to comment.