From 672a013b3a90782a7ebc7c5646fb9b44d8183e3a Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 28 Aug 2024 10:46:47 -0700 Subject: [PATCH 1/5] docs: prefer "summary" to "tl;dr" This was common and we typically just mean "summary." --- docs/debugging.md | 4 ++-- docs/debugging/checkify_guide.md | 2 +- docs/debugging/flags.md | 4 ++-- docs/debugging/index.md | 8 ++++---- docs/debugging/print_breakpoint.md | 5 +++-- docs/installation.md | 2 +- .../Custom_derivative_rules_for_Python_code.ipynb | 2 +- docs/notebooks/Custom_derivative_rules_for_Python_code.md | 2 +- docs/notebooks/autodiff_remat.ipynb | 2 +- docs/notebooks/autodiff_remat.md | 2 +- 10 files changed, 17 insertions(+), 16 deletions(-) diff --git a/docs/debugging.md b/docs/debugging.md index 94384035ca9d..1e8501f99e39 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -23,7 +23,7 @@ Let's begin with {func}`jax.debug.print`. ## JAX `debug.print` for high-level -**TL;DR** Here is a rule of thumb: +Here is a rule of thumb: - Use {func}`jax.debug.print` for traced (dynamic) array values with {func}`jax.jit`, {func}`jax.vmap` and others. - Use Python {func}`print` for static values, such as dtypes and array shapes. @@ -113,7 +113,7 @@ To learn more about {func}`jax.debug.print` and its Sharp Bits, refer to {ref}`a ## JAX `debug.breakpoint` for `pdb`-like debugging -**TL;DR** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values. +**Summary:** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values. To pause your compiled JAX program during certain points during debugging, you can use {func}`jax.debug.breakpoint`. The prompt is similar to Python `pdb`, and it allows you to inspect the values in the call stack. In fact, {func}`jax.debug.breakpoint` is an application of {func}`jax.debug.callback` that captures information about the call stack. diff --git a/docs/debugging/checkify_guide.md b/docs/debugging/checkify_guide.md index 2dad9b863b06..8b012e97ef28 100644 --- a/docs/debugging/checkify_guide.md +++ b/docs/debugging/checkify_guide.md @@ -2,7 +2,7 @@ -**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: +**Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: ```python from jax.experimental import checkify diff --git a/docs/debugging/flags.md b/docs/debugging/flags.md index 1cf1829e5152..13e34a6c3ac4 100644 --- a/docs/debugging/flags.md +++ b/docs/debugging/flags.md @@ -6,7 +6,7 @@ JAX offers flags and context managers that enable catching errors more easily. ## `jax_debug_nans` configuration option and context manager -**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code). +**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code). `jax_debug_nans` is a JAX flag that when enabled, automatically raises an error when a NaN is detected. It has special handling for JIT-compiled -- when a NaN output is detected from a JIT-ted function, the function is re-run eagerly (i.e. without compilation) and will throw an error at the specific primitive that produced the NaN. @@ -41,7 +41,7 @@ jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception! ## `jax_disable_jit` configuration option and context manager -**TL;DR** Enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb` +**Summary:** Enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb` `jax_disable_jit` is a JAX flag that when enabled, disables JIT-compilation throughout JAX (including in control flow functions like `jax.lax.cond` and `jax.lax.scan`). diff --git a/docs/debugging/index.md b/docs/debugging/index.md index 724827f837e3..46523d681512 100644 --- a/docs/debugging/index.md +++ b/docs/debugging/index.md @@ -2,7 +2,7 @@ -Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the "Read more" links at the bottom to learn more. +Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has summaries and you can click the "Read more" links at the bottom to learn more. Table of contents: @@ -12,7 +12,7 @@ Table of contents: ## [Interactive inspection with `jax.debug`](print_breakpoint) - **TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions, + **Summary:** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions, and {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack: ```python @@ -38,7 +38,7 @@ Click [here](print_breakpoint) to learn more! ## [Functional error checks with `jax.experimental.checkify`](checkify_guide) - **TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: + **Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: ```python from jax.experimental import checkify @@ -81,7 +81,7 @@ Click [here](checkify_guide) to learn more! ## [Throwing Python errors with JAX's debug flags](flags) -**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`. +**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`. ```python import jax diff --git a/docs/debugging/print_breakpoint.md b/docs/debugging/print_breakpoint.md index d7cb68bd1b0b..d33498697bb3 100644 --- a/docs/debugging/print_breakpoint.md +++ b/docs/debugging/print_breakpoint.md @@ -7,7 +7,8 @@ inside of JIT-ted functions. ## Debugging with `jax.debug.print` and other debugging callbacks -**TL;DR** Use {func}`jax.debug.print` to print traced array values to stdout in `jit`- and `pmap`-decorated functions: +**Summary:** Use {func}`jax.debug.print` to print traced array values to +stdout in compiled (e.g. `jax.jit` or `jax.pmap`-decorated) functions: ```python import jax @@ -236,7 +237,7 @@ Furthermore, when using `jax.debug.print` with `jax.pjit`, a global synchronizat ## Interactive inspection with `jax.debug.breakpoint()` -**TL;DR** Use `jax.debug.breakpoint()` to pause the execution of your JAX program to inspect values: +**Summary:** Use `jax.debug.breakpoint()` to pause the execution of your JAX program to inspect values: ```python @jax.jit diff --git a/docs/installation.md b/docs/installation.md index bd0473d89201..4a831750e42b 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -7,7 +7,7 @@ Using JAX requires installing two packages: `jax`, which is pure Python and cross-platform, and `jaxlib` which contains compiled binaries, and requires different builds for different operating systems and accelerators. -**TL;DR** For most users, a typical JAX installation may look something like this: +**Summary:** For most users, a typical JAX installation may look something like this: * **CPU-only (Linux/macOS/Windows)** ``` diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index ec85f6e63159..6767b33a20f0 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -30,7 +30,7 @@ "id": "9Fg3NFNY-2RY" }, "source": [ - "## TL;DR" + "## Summary" ] }, { diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index 6c948650fcc1..000d48c49b18 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -32,7 +32,7 @@ For an introduction to JAX's automatic differentiation API, see [The Autodiff Co +++ {"id": "9Fg3NFNY-2RY"} -## TL;DR +## Summary +++ {"id": "ZgMNRtXyWIW8"} diff --git a/docs/notebooks/autodiff_remat.ipynb b/docs/notebooks/autodiff_remat.ipynb index 041cf65314f2..82381838a5aa 100644 --- a/docs/notebooks/autodiff_remat.ipynb +++ b/docs/notebooks/autodiff_remat.ipynb @@ -27,7 +27,7 @@ "id": "qaIsQSh1XoKF" }, "source": [ - "### TL;DR\n", + "### Summary\n", "\n", "Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.\n", "\n", diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md index a4fb27c58128..0a6c84b2d88f 100644 --- a/docs/notebooks/autodiff_remat.md +++ b/docs/notebooks/autodiff_remat.md @@ -24,7 +24,7 @@ import jax.numpy as jnp +++ {"id": "qaIsQSh1XoKF"} -### TL;DR +### Summary Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs. From 763d600508a8307a381f5d1f75efae3e75ffd43b Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 28 Aug 2024 10:59:05 -0700 Subject: [PATCH 2/5] docs: remove inline authors/dates on misc doc pages Git history covers this better and automatically. --- docs/notebooks/Common_Gotchas_in_JAX.ipynb | 2 -- docs/notebooks/Common_Gotchas_in_JAX.md | 2 -- docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb | 2 -- docs/notebooks/Custom_derivative_rules_for_Python_code.md | 2 -- docs/notebooks/autodiff_cookbook.ipynb | 2 -- docs/notebooks/autodiff_cookbook.md | 2 -- 6 files changed, 12 deletions(-) diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index c143b520af51..0cffc22f1e8d 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -19,8 +19,6 @@ "id": "4k5PVzEo2uJO" }, "source": [ - "*levskaya@ mattjj@*\n", - "\n", "When walking about the countryside of Italy, the people will not hesitate to tell you that __JAX__ has [_\"una anima di pura programmazione funzionale\"_](https://www.sscardapane.it/iaml-backup/jax-intro/).\n", "\n", "__JAX__ is a language for __expressing__ and __composing__ __transformations__ of numerical programs. __JAX__ is also able to __compile__ numerical programs for CPU or accelerators (GPU/TPU).\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 3324fdb53bcd..543d9ecb1558 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -22,8 +22,6 @@ kernelspec: +++ {"id": "4k5PVzEo2uJO"} -*levskaya@ mattjj@* - When walking about the countryside of Italy, the people will not hesitate to tell you that __JAX__ has [_"una anima di pura programmazione funzionale"_](https://www.sscardapane.it/iaml-backup/jax-intro/). __JAX__ is a language for __expressing__ and __composing__ __transformations__ of numerical programs. __JAX__ is also able to __compile__ numerical programs for CPU or accelerators (GPU/TPU). diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index 6767b33a20f0..88d446723dba 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -12,8 +12,6 @@ "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n", "\n", - "*mattjj@ Mar 19 2020, last updated Oct 14 2020*\n", - "\n", "There are two ways to define differentiation rules in JAX:\n", "\n", "1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and\n", diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index 000d48c49b18..fdf4b3ed0c8d 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -19,8 +19,6 @@ kernelspec: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) -*mattjj@ Mar 19 2020, last updated Oct 14 2020* - There are two ways to define differentiation rules in JAX: 1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and diff --git a/docs/notebooks/autodiff_cookbook.ipynb b/docs/notebooks/autodiff_cookbook.ipynb index 478d84935b7f..3f2f0fd5650d 100644 --- a/docs/notebooks/autodiff_cookbook.ipynb +++ b/docs/notebooks/autodiff_cookbook.ipynb @@ -12,8 +12,6 @@ "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n", "\n", - "*alexbw@, mattjj@* \n", - "\n", "JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics." ] }, diff --git a/docs/notebooks/autodiff_cookbook.md b/docs/notebooks/autodiff_cookbook.md index 6615e65352d7..496d676f794a 100644 --- a/docs/notebooks/autodiff_cookbook.md +++ b/docs/notebooks/autodiff_cookbook.md @@ -20,8 +20,6 @@ kernelspec: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) -*alexbw@, mattjj@* - JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics. ```{code-cell} ipython3 From 42de34263fd2f3bd9976d193f0a9dd892596202b Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 28 Aug 2024 11:01:43 -0700 Subject: [PATCH 3/5] docs: runtime debugging tweaks Mainly make titles/headings easier to read, by swapping code for words and not using headings as links. --- docs/debugging/index.md | 18 ++++++++++++------ docs/debugging/print_breakpoint.md | 5 ++--- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/docs/debugging/index.md b/docs/debugging/index.md index 46523d681512..bcf561d06807 100644 --- a/docs/debugging/index.md +++ b/docs/debugging/index.md @@ -10,7 +10,9 @@ Table of contents: * [Functional error checks with jax.experimental.checkify](checkify_guide) * [Throwing Python errors with JAX’s debug flags](flags) -## [Interactive inspection with `jax.debug`](print_breakpoint) +## Interactive inspection with `jax.debug` + +Complete guide [here](print_breakpoint) **Summary:** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions, and {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack: @@ -34,9 +36,11 @@ Table of contents: # 🤯 0.9092974662780762 🤯 ``` -Click [here](print_breakpoint) to learn more! +[Read more](print_breakpoint). + +## Functional error checks with `jax.experimental.checkify` -## [Functional error checks with `jax.experimental.checkify`](checkify_guide) +Complete guide [here](checkify_guide) **Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: @@ -77,9 +81,11 @@ Click [here](print_breakpoint) to learn more! # ValueError: nan generated by primitive sin at <...>:8 (f) ``` -Click [here](checkify_guide) to learn more! +[Read more](checkify_guide). + +## Throwing Python errors with JAX's debug flags -## [Throwing Python errors with JAX's debug flags](flags) +Complete guide [here](flags) **Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`. @@ -92,7 +98,7 @@ def f(x, y): jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception! ``` -Click [here](flags) to learn more! +[Read more](flags). ```{toctree} :caption: Read more diff --git a/docs/debugging/print_breakpoint.md b/docs/debugging/print_breakpoint.md index d33498697bb3..73ac0262851d 100644 --- a/docs/debugging/print_breakpoint.md +++ b/docs/debugging/print_breakpoint.md @@ -1,9 +1,9 @@ -# `jax.debug.print` and `jax.debug.breakpoint` +# Compiled prints and breakpoints The {mod}`jax.debug` package offers some useful tools for inspecting values -inside of JIT-ted functions. +inside of compiled functions. ## Debugging with `jax.debug.print` and other debugging callbacks @@ -27,7 +27,6 @@ f(2.) # 🤯 0.9092974662780762 🤯 ``` - With some transformations, like `jax.grad` and `jax.vmap`, you can use Python's builtin `print` function to print out numerical values. But `print` won't work with `jax.jit` or `jax.pmap` because those transformations delay numerical evaluation. So use `jax.debug.print` instead! Semantically, `jax.debug.print` is roughly equivalent to the following Python function From 7b297912008ffe2a2b6f0417a73494bd88ccb696 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 28 Aug 2024 11:19:53 -0700 Subject: [PATCH 4/5] docs: shorten/clarify some page titles Shorten the titles for derivative rules and manual parallelism, and go for canonical wording in the parallel programming intro (typically we "shard" data, and "partition" computation, as part of parallel programming). --- docs/debugging.md | 6 +++--- docs/distributed_data_loading.md | 2 +- docs/multi_process.md | 2 +- .../notebooks/Custom_derivative_rules_for_Python_code.ipynb | 2 +- docs/notebooks/Custom_derivative_rules_for_Python_code.md | 2 +- docs/notebooks/shard_map.ipynb | 2 +- docs/notebooks/shard_map.md | 2 +- docs/sharded-computation.ipynb | 2 +- docs/sharded-computation.md | 2 +- 9 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/debugging.md b/docs/debugging.md index 1e8501f99e39..d07f42da5c85 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -21,7 +21,7 @@ This section introduces you to a set of built-in JAX debugging methods — {func Let's begin with {func}`jax.debug.print`. -## JAX `debug.print` for high-level +## `jax.debug.print` for simple inspection Here is a rule of thumb: @@ -111,7 +111,7 @@ f(1, 2) To learn more about {func}`jax.debug.print` and its Sharp Bits, refer to {ref}`advanced-debugging`. -## JAX `debug.breakpoint` for `pdb`-like debugging +## `jax.debug.breakpoint` for `pdb`-like debugging **Summary:** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values. @@ -160,7 +160,7 @@ f(2., 1.) # ==> No breakpoint f(2., 0.) # ==> Pauses during execution ``` -## JAX `debug.callback` for more control during debugging +## `jax.debug.callback` for more control during debugging Both {func}`jax.debug.print` and {func}`jax.debug.breakpoint` are implemented using the more flexible {func}`jax.debug.callback`, which gives greater control over the diff --git a/docs/distributed_data_loading.md b/docs/distributed_data_loading.md index 14fb1bb55c35..d7b88be44178 100644 --- a/docs/distributed_data_loading.md +++ b/docs/distributed_data_loading.md @@ -12,7 +12,7 @@ kernelspec: name: python3 --- -# Distributed data loading in multi-host / multi-process environments +# Distributed data loading diff --git a/docs/multi_process.md b/docs/multi_process.md index 7d7083bde10f..32cfae126784 100644 --- a/docs/multi_process.md +++ b/docs/multi_process.md @@ -1,4 +1,4 @@ -# Using JAX in multi-host and multi-process environments +# Multi-host and multi-process environments diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index 88d446723dba..dd7a36e57079 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -6,7 +6,7 @@ "id": "LqiaKasFjH82" }, "source": [ - "# Custom derivative rules for JAX-transformable Python functions\n", + "# Custom derivative rules\n", "\n", "\n", "\n", diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index fdf4b3ed0c8d..930887af1e1b 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -13,7 +13,7 @@ kernelspec: +++ {"id": "LqiaKasFjH82"} -# Custom derivative rules for JAX-transformable Python functions +# Custom derivative rules diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index 157d6c567b24..1315783c340c 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -5,7 +5,7 @@ "id": "41a7e222", "metadata": {}, "source": [ - "# SPMD multi-device parallelism with `shard_map`\n", + "# Manual parallelism with `shard_map`\n", "\n", "\n", "\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index 21e4111d5bf9..96667e709ac6 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -14,7 +14,7 @@ kernelspec: name: python3 --- -# SPMD multi-device parallelism with `shard_map` +# Manual parallelism with `shard_map` diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index b7bc919ebcde..22d9156f607b 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -5,7 +5,7 @@ "metadata": {}, "source": [ "(sharded-computation)=\n", - "# Introduction to sharded computation\n", + "# Introduction to parallel programming\n", "\n", "\n", "\n", diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index 4f8c1b0201bc..e6e1948f9902 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -12,7 +12,7 @@ kernelspec: --- (sharded-computation)= -# Introduction to sharded computation +# Introduction to parallel programming From 180573a8c89f6dc42f42e52cd92c90bee7135daf Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 28 Aug 2024 11:21:52 -0700 Subject: [PATCH 5/5] docs: tweak JAX-Toolbox url anchor for clarity --- docs/investigating_a_regression.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/investigating_a_regression.md b/docs/investigating_a_regression.md index 9b712056e9bc..389cc0b5a9e8 100644 --- a/docs/investigating_a_regression.md +++ b/docs/investigating_a_regression.md @@ -25,7 +25,7 @@ Here is a suggested investigation strategy: ## Nightly investigation -This can be done by using [JAX-Toolbox nightly +This can be done by using the [NVIDIA JAX-Toolbox nightly containers](https://github.com/NVIDIA/JAX-Toolbox). - Some days, bugs prevent the container from being built, or there are temporary regressions. Just discard those days.