Skip to content

Commit

Permalink
Merge pull request #15270 from skye:pjrt_c_api
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 520156646
  • Loading branch information
jax authors committed Mar 28, 2023
2 parents 5ae2e79 + 473d1c3 commit 8c4fed6
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
19 changes: 19 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,25 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.8

* Breaking changes
* A major component of the Cloud TPU runtime has been upgraded. This enables
the following new features on Cloud TPU:
* {func}`jax.debug.print`, {func}`jax.debug.callback`, and
{func}`jax.debug.breakpoint()` now work on Cloud TPU
* Automatic TPU memory defragmentation

{func}`jax.experimental.host_callback` is no longer supported on Cloud TPU
with the new runtime component. Please file an issue on the [JAX issue
tracker](https://github.com/google/jax/issues) if the new `jax.debug` APIs
are insufficient for your use case.

The old runtime component will be available for at least the next three
months by setting the environment variable
`JAX_USE_PJRT_C_API_ON_TPU=false`. If you find you need to disable the new
runtime for any reason, please let us know on the [JAX issue
tracker](https://github.com/google/jax/issues).


* Changes
* The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.

Expand Down
6 changes: 2 additions & 4 deletions docs/debugging/print_breakpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def f(x):
y = jnp.sin(x)
jax.debug.print("🤯 {y} 🤯", y=y)
return y

f(2.)
# Prints:
# 🤯 2.0 🤯
Expand Down Expand Up @@ -225,8 +225,7 @@ Furthermore, when using `jax.debug.print` with `jax.pjit`, a global synchronizat
#### Limitations
* Adding print statements is a manual process
* Can have performance impacts
* Unsupported on Cloud TPUs


## Interactive inspection with `jax.debug.breakpoint()`

**TL;DR** Use `jax.debug.breakpoint()` to pause the execution of your JAX program to inspect values:
Expand Down Expand Up @@ -296,4 +295,3 @@ Because `jax.debug.breakpoint` is a just an application of `jax.debug.callback`,
#### Limitations
* Need to potentially use many breakpoints to pinpoint the source of an error
* Materializes many intermediates
* Unsupported on Cloud TPUs
3 changes: 3 additions & 0 deletions jax/_src/cloud_tpu_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,6 @@ def cloud_tpu_init() -> None:
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu')
os.environ['TPU_ML_PLATFORM'] = 'JAX'

if 'JAX_USE_PJRT_C_API_ON_TPU' not in os.environ:
os.environ['JAX_USE_PJRT_C_API_ON_TPU'] = 'true'

0 comments on commit 8c4fed6

Please sign in to comment.