Skip to content

Commit

Permalink
Merge pull request #23298 from froystig:docs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 670052227
  • Loading branch information
jax authors committed Sep 2, 2024
2 parents ebd5948 + 180573a commit c092cbf
Show file tree
Hide file tree
Showing 21 changed files with 43 additions and 49 deletions.
10 changes: 5 additions & 5 deletions docs/debugging.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ This section introduces you to a set of built-in JAX debugging methods — {func

Let's begin with {func}`jax.debug.print`.

## JAX `debug.print` for high-level
## `jax.debug.print` for simple inspection

**TL;DR** Here is a rule of thumb:
Here is a rule of thumb:

- Use {func}`jax.debug.print` for traced (dynamic) array values with {func}`jax.jit`, {func}`jax.vmap` and others.
- Use Python {func}`print` for static values, such as dtypes and array shapes.
Expand Down Expand Up @@ -111,9 +111,9 @@ f(1, 2)
To learn more about {func}`jax.debug.print` and its Sharp Bits, refer to {ref}`advanced-debugging`.


## JAX `debug.breakpoint` for `pdb`-like debugging
## `jax.debug.breakpoint` for `pdb`-like debugging

**TL;DR** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values.
**Summary:** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values.

To pause your compiled JAX program during certain points during debugging, you can use {func}`jax.debug.breakpoint`. The prompt is similar to Python `pdb`, and it allows you to inspect the values in the call stack. In fact, {func}`jax.debug.breakpoint` is an application of {func}`jax.debug.callback` that captures information about the call stack.

Expand Down Expand Up @@ -160,7 +160,7 @@ f(2., 1.) # ==> No breakpoint
f(2., 0.) # ==> Pauses during execution
```

## JAX `debug.callback` for more control during debugging
## `jax.debug.callback` for more control during debugging

Both {func}`jax.debug.print` and {func}`jax.debug.breakpoint` are implemented using
the more flexible {func}`jax.debug.callback`, which gives greater control over the
Expand Down
2 changes: 1 addition & 1 deletion docs/debugging/checkify_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

<!--* freshness: { reviewed: '2023-02-28' } *-->

**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
**Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:

```python
from jax.experimental import checkify
Expand Down
4 changes: 2 additions & 2 deletions docs/debugging/flags.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ JAX offers flags and context managers that enable catching errors more easily.

## `jax_debug_nans` configuration option and context manager

**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code).
**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code).

`jax_debug_nans` is a JAX flag that when enabled, automatically raises an error when a NaN is detected. It has special handling for JIT-compiled -- when a NaN output is detected from a JIT-ted function, the function is re-run eagerly (i.e. without compilation) and will throw an error at the specific primitive that produced the NaN.

Expand Down Expand Up @@ -41,7 +41,7 @@ jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!

## `jax_disable_jit` configuration option and context manager

**TL;DR** Enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`
**Summary:** Enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`

`jax_disable_jit` is a JAX flag that when enabled, disables JIT-compilation throughout JAX (including in control flow functions like `jax.lax.cond` and `jax.lax.scan`).

Expand Down
26 changes: 16 additions & 10 deletions docs/debugging/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@

<!--* freshness: { reviewed: '2024-04-11' } *-->

Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the "Read more" links at the bottom to learn more.
Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has summaries and you can click the "Read more" links at the bottom to learn more.

Table of contents:

* [Interactive inspection with `jax.debug`](print_breakpoint)
* [Functional error checks with jax.experimental.checkify](checkify_guide)
* [Throwing Python errors with JAX’s debug flags](flags)

## [Interactive inspection with `jax.debug`](print_breakpoint)
## Interactive inspection with `jax.debug`

**TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions,
Complete guide [here](print_breakpoint)

**Summary:** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions,
and {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack:

```python
Expand All @@ -34,11 +36,13 @@ Table of contents:
# 🤯 0.9092974662780762 🤯
```

Click [here](print_breakpoint) to learn more!
[Read more](print_breakpoint).

## Functional error checks with `jax.experimental.checkify`

## [Functional error checks with `jax.experimental.checkify`](checkify_guide)
Complete guide [here](checkify_guide)

**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
**Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:

```python
from jax.experimental import checkify
Expand Down Expand Up @@ -77,11 +81,13 @@ Click [here](print_breakpoint) to learn more!
# ValueError: nan generated by primitive sin at <...>:8 (f)
```

Click [here](checkify_guide) to learn more!
[Read more](checkify_guide).

## Throwing Python errors with JAX's debug flags

## [Throwing Python errors with JAX's debug flags](flags)
Complete guide [here](flags)

**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.

```python
import jax
Expand All @@ -92,7 +98,7 @@ def f(x, y):
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
```

Click [here](flags) to learn more!
[Read more](flags).

```{toctree}
:caption: Read more
Expand Down
10 changes: 5 additions & 5 deletions docs/debugging/print_breakpoint.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# `jax.debug.print` and `jax.debug.breakpoint`
# Compiled prints and breakpoints

<!--* freshness: { reviewed: '2024-03-13' } *-->

The {mod}`jax.debug` package offers some useful tools for inspecting values
inside of JIT-ted functions.
inside of compiled functions.

## Debugging with `jax.debug.print` and other debugging callbacks

**TL;DR** Use {func}`jax.debug.print` to print traced array values to stdout in `jit`- and `pmap`-decorated functions:
**Summary:** Use {func}`jax.debug.print` to print traced array values to
stdout in compiled (e.g. `jax.jit` or `jax.pmap`-decorated) functions:

```python
import jax
Expand All @@ -26,7 +27,6 @@ f(2.)
# 🤯 0.9092974662780762 🤯
```

<!-- mattjj added this line -->
With some transformations, like `jax.grad` and `jax.vmap`, you can use Python's builtin `print` function to print out numerical values. But `print` won't work with `jax.jit` or `jax.pmap` because those transformations delay numerical evaluation. So use `jax.debug.print` instead!

Semantically, `jax.debug.print` is roughly equivalent to the following Python function
Expand Down Expand Up @@ -236,7 +236,7 @@ Furthermore, when using `jax.debug.print` with `jax.pjit`, a global synchronizat

## Interactive inspection with `jax.debug.breakpoint()`

**TL;DR** Use `jax.debug.breakpoint()` to pause the execution of your JAX program to inspect values:
**Summary:** Use `jax.debug.breakpoint()` to pause the execution of your JAX program to inspect values:

```python
@jax.jit
Expand Down
2 changes: 1 addition & 1 deletion docs/distributed_data_loading.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ kernelspec:
name: python3
---

# Distributed data loading in multi-host / multi-process environments
# Distributed data loading

<!--* freshness: { reviewed: '2024-05-16' } *-->

Expand Down
2 changes: 1 addition & 1 deletion docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Using JAX requires installing two packages: `jax`, which is pure Python and
cross-platform, and `jaxlib` which contains compiled binaries, and requires
different builds for different operating systems and accelerators.

**TL;DR** For most users, a typical JAX installation may look something like this:
**Summary:** For most users, a typical JAX installation may look something like this:

* **CPU-only (Linux/macOS/Windows)**
```
Expand Down
2 changes: 1 addition & 1 deletion docs/investigating_a_regression.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Here is a suggested investigation strategy:

## Nightly investigation

This can be done by using [JAX-Toolbox nightly
This can be done by using the [NVIDIA JAX-Toolbox nightly
containers](https://github.com/NVIDIA/JAX-Toolbox).

- Some days, bugs prevent the container from being built, or there are temporary regressions. Just discard those days.
Expand Down
2 changes: 1 addition & 1 deletion docs/multi_process.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Using JAX in multi-host and multi-process environments
# Multi-host and multi-process environments

<!--* freshness: { reviewed: '2024-06-10' } *-->

Expand Down
2 changes: 0 additions & 2 deletions docs/notebooks/Common_Gotchas_in_JAX.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
"id": "4k5PVzEo2uJO"
},
"source": [
"*levskaya@ mattjj@*\n",
"\n",
"When walking about the countryside of Italy, the people will not hesitate to tell you that __JAX__ has [_\"una anima di pura programmazione funzionale\"_](https://www.sscardapane.it/iaml-backup/jax-intro/).\n",
"\n",
"__JAX__ is a language for __expressing__ and __composing__ __transformations__ of numerical programs. __JAX__ is also able to __compile__ numerical programs for CPU or accelerators (GPU/TPU).\n",
Expand Down
2 changes: 0 additions & 2 deletions docs/notebooks/Common_Gotchas_in_JAX.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ kernelspec:

+++ {"id": "4k5PVzEo2uJO"}

*levskaya@ mattjj@*

When walking about the countryside of Italy, the people will not hesitate to tell you that __JAX__ has [_"una anima di pura programmazione funzionale"_](https://www.sscardapane.it/iaml-backup/jax-intro/).

__JAX__ is a language for __expressing__ and __composing__ __transformations__ of numerical programs. __JAX__ is also able to __compile__ numerical programs for CPU or accelerators (GPU/TPU).
Expand Down
6 changes: 2 additions & 4 deletions docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
"id": "LqiaKasFjH82"
},
"source": [
"# Custom derivative rules for JAX-transformable Python functions\n",
"# Custom derivative rules\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n",
"\n",
"*mattjj@ Mar 19 2020, last updated Oct 14 2020*\n",
"\n",
"There are two ways to define differentiation rules in JAX:\n",
"\n",
"1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and\n",
Expand All @@ -30,7 +28,7 @@
"id": "9Fg3NFNY-2RY"
},
"source": [
"## TL;DR"
"## Summary"
]
},
{
Expand Down
6 changes: 2 additions & 4 deletions docs/notebooks/Custom_derivative_rules_for_Python_code.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@ kernelspec:

+++ {"id": "LqiaKasFjH82"}

# Custom derivative rules for JAX-transformable Python functions
# Custom derivative rules

<!--* freshness: { reviewed: '2024-04-08' } *-->

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)

*mattjj@ Mar 19 2020, last updated Oct 14 2020*

There are two ways to define differentiation rules in JAX:

1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and
Expand All @@ -32,7 +30,7 @@ For an introduction to JAX's automatic differentiation API, see [The Autodiff Co

+++ {"id": "9Fg3NFNY-2RY"}

## TL;DR
## Summary

+++ {"id": "ZgMNRtXyWIW8"}

Expand Down
2 changes: 0 additions & 2 deletions docs/notebooks/autodiff_cookbook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n",
"\n",
"*alexbw@, mattjj@* \n",
"\n",
"JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics."
]
},
Expand Down
2 changes: 0 additions & 2 deletions docs/notebooks/autodiff_cookbook.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ kernelspec:

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)

*alexbw@, mattjj@*

JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics.

```{code-cell} ipython3
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/autodiff_remat.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"id": "qaIsQSh1XoKF"
},
"source": [
"### TL;DR\n",
"### Summary\n",
"\n",
"Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/autodiff_remat.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import jax.numpy as jnp

+++ {"id": "qaIsQSh1XoKF"}

### TL;DR
### Summary

Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.

Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/shard_map.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"id": "41a7e222",
"metadata": {},
"source": [
"# SPMD multi-device parallelism with `shard_map`\n",
"# Manual parallelism with `shard_map`\n",
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/shard_map.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ kernelspec:
name: python3
---

# SPMD multi-device parallelism with `shard_map`
# Manual parallelism with `shard_map`

<!--* freshness: { reviewed: '2024-04-08' } *-->

Expand Down
2 changes: 1 addition & 1 deletion docs/sharded-computation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"metadata": {},
"source": [
"(sharded-computation)=\n",
"# Introduction to sharded computation\n",
"# Introduction to parallel programming\n",
"\n",
"<!--* freshness: { reviewed: '2024-05-10' } *-->\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/sharded-computation.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ kernelspec:
---

(sharded-computation)=
# Introduction to sharded computation
# Introduction to parallel programming

<!--* freshness: { reviewed: '2024-05-10' } *-->

Expand Down

0 comments on commit c092cbf

Please sign in to comment.