From 473d1c3685570e505bed58512afa05fb7d7a8935 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 28 Mar 2023 13:42:51 -0700 Subject: [PATCH] Turn on PJRT C API by default. I forgot that the default setting is actually in jaxlib: https://github.com/openxla/xla/blob/fbe9a80fdb8c429e8a175962459da348cd560a50/xla/python/xla_client.py#L135 To be able to make this change as a jax-only release, I manually set the env var on Cloud TPU if it isn't already set. --- CHANGELOG.md | 19 +++++++++++++++++++ docs/debugging/print_breakpoint.md | 6 ++---- jax/_src/cloud_tpu_init.py | 3 +++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d95cac920f24..6c31d7d8ffe2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,25 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.8 +* Breaking changes + * A major component of the Cloud TPU runtime has been upgraded. This enables + the following new features on Cloud TPU: + * {func}`jax.debug.print`, {func}`jax.debug.callback`, and + {func}`jax.debug.breakpoint()` now work on Cloud TPU + * Automatic TPU memory defragmentation + + {func}`jax.experimental.host_callback` is no longer supported on Cloud TPU + with the new runtime component. Please file an issue on the [JAX issue + tracker](https://github.com/google/jax/issues) if the new `jax.debug` APIs + are insufficient for your use case. + + The old runtime component will be available for at least the next three + months by setting the environment variable + `JAX_USE_PJRT_C_API_ON_TPU=false`. If you find you need to disable the new + runtime for any reason, please let us know on the [JAX issue + tracker](https://github.com/google/jax/issues). + + * Changes * The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7. diff --git a/docs/debugging/print_breakpoint.md b/docs/debugging/print_breakpoint.md index 850d3bb78232..2093c7599698 100644 --- a/docs/debugging/print_breakpoint.md +++ b/docs/debugging/print_breakpoint.md @@ -17,7 +17,7 @@ def f(x): y = jnp.sin(x) jax.debug.print("🤯 {y} 🤯", y=y) return y - + f(2.) # Prints: # 🤯 2.0 🤯 @@ -225,8 +225,7 @@ Furthermore, when using `jax.debug.print` with `jax.pjit`, a global synchronizat #### Limitations * Adding print statements is a manual process * Can have performance impacts -* Unsupported on Cloud TPUs - + ## Interactive inspection with `jax.debug.breakpoint()` **TL;DR** Use `jax.debug.breakpoint()` to pause the execution of your JAX program to inspect values: @@ -296,4 +295,3 @@ Because `jax.debug.breakpoint` is a just an application of `jax.debug.callback`, #### Limitations * Need to potentially use many breakpoints to pinpoint the source of an error * Materializes many intermediates -* Unsupported on Cloud TPUs diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index d789ab79c786..e22faf4b1efa 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -67,3 +67,6 @@ def cloud_tpu_init() -> None: os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu') os.environ['TPU_ML_PLATFORM'] = 'JAX' + + if 'JAX_USE_PJRT_C_API_ON_TPU' not in os.environ: + os.environ['JAX_USE_PJRT_C_API_ON_TPU'] = 'true'