From 9171e66d9e4e8a893d9d178541da8247a5334665 Mon Sep 17 00:00:00 2001 From: Keshav Date: Wed, 18 Sep 2024 13:43:00 -0700 Subject: [PATCH] Squashed commit of the following: commit 093c6e9cd90856e840b6251c23422bf0ad7b82ff Merge: e1a77eee2 d0cb3182a Author: Keshav Date: Wed Sep 18 13:37:44 2024 -0700 Merge remote-tracking branch 'upstream/main' into disable_remat_pass commit e1a77eee25aad201d761d67c9550ee194a2656b2 Author: Keshav Date: Wed Sep 18 13:35:37 2024 -0700 minor changes commit d0cb3182aaf0a20727ae521a9b8d06539931fc1b Merge: b51c65357 bef36c431 Author: jax authors Date: Wed Sep 18 13:34:11 2024 -0700 Merge pull request #23736 from hawkinsp:changelog PiperOrigin-RevId: 676111400 commit b51c65357f0ae9659e58e2ff0df871542124cddf Merge: dbc03cf8e 57a4b76d0 Author: jax authors Date: Wed Sep 18 13:33:05 2024 -0700 Merge pull request #23737 from jakevdp:digitize-doc PiperOrigin-RevId: 676111220 commit dbc03cf8e5a7ac8e1e6e8e593e063eb7f54990d1 Author: Dan Foreman-Mackey Date: Wed Sep 18 12:39:58 2024 -0700 Re-land #23261 with appropriate compatibility checks. PiperOrigin-RevId: 676092618 commit b164d67d4a9bd094426ff450fe1f1335d3071d03 Merge: cd04d0f32 541b3a3f7 Author: jax authors Date: Wed Sep 18 12:05:03 2024 -0700 Merge pull request #23247 from kaixih:sliding_window_attn PiperOrigin-RevId: 676079831 commit 57a4b76d09fb1eac160242bf1f31bc8b3841f82d Author: Jake VanderPlas Date: Wed Sep 18 11:59:00 2024 -0700 Improve documentation for jnp.digitize commit bef36c431d752b91372ac58d4cf6a84277dc600e Author: Peter Hawkins Date: Wed Sep 18 18:57:03 2024 +0000 Add Python 3.13 wheels to changelog. commit cd04d0f32e854aa754e37e4b676725655a94e731 Merge: 016c49951 c756d9b70 Author: jax authors Date: Wed Sep 18 10:00:03 2024 -0700 Merge pull request #23726 from hawkinsp:debug PiperOrigin-RevId: 676030839 commit 016c49951f670256ce4750cdfea182e3a2a15325 Author: Sergei Lebedev Date: Wed Sep 18 09:56:44 2024 -0700 Removed leftover usages of GPUGridSpec from Pallas Mosaic GPU tests PiperOrigin-RevId: 676029854 commit 9dd363da1298e4810b693a918fc2e8199094acdb Author: Luke Baumann Date: Wed Sep 18 09:28:25 2024 -0700 Export `jax.lib.xla_extension.ifrt_programs`. PiperOrigin-RevId: 676020419 commit e27f1e9b3a8af39d2791b95a8106106229cad238 Author: jax authors Date: Wed Sep 18 09:03:55 2024 -0700 Change Python version 3.13.0rc2 to 3.13.0-rc.2. The value is taken from [the versions manifest](https://raw.githubusercontent.com/actions/python-versions/main/versions-manifest.json). PiperOrigin-RevId: 676012255 commit 442e8630deff3f89d3ae756aac30bb71e7ba7cf2 Author: Sergei Lebedev Date: Wed Sep 18 08:56:49 2024 -0700 Added a missing branch to `mgpu.FragmentedArray.astype` Previously, an unsupported cast produced a `NameError` instead. PiperOrigin-RevId: 676010161 commit 6236b8f4aef7344467ead740b92a32769a84aee1 Merge: 826843a22 1cc96616b Author: jax authors Date: Wed Sep 18 08:57:38 2024 -0700 Merge pull request #23667 from dfm:always-lower-jnp-dot-to-dot-general PiperOrigin-RevId: 676010154 commit c756d9b7033d1dcf3a389bcb4c9e65f9f5201019 Author: Peter Hawkins Date: Wed Sep 18 15:44:45 2024 +0000 Fix error in debugger tests that is showing up in CI. I'm unsure why this started happening now, but sometimes we get an invalid offset for a frame. Be tolerant of that case. commit 826843a22d1dbbd13455a56a639d2de9abef1682 Merge: c191bbcdb 922e652c0 Author: jax authors Date: Wed Sep 18 08:42:39 2024 -0700 Merge pull request #23723 from hawkinsp:setuptools PiperOrigin-RevId: 676005613 commit c191bbcdb162bec58494e818f42feb457cd0a287 Author: Yash Katariya Date: Wed Sep 18 08:40:30 2024 -0700 Make `debug.print` work with static args. Fixes: https://github.com/google/jax/issues/23600 PiperOrigin-RevId: 676005582 commit 1cc96616baab7394117c2ecb9b64db22fd82dc44 Author: Dan Foreman-Mackey Date: Mon Sep 16 14:18:29 2024 -0400 Unconditionally lower jnp.dot to lax.dot_general. https://github.com/google/jax/pull/16721 added a condition to lower calls to `jnp.dot` with scalar inputs to `lax.mul` instead of `lax.dot_general`. AFAICT, https://github.com/google/jax/pull/16826 fixed the issue that this was solving, so this condition should no longer be necessary. Removing this condition simplifies the addition of new arguments to `dot` and `dot_general`, including the `algorithm` parameter that I am currently working on in https://github.com/google/jax/pull/23574, so now seemed like a good time to remove it! commit 922e652c05654b9e2278a88a22cb70b94fdd6f46 Author: Peter Hawkins Date: Wed Sep 18 15:17:49 2024 +0000 Replace plat-name with plat_name. The former seems to elicit a deprecation warning from setuptools recently. commit 69ba060957529fc9babf838af8d47a2626615e62 Author: Dan Foreman-Mackey Date: Wed Sep 18 07:40:58 2024 -0700 Reverts e15ec1e8abe3732d747731c15a36facf4169739e PiperOrigin-RevId: 675987338 commit 44a7f0439c38f465d94388cecfcd5af463c2c2fe Merge: 0a29696a9 2834c135a Author: jax authors Date: Wed Sep 18 07:31:00 2024 -0700 Merge pull request #23708 from jakevdp:sort-complex PiperOrigin-RevId: 675983957 commit 0a29696a97bca016af04f3b3cb59e94dd6b0615b Merge: e15ec1e8a 73c38cb70 Author: jax authors Date: Wed Sep 18 07:08:24 2024 -0700 Merge pull request #23698 from dfm:dev-clang-warning PiperOrigin-RevId: 675977448 commit 2834c135a34636af9de73b421e8fbb6731b20c4e Author: Jake VanderPlas Date: Tue Sep 17 15:32:25 2024 -0700 jnp.sort_complex: fix output for N-dimensional inputs commit 73c38cb7009b52706d3769aaec6d32046aced508 Author: Dan Foreman-Mackey Date: Tue Sep 17 14:00:21 2024 -0400 Add a note to the developer docs making it clear that clang is the only toolchain that is actively supported for source compilation. As discussed in https://github.com/google/jax/issues/23687 commit e15ec1e8abe3732d747731c15a36facf4169739e Merge: 48d8fce73 3f2bc9b60 Author: jax authors Date: Wed Sep 18 06:56:28 2024 -0700 Merge pull request #23261 from joaospinto:stablehlo.tan PiperOrigin-RevId: 675973798 commit 48d8fce731801ae2a90423b2f3cf31544290e0ef Merge: 4e6f69072 271446939 Author: jax authors Date: Wed Sep 18 06:54:28 2024 -0700 Merge pull request #23563 from rajasekharporeddy:testbranch1 PiperOrigin-RevId: 675973225 commit 4e6f6907242653daa485462be4ae68b9a5b01514 Merge: b7c91e90c 611ad6306 Author: jax authors Date: Wed Sep 18 06:35:15 2024 -0700 Merge pull request #23653 from apaszke:torchsaic PiperOrigin-RevId: 675967844 commit b7c91e90c2f3645f5270b6eb7aea5882852eb7b1 Author: Sergei Lebedev Date: Wed Sep 18 06:22:14 2024 -0700 Lookup `shape` and `dtype` directly on `state.AbstractRef` instead of going through `inner_aval` This is just a cleanup. No behavior changes are expected. PiperOrigin-RevId: 675964703 commit 611ad630603cffa88aa714bf876340af315dd819 Author: Adam Paszke Date: Fri Sep 6 16:09:58 2024 +0000 Add basic PyTorch integration for Mosaic GPU We have already had most of the relevant pieces and we only needed to connect them together. The most sensitive change is perhaps that I needed to expose one more symbol from the XLA GPU plugin, but I don't think it should be a problem. commit e90336947a7f763226e8609ea96bc49a64fdb2c9 Author: Sergei Lebedev Date: Wed Sep 18 05:25:37 2024 -0700 Pulled `scratch_shapes` into `GridSpec` It is supported by Mosaic TPU and Mosaic GPU and unsupported by Triton. PiperOrigin-RevId: 675950199 commit 2714469397c18041d6c5696448abb7abb916ba89 Author: rajasekharporeddy Date: Wed Sep 18 17:06:28 2024 +0530 Deprecate passing NdArrays with ndim != 1 and non-arraylike inputs to jnp.trim_zeros commit b904599b98cca5fb73a387911afc685b290b623b Author: Sergei Lebedev Date: Wed Sep 18 04:23:25 2024 -0700 `pl.debug_print` no longer restricts values to be scalars This allows printing arrays on Triton and soon on Mosaic GPU. PiperOrigin-RevId: 675935666 commit 988ed2bd75df5fe25b74eaf38075aadff19be207 Author: jax authors Date: Tue Sep 17 21:09:26 2024 -0700 Add support for SMEM windows in Pallas custom pipeline. PiperOrigin-RevId: 675822640 commit f79d85ba8d7a3f59d8b36f2bb0f40cba9aa46b6d Merge: 1b74cfde8 cc28d639c Author: Keshav Date: Tue Sep 17 18:58:33 2024 -0700 Merge remote-tracking branch 'upstream/main' into disable_remat_pass commit cc28d639cdfb725dd8b7177cf71616668f4f60f6 Merge: 8bcdb1285 9d3762bd4 Author: jax authors Date: Tue Sep 17 17:36:36 2024 -0700 Merge pull request #23682 from sharadmv:pallas-async-docs PiperOrigin-RevId: 675770723 commit 1b74cfde8f388984bbcd540ebb2a588f91ba0cf3 Author: Keshav Date: Tue Sep 17 17:23:30 2024 -0700 disable remat hlo pass by default commit 8bcdb1285218d42e051882b33abf65a75649488b Author: jax authors Date: Tue Sep 17 16:50:55 2024 -0700 Add CI jobs for python 3.13.0rc2. PiperOrigin-RevId: 675758096 commit 8b5b71750b009fdd979dfd0abeb43a359a60c664 Author: Yash Katariya Date: Tue Sep 17 16:39:55 2024 -0700 Fix jaxpr equation context propagation in jaxpr equations when `inline=True`. PiperOrigin-RevId: 675754808 commit 86fe463ad7221de7f4078fcbba9c6bf1af0b19ba Author: Parker Schuh Date: Tue Sep 17 16:10:41 2024 -0700 [Take 2] Generalize global jit cpp cache keys so we can add more keys than the current donate_argnums. This allows us to get more cache hits globally. For example: Before: jax.jit(f, out_shardings=s)(arr) jax.jit(f, out_shardings=s)(arr) # cpp cache miss After: jax.jit(f, out_shardings=s)(arr) jax.jit(f, out_shardings=s)(arr) # cpp cache hit Reverts b615266175effe4aefeb903620a19f3719a604da PiperOrigin-RevId: 675746175 commit e92a599a96374064d53a7230086992e989af542a Author: Christos Perivolaropoulos Date: Tue Sep 17 15:26:42 2024 -0700 [mosaic_gpu] Better error message for misaligned tma_transpose with dtype. PiperOrigin-RevId: 675731295 commit 78646484950fb01b1915114b8fa34054237b841f Merge: 3f2c58b9c 83a7555ff Author: jax authors Date: Tue Sep 17 15:12:50 2024 -0700 Merge pull request #23679 from selamw1:docstring_sort_complex PiperOrigin-RevId: 675726527 commit 83a7555ffd355545c1f3d4642eaf4ab5d18ebcc8 Author: selamw1 Date: Mon Sep 16 16:47:52 2024 -0700 docstring_sort_complex_added input_array_modified commit 9d3762bd476b95a187bab22284e62525901255f7 Author: Sharad Vikram Date: Mon Sep 16 19:18:22 2024 -0700 [Pallas] Add design note for async ops on TPU commit 3f2bc9b60846b8a32c29804ba5e8caac7766add7 Author: Joao Sousa-Pinto Date: Mon Aug 26 17:25:16 2024 -0700 Lower tan to StableHLO instead of CHLO. Fixes #23259 commit 541b3a3f7565b0e3f826b388dd094d22b28efb54 Author: kaixih Date: Mon Aug 26 17:32:38 2024 +0000 New feature --- .bazelrc | 4 + .github/workflows/wheel_win_x64.yml | 2 +- CHANGELOG.md | 9 + build/requirements.in | 2 + build/requirements_lock_3_13.txt | 10 +- docs/developer.md | 32 +- docs/pallas/CHANGELOG.md | 32 +- docs/pallas/async_note.md | 675 ++++++++++++++++++ docs/pallas/index.rst | 7 + jax/_src/api.py | 6 +- jax/_src/compiler.py | 14 + jax/_src/core.py | 3 +- jax/_src/debugger/core.py | 5 + jax/_src/debugging.py | 26 +- jax/_src/deprecations.py | 1 + jax/_src/interpreters/partial_eval.py | 3 +- jax/_src/interpreters/pxla.py | 42 +- jax/_src/lax/lax.py | 12 +- jax/_src/nn/functions.py | 58 +- jax/_src/numpy/lax_numpy.py | 124 +++- jax/_src/pallas/core.py | 32 +- jax/_src/pallas/mosaic/core.py | 22 +- jax/_src/pallas/mosaic/lowering.py | 3 + jax/_src/pallas/mosaic/pipeline.py | 66 +- jax/_src/pallas/mosaic_gpu/__init__.py | 1 - jax/_src/pallas/mosaic_gpu/core.py | 22 +- jax/_src/pallas/mosaic_gpu/lowering.py | 10 +- jax/_src/pallas/pallas_call.py | 16 +- jax/_src/pallas/primitives.py | 12 +- jax/_src/pallas/triton/lowering.py | 13 +- jax/_src/pjit.py | 119 ++- jax/_src/state/discharge.py | 3 +- jax/_src/state/types.py | 18 +- jax/experimental/mosaic/gpu/__init__.py | 121 +++- .../mosaic/gpu/examples/matmul.py | 2 +- .../mosaic/gpu/fragmented_array.py | 2 + jax/experimental/multihost_utils.py | 11 +- jax/extend/BUILD | 6 + jax/extend/ifrt_programs.py | 22 + jax/numpy/__init__.pyi | 3 +- jaxlib/mosaic/gpu/custom_call.cc | 94 ++- jaxlib/tools/BUILD.bazel | 3 +- jaxlib/tools/build_gpu_kernels_wheel.py | 2 +- jaxlib/tools/build_gpu_plugin_wheel.py | 2 +- jaxlib/tools/build_wheel.py | 2 +- jaxlib/tools/gpu_version_script.lds | 11 + tests/debugging_primitives_test.py | 12 + tests/filecheck/math.filecheck.py | 2 +- tests/lax_numpy_test.py | 16 +- tests/memories_test.py | 23 + tests/mosaic/gpu_test.py | 24 + tests/nn_test.py | 19 +- tests/pallas/mosaic_gpu_test.py | 30 +- tests/pjit_test.py | 33 +- 54 files changed, 1536 insertions(+), 308 deletions(-) create mode 100644 docs/pallas/async_note.md create mode 100644 jax/extend/ifrt_programs.py create mode 100644 jaxlib/tools/gpu_version_script.lds diff --git a/.bazelrc b/.bazelrc index 9d5d9664939e..948d92c29c26 100644 --- a/.bazelrc +++ b/.bazelrc @@ -215,6 +215,8 @@ build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base build:rbe_cpu_linux_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11" build:rbe_cpu_linux_py3.12 --config=rbe_cpu_linux_base build:rbe_cpu_linux_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12" +build:rbe_cpu_linux_py3.13 --config=rbe_cpu_linux_base +build:rbe_cpu_linux_py3.13 --repo_env HERMETIC_PYTHON_VERSION="3.13" build:rbe_linux_cuda_base --config=rbe_linux build:rbe_linux_cuda_base --config=cuda @@ -237,6 +239,8 @@ build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base build:rbe_linux_cuda12.3_nvcc_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11" build:rbe_linux_cuda12.3_nvcc_py3.12 --config=rbe_linux_cuda12.3_nvcc_base build:rbe_linux_cuda12.3_nvcc_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12" +build:rbe_linux_cuda12.3_nvcc_py3.13 --config=rbe_linux_cuda12.3_nvcc_base +build:rbe_linux_cuda12.3_nvcc_py3.13 --repo_env HERMETIC_PYTHON_VERSION="3.13" # These you may need to change for your own GCP project. build:tensorflow_testing_rbe --project_id=tensorflow-testing diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 447ccba4f8c2..367f8e05bf56 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -17,7 +17,7 @@ jobs: matrix: os: [windows-2019-32core] arch: [AMD64] - pyver: ['3.10', '3.11', '3.12'] + pyver: ['3.10', '3.11', '3.12', '3.13.0-rc.2'] name: ${{ matrix.os }} ${{ matrix.pyver }} jaxlib wheel build runs-on: ${{ matrix.os }} diff --git a/CHANGELOG.md b/CHANGELOG.md index b34eee046856..659a8ee04db0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,15 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.34 +* New Functionality + * This release includes wheels for Python 3.13. Free-threading mode is not yet + supported. + +* Deprecations + * In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike + arguments with `ndim != 1` are now deprecated, and in the future will result + in an error. + * Deletion: * `jax.xla_computation` is deleted. It's been 3 months since it's deprecation in 0.4.30 JAX release. diff --git a/build/requirements.in b/build/requirements.in index f6b5b18b2660..a8d81fa5c670 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -23,3 +23,5 @@ ml_dtypes>=0.4.0 opt_einsum zstandard etils[epath] +# TODO(ybaturina): remove setuptools version +setuptools<71.0.0 diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index ef121b73713b..e2369a8001bb 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -732,7 +732,9 @@ zstandard==0.23.0 \ # via -r build/requirements.in # The following packages are considered to be unsafe in a requirements file: -setuptools==75.1.0 \ - --hash=sha256:35ab7fd3bcd95e6b7fd704e4a1539513edad446c097797f2985e0e4b960772f2 \ - --hash=sha256:d59a21b17a275fb872a9c3dae73963160ae079f1049ed956880cd7c09b120538 - # via -r build/test-requirements.txt +setuptools==70.3.0 \ + --hash=sha256:f171bab1dfbc86b132997f26a119f6056a57950d058587841a0082e8830f9dc5 \ + --hash=sha256:fe384da74336c398e0d956d1cae0669bc02eed936cdb1d49b57de1990dc11ffc + # via + # -r build/requirements.in + # -r build/test-requirements.txt diff --git a/docs/developer.md b/docs/developer.md index 53b6f0cf0f45..40ad51e873ca 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -31,23 +31,33 @@ guidance on pip installation (e.g., for GPU and TPU support). ### Building `jaxlib` from source +```{warning} +While it should typically be possible to compile `jaxlib` from source using +most modern compilers, the builds are only tested using clang. Pull requests +are welcomed to improve support for different toolchains, but other compilers +are not actively supported. +``` + To build `jaxlib` from source, you must also install some prerequisites: -- a C++ compiler (g++, clang, or MSVC) +- A C++ compiler: - On Ubuntu or Debian you can install the necessary prerequisites with: + As mentioned in the box above, it is best to use a recent version of clang + (at the time of writing, the version we test is 18), but other compilers (e.g. + g++ or MSVC) may work. - ``` - sudo apt install g++ python python3-dev - ``` + On Ubuntu or Debian you can follow the instructions from the + [LLVM](https://apt.llvm.org/) documentation to install the latest stable + version of clang. If you are building on a Mac, make sure XCode and the XCode command line tools are installed. See below for Windows build instructions. -- there is no need to install Python dependencies locally, as your system - Python will be ignored during the build; please check +- Python: for running the build helper script. Note that there is no need to + install Python dependencies locally, as your system Python will be ignored + during the build; please check [Managing hermetic Python](#managing-hermetic-python) for details. To build `jaxlib` for CPU or TPU, you can run: @@ -86,7 +96,7 @@ the `build/build.py` script itself will be processed by your system Python interpreter. By default, the wheel is written to the `dist/` subdirectory of the current directory. -* JAX versions starting from v.0.4.32: you can provide custom CUDA and CUDNN +* JAX versions starting from v.0.4.32: you can provide custom CUDA and CUDNN versions in the configuration options. Bazel will download them and use as target dependencies. @@ -259,8 +269,8 @@ together with their corresponding hashes are specified in `build/requirements_lock_.txt` files ( e.g. `build/requirements_lock_3_12.txt` for `Python 3.12`). -To update the lock files, make sure `build/requirements.in` contains the desired -direct dependencies list and then execute the following command (which will call +To update the lock files, make sure `build/requirements.in` contains the desired +direct dependencies list and then execute the following command (which will call [pip-compile](https://pypi.org/project/pip-tools/) under the hood): ``` @@ -382,7 +392,7 @@ sudo apt-get install libopenblas-dev -y example for `Python 3.13` it should have something like `"3.13": "//build:requirements_lock_3_13.txt"`. Note, the key in the `requirements` parameter must always be in `"major.minor"` version format, so - even if you are building Python version `3.13.0rc1` the corresponding + even if you are building Python version `3.13.0rc1` the corresponding `requirements` entry must still be `"3.13": "//build:requirements_lock_3_13.txt"`, **not** `"3.13.0rc1": "//build:requirements_lock_3_13_0rc1.txt"`. diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index c1ed1385bbbc..43ba3ebd6afb 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -11,15 +11,31 @@ For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/c Remember to align the itemized text with the first line of an item within a list. --> -## Released with jax 0.4.32 +## Released with jax 0.4.34 * Changes - * The kernel function is not allowed to close over constants. Instead, all the needed arrays - must be passed as inputs, with proper block specs ({jax-issue}`#22746`). + + * {func}`jax.experimental.pallas.debug_print` no longer requires all arguments + to be scalars. The restrictions on the arguments are backend-specific: + Non-scalar arguments are currently only supported on GPU, when using Triton. * Deprecations -* New functionality: +* New functionality + + * {func}`jax.experimental.pallas.pallas_call` now accepts `scratch_shapes`, + a PyTree specifying backend-specific temporary objects needed by the + kernel, for example, buffers, synchronization primitives etc. + +## Released with jax 0.4.33 (September 16, 2024) + +## Released with jax 0.4.32 (September 11, 2024) + +* Changes + * The kernel function is not allowed to close over constants. Instead, all the needed arrays + must be passed as inputs, with proper block specs ({jax-issue}`#22746`). + +* New functionality * Improved error messages for mistakes in the signature of the index map functions, to include the name and source location of the index map. @@ -44,10 +60,6 @@ Remember to align the itemized text with the first line of an item within a list * Previously it was possible to import many APIs that are meant to be private, as `jax.experimental.pallas.pallas`. This is not possible anymore. - -* Deprecations - - * New Functionality * Added documentation for BlockSpec: {ref}`pallas_grids_and_blockspecs`. * Improved error messages for the {func}`jax.experimental.pallas.pallas_call` @@ -73,7 +85,3 @@ Remember to align the itemized text with the first line of an item within a list * Added checkify support for {func}`jax.experimental.pallas.pallas_call` in interpret mode ({jax-issue}`#21862`). * Improved support for PRNG keys for TPU kernels ({jax-issue}`#21773`). - - - - diff --git a/docs/pallas/async_note.md b/docs/pallas/async_note.md new file mode 100644 index 000000000000..96370ee48625 --- /dev/null +++ b/docs/pallas/async_note.md @@ -0,0 +1,675 @@ +# Pallas Async Operations + +## Background \+ Motivation + +We’d like to expose APIs in Pallas to explicitly overlap computation and communication *across multiple kernels*. + +### XLA Async Decomposition + +As motivation, consider the following JAX pseudocode: + +```py +def f(x): + y = ppermute(x) + z = x + 1 + return y, z +``` + +In this function, we could perform the `ppermute` at the same time as the `x + 1`. This is an optimization XLA does automatically by: + +1. decomposing `ppermute` into a `ppermute_start` and `ppermute_done` op, which are connected via a future. +2. scheduling the `x + 1` between the `ppermute_start` and `ppermute_done`, + +resulting in the following program: + +```py +def f(x): + fut = ppermute_start(x) + z = x + 1 # happens at the same time as ppermute + y = ppermute_done(fut) + return y, z +``` + +### Async ops inside kernels + +Now imagine we aren’t using XLA’s `ppermute` but have our own custom Pallas `ppermute`. + +```py +def ppermute_kernel(x_ref, y_ref, send_sem, recv_sem): + right_neighbor = ... + descriptor = pltpu.make_remote_async_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor) + descriptor.start() + descriptor.wait_send() + descriptor.wait_recv() + +def ppermute(x): + return pl.pallas_call(ppermute_kernel, out_shape=x, ...)(x) +``` + +Currently, we cannot decompose `ppermute` into a `start/done` pair as XLA does, so instead we explicitly **fuse** the `x + 1` into the kernel. + +```py +def add_one(x_ref, z_ref): + z_ref[...] = x_ref[...] + 1 + +def ppermute_add_one_kernel(x_ref, y_ref, z_ref, send_sem, recv_sem): + right_neighbor = ... + descriptor = pltpu.make_remote_async_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor) + descriptor.start() + + # Explicitly schedule inner kernel between start/wait + pltpu.emit_pipeline(add_one)(x_ref, z_ref) + + descriptor.wait_send() + descriptor.wait_recv() + +def ppermute_and_add_one(x): + return pl.pallas_call(ppermute_add_one_kernel, out_shape=(x, x), ...)(x) + +``` + +The goal is to enable writing separate kernels for starting the `ppermute` and waiting on it to complete, so that we can use a regular old `x + 1` in between (or whatever compute we want). This makes the code more readable, maintainable, and less bug-prone. + +## How do we implement decomposed Pallas async operations (on TPU)? + +The main thing to figure out when implementing decomposed async operations in Pallas is what the `future` that is passed between them contains. Specifically, it must contain some important state about the operation happening in the background. + +If we look at the Pallas code, we can see that we need a “descriptor” to both start and wait on a remote copy. Can we plumb this descriptor out of the Pallas kernel, and then pass it into another one? Well kinda. The underlying TPU hardware tracks async op progress via a pair of semaphores: `send_sem` enables us to wait on when a device is done sending data to its neighbor and `recv_sem` tracks the data transfer sent to a device from their neighbor. If we imagine writing a start kernel and a done kernel, all we’d need to pass from the start to the done would be the semaphores and some information about how much to wait on those semaphores. + +We can do this via extending Pallas to support returning semaphores from kernels. + +```py +def ppermute_start_kernel( + in_ref, send_sem, recv_sem, out_ref, *, axis_name, +): + axis_size = jax.lax.psum(1, axis_name) + left_neighbor = jax.lax.rem( + jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size + ) + right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size) + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor) + pltpu.semaphore_wait(barrier_sem, 1) + pltpu.make_async_remote_copy( + in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor + ).start() + +def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array]: + send_sem, recv_sem, out = pl.pallas_call( + functools.partial(ppermute_start_kernel, axis_name=axis_name), + out_shape=( + pltpu.SemaphoreType.DMA(()), + pltpu.SemaphoreType.DMA(()), + jax.ShapeDtypeStruct( + x.shape, + dtype=x.dtype, + ), + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=( + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.ANY), + ), + )(x) + return send_sem, recv_sem, out +``` + +Note that something subtle is happening here. Pallas is telling XLA that it would like some outputs to be semaphores (a.k.a. sync flags) and XLA will treat them as “reserved” (e.g. while they are alive in the XLA program, those sync flags cannot be allocated by other kernels). They behave similarly to barrier semaphores, which are reserved semaphores managed by XLA. + +Another thing to notice is that we return the output buffer `out` from the start kernel *while it’s being actively copied into*. + +Now we write the `done` kernel that performs the blocking operation. We pass `out` into the kernel to compute the shape needed to block on the semaphore. + +```py +def ppermute_done_kernel(ref, send_sem, recv_sem, _): + pltpu.make_async_copy(ref, ref, send_sem).wait() + pltpu.make_async_copy(ref, ref, recv_sem).wait() + +def ppermute_done(send_sem, recv_sem, out) ->Array: + out = pl.pallas_call( + ppermute_done_kernel, + out_shape=( + jax.ShapeDtypeStruct( + out.shape, + dtype=out.dtype, + ), + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + input_output_aliases={0:0} + )(out, send_sem, recv_sem) + return out +``` + +Note: we i/o alias the output buffer here to guarantee that the consumers are downstream of the `ppermute_done`. + +We now can implement the decomposed collective permute. + +```py +def f(x): + fut = ppermute_start(x) + z = x + 1 # happens at the same time as ppermute + y = ppermute_done(fut) + return y, z +``` + +***OR CAN WE?*** + +## Why *doesn’t* this work? + +There are three remaining issues with this, each of which exists outside of Pallas to some degree. Here they are at a high level. + +1. Scheduling \- just because we write `ppermute_start`, then `x + 1`, then `ppermute_done` doesn’t guarantee that they will happen in that order. XLA is responsible for scheduling, so when we write JAX programs, we are setting up data dependencies that XLA will respect but XLA will not respect the specific order of operations written in JAX. +2. Lifetimes \- XLA assumes that once a value is out of scope in the dependency graph, its memory can be freed for use by other values. If we have an op that asynchronously copies x \-\> y, we need to ensure that x is alive until the copy is complete, otherwise we will be copying from garbage memory. +3. Defensive copies \- XLA reserves the right to create copies of values. We need to make sure we don’t introduce unnecessary copies to a) avoid unnecessary runtime overhead and b) ensure correctness. + +We will go over these issues one by one and suggest fixes. + +### Scheduling + +How do we explicitly force ops to happen in a particular order in JAX? Note that this is not a Pallas specific problem, and if we had async ops implemented using an alternative method, we’d still run into this. + +One way is to introduce an optimization barrier into the XLA program. The optimization barrier will prevent XLA moving ops around it. + +Here’s our original code: + +```py +def f(x): + fut = ppermute_start(x) + z = x + 1 + y = ppermute_done(fut) + return y, z +``` + +XLA could choose to execute `x + 1` in any of three places: + +```py +def f(x): + z = x + 1 + fut = ppermute_start(x) + y = ppermute_done(fut) + return y, z + +# OR + +def f(x): + fut = ppermute_start(x) + z = x + 1 + y = ppermute_done(fut) + return y, z + +# OR + +def f(x): + fut = ppermute_start(x) + y = ppermute_done(fut) + z = x + 1 + return y, z +``` + +To force the `x + 1` to happen between the `ppermute` ops, we can use `optimization_barrier`, which is semantically the identity function (i.e. `lambda x: x`) but introduces an explicit data dependency between values. Specifically, if we make the `x` that is used in `x + 1` dependent on the `fut` returned by `ppermute_start`, it must happen after `ppermute_start`. + +We also introduce a dependency that forces the output value `y` to depend on `z`. + +```py +def f(x): + fut = ppermute_start(x) + x, fut = optimization_barrier((x, fut)) # x now depends on fut + z = x + 1 + z, fut = optimization_barrier((z, fut)) # fut now depends on z + y = ppermute_done(fut) + return y, z +``` + +`optimization_barrier` is a good enough hammer for us to explicitly write out schedules. + +### Lifetimes + +Let’s look at our original code again and assume the ops are happening in the correct order. + +```py +def f(x): + fut = ppermute_start(x) + z = x + 1 + y = ppermute_done(fut) + return y, z +``` + +Let’s look at which point in the program XLA believes it is okay to free the buffer for `x`. It would be the point after which `x` is no longer used, specifically after `z = x + 1`. + +```py +def f(x): + fut = ppermute_start(x) + z = x + 1 + # XLA can free x here! + y = ppermute_done(fut) + return y, z +``` + +If XLA frees `x` after `z = x + 1` has completed, we run into a very bad problem. The `ppermute` could still be actively copying `x` to the neighbor after `z = x + 1` which means if `x` is freed, the `ppermute` will be reading from garbage memory\! + +How do we extend `x`’s lifetime to the `ppermute_done`? Well we can introduce a data dependency\! We need to modify our kernels a little bit to make this happen. + +First, we rewrite `ppermute_start` to return `x`, aliasing it through the kernel. + +```py +def ppermute_start_kernel( + in_ref, send_sem, recv_sem, out_ref, _, *, axis_name, +): + axis_size = jax.lax.psum(1, axis_name) + left_neighbor = jax.lax.rem( + jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size + ) + right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size) + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor) + pltpu.semaphore_wait(barrier_sem, 1) + pltpu.make_async_remote_copy( + in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor + ).start() + +def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array, Array]: + send_sem, recv_sem, x, out = pl.pallas_call( + functools.partial(ppermute_start_kernel, axis_name=axis_name), + out_shape=( + pltpu.SemaphoreType.DMA(()), + pltpu.SemaphoreType.DMA(()), + jax.ShapeDtypeStruct( + x.shape, + dtype=x.dtype, + ), + jax.ShapeDtypeStruct( + x.shape, + dtype=x.dtype, + ), + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=( + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + ), + input_output_aliases={0:2} + )(x) + return send_sem, recv_sem, x, out +``` + +We then have `ppermute_done` take in `x` and do nothing with it. + +```py +def ppermute_done_kernel(_, ref, send_sem, recv_sem, _): + pltpu.make_async_copy(ref, ref, send_sem).wait() + pltpu.make_async_copy(ref, ref, recv_sem).wait() + +def ppermute_done(send_sem, recv_sem, x, out) ->Array: + out = pl.pallas_call( + ppermute_done_kernel, + out_shape=( + jax.ShapeDtypeStruct( + out.shape, + dtype=out.dtype, + ), + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + input_output_aliases={1:0} + )(x, out, send_sem, recv_sem) + return out + +``` + +Now when we write + +```py +def f(x): + *sems, x ,out = ppermute_start(x) + z = x + 1 + y = ppermute_done(*sems, x, out) + return y, z +``` + +XLA can no longer free `x` because it is an input to `ppermute_done`\! This means that `x`’s lifetime is tied to the `ppermute` and this code is now correct. + +### Defensive copies + +XLA, in its buffer assignment pass, analyzes which buffers are aliased to each other and inserts copies whenever an operation that aliases one of its inputs is not the final consumer of that input. + +#### Background + +Here’s a simple example. Let’s say we have an op `add_one_inplace` which takes in an array and adds one, but promises to do it in-place. + +The following code would be legal. + +```py +def f(): + x = jnp.arange(...) + y = add_one_inplace(x) return y +``` + +However, if `x` had a separate consumer as well, the program may not execute correctly. + +```py +def f(): + x = jnp.arange(...) + y = add_one_inplace(x) + return y, x * 2 # another x consumer! +``` + +This is because `x * 2` operates on the original `x` but `add_one_inplace` clobbers the value in `x`. `x * 2` needs to make sure to read the original values of `x`, not the ones after we’ve incremented it by 1\. XLA notices this and inserts a `copy` op (which is semantically the identity but the input and output buffers will be different). + +```py +def f(x): + x2 = copy(x) + y = add_one_inplace(x2) + return y, x * 2 +``` + +This pass in XLA ensures correctness in the presence of ops that perform in-place updates by forcing them to effectively be out-of-place with `copy` ops. + +#### Copies with downstream ops + +Let’s revisit our example where we add 1 while `ppermute`ing. + +```py +def f(x): + fut = ppermute_start(x) + z = x + 1 + y = ppermute_done(fut) + return y, z +``` + +If we unpack the future into its components, we’ll see the the aliasing patterns: + +```py +def f(x): + *sems, x2, y = ppermute_start(x) + z = x + 1 + y = ppermute_done((*sems, x2, y)) + return y, z +``` + +We know that `x` is left unchanged by `ppermute_start` (that is, `x` is identical to `x2`), but XLA does not. In fact, it looks like our `add_one_inplace` example to XLA, where it conservatively assumes that `ppermute_start` mutated `x` and `x2` is the new aliased result. Therefore, when we do `z = x + 1`, we run into a consumer of the original buffer. XLA therefore introduces a copy\! + +```py +def f(x): + x2 = copy(x) + *sems, x2, y = ppermute_start(x2) + z = x + 1 + y = ppermute_done((*sems, x2, y)) + return y, z +``` + +This copy is unnecessary because we know that `x2` is unchanged from `x`. In order to remove this copy, we’d need some mechanism to inform XLA we are just forwarding a value. However, in the absence of that we can rewrite our program a bit to explicitly use `x2` instead of `x`. + +```py +def f(x): + *sems, x2, y = ppermute_start(x) + z = x2 + 1 + y = ppermute_done((*sems, x2, y)) + return y, z +``` + +Now, XLA doesn’t see a separate consumer of `x` so no more copy is introduced. However, this comes at a major downside in that it forces us to unpack the future coming from `ppermute_start`. It couples the lifetime problem to the copying problem. + +#### Loop aliasing + +Let’s consider a slightly more advanced example. Let’s implement a function that uses a `while_loop` with `ppermute` to send values around a ring. + +```py +def f(x): + def body(i, x): + fut = ppermute_start(x) + y = ppermute_done(fut) + return y + return fori_loop(0, 8, body, x) +``` + +One implementation detail of `fori_loop` is that the inputs and outputs buffers are automatically aliased to each other. Note that we are setting up some additional aliasing in the `ppermute_start` and `ppermute_done` ops. Let’s run our own “buffer assignment” by coloring each of the values in the program to determine how many unique buffers we need. + +First, we’ll unpack the `fut` tuple that has the aliased `x` and `out` buffers. + +```py +def f(x): + def body(i, x): + *sems, x, y = ppermute_start(x) + y = ppermute_done(*sems, x, y) + return y + return fori_loop(0, 8, body, x) +``` + +Let’s now color each of the values according to the unique buffer they are assigned. We have the input/output aliasing coming from `fori_loop`, the `x` aliasing coming from `ppermute_start` and the `y` aliasing coming from `ppermute_done`. + +```py +def f(x): + def body(i, x): + *sems, x, y = ppermute_start(x) + y = ppermute_done((*sems, x, y)) + return y + return fori_loop(0, 8, body, x) +``` + +If you run the alias analysis, you’ll find that all of the buffers have been colored the same\! Intuitively, this is problematic because if we are doing a loop of `ppermute`s, we can’t write into the same buffer we are sending into. We generally need an extra (i.e. a “double”) buffer to receive, and then usually we will switch the send/recv buffers on the next iteration. What XLA will do in practice is that it will observe the buffer re-use and defensively insert a copy. + +```py +def f(x): + def body(i, x): + x = copy(x) + *sems, x, y = ppermute_start(x) + y = ppermute_done((*sems, x, y)) + return y + return fori_loop(0, 8, body, x) +``` + +This copy means `x` and `y` are no longer aliased to each other and the program will be correct. However, do we need this copy? How do we introduce a double buffer to avoid expensive copies each iteration? The answer is unrolling\! + +We’ll manually unroll our code. + +```py +def f(x): + def body(i, x): + *sems, x, x2 = ppermute_start(x) + x2 = ppermute_done((*sems, x, x2)) + + *sems, x2, y = ppermute_start(x2) + y = ppermute_done((*sems, x2, y)) + return y + return fori_loop(0, 4, body, x) +``` + +Now if we were to run the same alias analysis, we’ll find that the buffers all no longer alias to each other and that we won’t need to insert defensive copies to be correct. + +Therefore, the simple solution to removing these copies is to use `fori_loop` with `unroll >= 2`. + +```py +def f(x): + def body(i, x): + fut = ppermute_start(x) + y = ppermute_done(fut) + return y + return fori_loop(0, 8, body, x, unroll=2) +``` + +That’s sufficient to implement this loop without extra copies\! + +#### Passing futures across loop boundaries + +Let’s now look at an even more advanced example. We’ll implement the same program as before but stagger the loop, where we begin the `ppermute` in a prologue before the loop, and wait on the `ppermute` at the beginning of the loop. + +```py +def f(x): + fut = ppermute_start(x) + def body(i, fut): + x = ppermute_done(fut) + fut = ppermute_start(x) + return fut + fut = fori_loop(0, 7, body, fut) + return ppermute_done(fut) +``` + +In this example, rather than passing a value `x` from one loop to another we are passing a future value. + +Let’s unpack the future again to see what’s happening. + +```py +def f(x): + fut = ppermute_start(x) + def body(i, fut): + *sems, x, out = fut + x = ppermute_done((*sems, x, out)) + (*sems, x, out) = ppermute_start(x) + return (*sems, x, out) + (*sems, x, out) = fori_loop(0, 7, body, x) + return ppermute_done((*sems, x, out)) +``` + +So we’re explicitly threading the semaphores, the input buffer, and the target output buffer as a loop carry. What happens if we run alias analysis now? Well, we’ll run into the same aliasing issue as in the previous section where `x` and `out` will be aliased to each other. XLA will introduce a copy. + +```py +def f(x): + fut = ppermute_start(x) + def body(i, fut): + *sems, x, out = fut + out = copy(out) + x = ppermute_done((*sems, x, out)) + (*sems, x, out) = ppermute_start(x) + return (*sems, x, out) + (*sems, x, out) = fori_loop(0, 7, body, x) + return ppermute_done((*sems, x, out)) +``` + +In this case, we inserted a copy on `out`. However, this is a really bad scenario because `out` is being actively copied into\! Even if we insert a copy on `x`, we will also run into issues because then `x`’s lifetime will not extend to the `ppermute_done`. This is very very bad\! We will not only get copies, but we will also get incorrect results\! + +The solution, as we observed before, is to avoid the copies by avoiding aliasing all the buffers via unrolling. So, if we do: + +```py +def f(x): + fut = ppermute_start(x) + def body(i, fut): + x = ppermute_done(fut) + fut = ppermute_start(x) + return fut + fut = fori_loop(0, 7, body, x, unroll=2) + return ppermute_done(fut) +``` + +our program should now be correct. + +### Putting it all together + +So we’ve come up with some rules of thumb: + +1. If we have operations dependent on the input value to the `ppermute`, unpack the future to use the aliased value instead of the original value. +2. Use `unroll >= 2` when doing `ppermute`s in a loop body. + +Let’s combine everything into one function that does `ppermute`s in a loop and accumulates the result. + +```py +def f(x): + out = jnp.zeros_like(x) + fut = (*sems, x, out) = ppermute_start(x) + out = out + x + def body(i, carry): + out, fut = carry + x = ppermute_done(fut) + fut = (*sems, x, out) = ppermute_start(x) + out = out + x + return out, fut + out, fut = fori_loop(0, 7, body, (out, fut), unroll=2) + return out, ppermute_done(fut) +``` + +Note that in this example, we don’t need `optimization_barrier`s because the loop boundary acts as a scheduling barrier, splitting up the `start`s and `done`s. + +That’s it, we are done\! This will be the official API for doing async ops in Pallas. Thank you everyone\! Mission accomplished\! + +***OR IS IT?*** + +## Revenge of the State + +While it seems we have worked around copies and incorrectness issues by using some clever tricks, we are still in an awkward position. This API is powerful, but has many many footguns and caveats. There are likely far many more edge cases we will need to deal with that even require deep knowledge of XLA to predict or understand. Should we release an API like this? Or is there an alternative? + +Well, the answer may have been in front of us this whole time. + +Let’s run through this whole exercise one more time, *except*, let’s write the stateful version. This means each of our custom async ops now operate on `Ref`s instead of values. + +```py +def ppermute_start_stateful(x_ref, y_ref) -> tuple[Semaphore, Semaphore]: + ... + +def ppermute_done_stateful(send_sem, recv_sem, x_ref, y_ref) -> None: + ... +``` + +Let’s assume we can implement these in Pallas and see what our new programs will look like. Let’s start with a basic collective permute: + +```py +def f(x): + x_ref = make_ref(x) + y_ref = make_ref(zeros_like(x)) + fut = ppermute_start_stateful(x_ref, y_ref) + ppermute_done_stateful(*fut, x_ref, y_ref) + return y_ref[...] +``` + +It’s a little bit more verbose than our original value-based version, but it has a few key differences. The first is that we create an “empty” `Ref` to receive the result of the `ppermute`, unlike the value-based version, which creates a value for us. One neat thing is that the lifetime of `x_ref` is clear here: it lives until `ppermute_done_stateful`. We don’t need to “sneak” the `x` value into the op like we did before. + +Another difference becomes more clear when we try adding an op between the `start/done`. + +```py +def f(x): + x_ref = make_ref(x) + y_ref = make_ref(zeros_like(x)) + fut = ppermute_start_stateful(x_ref, y_ref) + x_ref[...] += 1 + ppermute_done_stateful(*fut, x_ref, y_ref) + return y_ref[...] +``` + +Before, we ran into scheduling ambiguity, where XLA could re-order the add w.r.t. the `ppermute`. With stateful semantics, we actually add in an ordering constraint\! `x_ref[...] += 1` mutates `x_ref` so it can’t be moved wrt to `ppermute_done_stateful`. JAX can inject these scheduling constraints as part of the lowering to HLO. + +The final key difference is evident when we try our loop examples. + +```py +def f(x): + x_ref = make_ref(x) + y_ref = make_ref(zeros_like(x)) + def body(i, _): + fut = ppermute_start_stateful(x_ref, y_ref) + ppermute_done_stateful(*fut, x_ref, y_ref) + # Now switch to y_ref -> x_ref + fut = ppermute_start_stateful(y_ref, x_ref) + ppermute_done_stateful(*fut, y_ref, x_ref) + fori_loop(0, 8 // 2, body, None) + return x_ref[...] +``` + +Because of the requirement that we have a separate buffer ready to receive the `ppermute`, we were forced to write our code in such a way that unrolls it\! There is no way to write the version in XLA that requires copying because that would involve a `ppermute` that sends from a `Ref` into itself, which doesn’t really make sense. + +To handle this without the manual unrolling, we’d create a scratch buffer with a leading `2` dimension that acts as the send/recv target across iterations, switching each one. This is the same pattern we use internally in Pallas kernels when writing manually overlapped kernels. + +The realization here is that being stateful forces us to deal with a lot of the issues that pop up with value semantics earlier on. We define them away\! + +1. Scheduling \- stateful ops that have `Ref`s as inputs force an ordering of our program. Note that this will schedule operations on the same `Ref` wrt to each other. We might also need an `opt_barrier_stateful` to enforce more ordering constraints. +2. Lifetimes \- `Ref` lifetimes can be scoped via `run_state` or could be inputs to stateful ops. +3. Defensive copies \- Using `Ref`s forces us to handle buffer assignment “manually” and the lowering can ensure the aliasing works out to avoid any copies. + +Another important fundamental limitation is that we eventually stage out an HLO program where the live buffers and semaphores are represented as array value types. XLA does not provide guarantees about buffer lifetimes or which memory spaces they live in for these intermediate values. *Therefore, it is possible XLA can copy array values even if they are actively being copied into by Pallas kernels.* This is easy to verify in HLO but it is a sharp edge of using custom calls to represent asynchronous operations in HLO. + +## Conclusion + +We’ve gone over some tricky challenges when it comes to async ops in Pallas and JAX. `Ref`s seem like a promising way of representing these ops that circumvents some of the issues that come up with value semantics. However, a downside is that it puts stateful JAX front and center, which we haven’t done yet outside of Pallas. It’s worth thinking whether we should educate users about stateful ops, or provide a more dangerous API. We also don’t know if everything we want to do is expressible via `Ref`s as well. We should also brainstorm alternatives to state to flesh out the design space. For example, what if XLA offered a first-class futures API that respected lifetimes, and it could automatically do things like double buffer loops with futures in them? That might be a viable alternative but the tradeoff would be giving more control to the compiler vs explicit control from the user. diff --git a/docs/pallas/index.rst b/docs/pallas/index.rst index 467f375d0e43..5969349c962a 100644 --- a/docs/pallas/index.rst +++ b/docs/pallas/index.rst @@ -33,6 +33,13 @@ See also the :class:`jax.experimental.pallas` module API documentation. tpu/index .. toctree:: + :caption: Design Notes + :maxdepth: 1 + + async_note + +.. toctree:: + :caption: Other :maxdepth: 1 CHANGELOG diff --git a/jax/_src/api.py b/jax/_src/api.py index 8ca3803aec35..b548cc43fb3b 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2726,7 +2726,8 @@ def clear_backends(): pjit._infer_params_cached.cache_clear() pjit._pjit_lower_cached.cache_clear() pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error - pjit._cpp_pjit_cache.clear() + pjit._cpp_pjit_cache_fun_only.clear() + pjit._cpp_pjit_cache_explicit_attributes.clear() xc._xla.PjitFunctionCache.clear_all() @atexit.register @@ -2755,7 +2756,8 @@ def clear_caches(): util.clear_all_weakref_lru_caches() # Clear all C++ compiled executable caches for pjit - pjit._cpp_pjit_cache.clear() + pjit._cpp_pjit_cache_fun_only.clear() + pjit._cpp_pjit_cache_explicit_attributes.clear() pjit._infer_params_cached.cache_clear() xc._xla.PjitFunctionCache.clear_all() diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 108741b5f8fd..2335103939ff 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -53,6 +53,17 @@ ), ) +_ENABLE_COMPILER_REMAT_OPTIMIZATION_PASS = config.bool_flag( + "jax_compiler_enable_remat_pass", + config.bool_env('JAX_COMPILER_ENABLE_REMAT_PASS', False), + help=( + 'Config to enable the rematerialization HLO pass. ' + 'Useful to allow XLA to automatically trade off memory and ' + 'compute when encountering OOM errors. However, you are ' + 'likely to get better results manually with jax.checkpoint' + ) +) + # The special XLA-AutoFDO profile version that indicates that a profile is not # available and retrieval should not be attempted. _NO_PROFILE_DONT_RETRIEVE = -1 @@ -199,6 +210,9 @@ def get_compile_options( debug_options.xla_backend_optimization_level = 0 debug_options.xla_llvm_disable_expensive_passes = True debug_options.xla_test_all_input_layouts = False + + if not _ENABLE_COMPILER_REMAT_OPTIMIZATION_PASS.value: + debug_options.xla_disable_hlo_passes = "rematerialization" # XLA-AutoFDO profile version: precedence order is: # 1. Whatever --jax_xla_profile_version is set to. diff --git a/jax/_src/core.py b/jax/_src/core.py index 74d03b8d9464..51933a9f8bbf 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -343,8 +343,7 @@ def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None, ctx = ctx or JaxprEqnContext( compute_on.current_compute_type(), config.threefry_partitionable.value, - xla_metadata_lib.current_xla_metadata(), - ) + xla_metadata_lib.current_xla_metadata()) if config.enable_checks.value: assert all(isinstance(x, (Var, Literal)) for x in invars) assert all(isinstance(v, Var) for v in outvars) diff --git a/jax/_src/debugger/core.py b/jax/_src/debugger/core.py index f6b0a81baf92..1efeed73cbc8 100644 --- a/jax/_src/debugger/core.py +++ b/jax/_src/debugger/core.py @@ -112,6 +112,11 @@ def from_frameinfo(cls, frame_info) -> DebuggerFrame: # then we subtract it off from the `lineno` and don't need to subtract 1 # since both start and lineno are 1-indexed. offset = frame_info.lineno - max(start, 1) + if offset >= len(source): + # Sometimes we don't get a valid source/offset pair. This seems to + # happen sometimes when code uses eval(). If that happens, give up. + source = [] + offset = None except OSError: source = [] offset = None diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 3e7082ab10ec..3373496940e2 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -46,6 +46,7 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.sharding import Sharding from jax._src.sharding_impls import NamedSharding, parse_flatten_op_sharding +from jax._src.api_util import shaped_abstractify from jax._src.state import discharge as state_discharge logger = logging.getLogger(__name__) @@ -256,12 +257,29 @@ def debug_callback(callback: Callable[..., None], *args: Any, raise TypeError("first argument to jax.debug.callback must be callable, " f"but got an object of type {type(callback)}") flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) - effect = ordered_debug_effect if ordered else debug_effect - def _flat_callback(*flat_args): - args, kwargs = tree_util.tree_unflatten(in_tree, flat_args) + static_args, dyn_args = {}, [] + for i, a in enumerate(flat_args): + try: + shaped_abstractify(a) + dyn_args.append(a) + except (AssertionError, TypeError): + static_args[i] = a + + def _flat_callback(*dyn_args): + all_args = [None] * (len(static_args) + len(dyn_args)) + di = iter(dyn_args) + for i in range(len(all_args)): + if i in static_args: + all_args[i] = static_args[i] + else: + all_args[i] = next(di) + assert next(di, None) is None + args, kwargs = tree_util.tree_unflatten(in_tree, all_args) callback(*args, **kwargs) return () - debug_callback_p.bind(*flat_args, callback=_flat_callback, effect=effect) + + effect = ordered_debug_effect if ordered else debug_effect + debug_callback_p.bind(*dyn_args, callback=_flat_callback, effect=effect) class _DebugPrintFormatChecker(string.Formatter): diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 10850357f677..5f1d132bcbb3 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -132,3 +132,4 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: register('jax-numpy-linalg-matrix_rank-tol') register('jax-numpy-linalg-pinv-rcond') register('jax-numpy-quantile-interpolation') +register('jax-numpy-trimzeros-not-1d-array') diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 2d27bf064fce..374816e001ec 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2828,8 +2828,7 @@ def inline_jaxpr_into_trace( outvars = [Var('', v.aval) for v in eqn.outvars] src_ = (src if not eqn.source_info.name_stack else src.replace(name_stack=src.name_stack + eqn.source_info.name_stack)) - trace.frame.add_eqn(core.new_jaxpr_eqn(invars, outvars, eqn.primitive, - eqn.params, eqn.effects, src_)) + trace.frame.add_eqn(eqn.replace(invars, outvars, source_info=src_)) # type: ignore map(env.setdefault, eqn.outvars, outvars) tracer_env: dict[Var, Any] = dict(zip([*jaxpr.constvars, *jaxpr.invars], diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index b7d68f73c2a4..944e20fa7faa 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -22,6 +22,7 @@ from collections.abc import Callable, Sequence, Iterable, Iterator import dataclasses from functools import partial, lru_cache, cached_property +import functools import itertools as it import logging import math @@ -61,6 +62,7 @@ from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec @@ -88,6 +90,7 @@ class WeakRefList(list): logger = logging.getLogger(__name__) Index = Union[int, slice, tuple[Union[int, slice], ...]] +PyTreeDef = tree_util.PyTreeDef NoSharding = sharding_specs.NoSharding Chunked = sharding_specs.Chunked @@ -2904,6 +2907,34 @@ class MeshExecutableFastpathData(NamedTuple): in_device_local_layouts: Sequence[DeviceLocalLayout | None] +@dataclasses.dataclass(frozen=True, kw_only=True) +class JitGlobalCppCacheKeys: + donate_argnums: tuple[int, ...] | None = None + donate_argnames: tuple[str, ...] | None = None + device: xc.Device | None = None + backend: str | None = None + in_shardings_treedef: PyTreeDef | None = None + in_shardings_leaves: tuple[Any, ...] | None = None + out_shardings_treedef: PyTreeDef | None = None + out_shardings_leaves: tuple[Any, ...] | None = None + in_layouts_treedef: PyTreeDef | None = None + in_layouts_leaves: tuple[Any, ...] | None = None + out_layouts_treedef: PyTreeDef | None = None + out_layouts_leaves: tuple[Any, ...] | None = None + use_resource_env: bool = False + + @functools.cached_property + def contains_explicit_attributes(self): + return (self.donate_argnums is not None or + self.donate_argnames is not None or + self.device is not None or + self.backend is not None or + any(not is_unspecified(i) for i in self.in_shardings_leaves) or + any(not is_unspecified(o) for o in self.out_shardings_leaves) or + any(i is not None for i in self.in_layouts_leaves) or + any(o is not None for o in self.out_layouts_leaves)) + + def reflatten_outputs_for_dispatch(out_tree, out_flat): # We arrive at dispatch having flattened according to the default # pytree registry, but we want to re-flatten according to our @@ -3017,9 +3048,14 @@ def aot_cache_miss(*args, **kwargs): fastpath_data = None return outs, fastpath_data, False # Do not remove cache entry - return xc._xla.pjit( - self.unsafe_call.name, None, aot_cache_miss, [], [], [], - tree_util.dispatch_registry, cc_shard_arg) + if xla_extension_version >= 286: + return xc._xla.pjit( + self.unsafe_call.name, None, aot_cache_miss, [], [], + JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg) + else: + return xc._xla.pjit( + self.unsafe_call.name, None, aot_cache_miss, [], [], [], + tree_util.dispatch_registry, cc_shard_arg) def cc_shard_arg(x, sharding, layout): return shard_args([sharding], [layout], [x])[0] diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8d2c24d6e64c..e356756cd3e1 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2012,9 +2012,19 @@ def _cos_lowering(ctx, x): def _tan_impl(x): return div(sin(x), cos(x)) +def _tan_lowering(ctx, x): + # TODO(b/368011034): Remove after jaxlib 0.4.34 release. In 0.4.33, this + # lowering is supported, but export doesn't target a sufficiently up-to-date + # StableHLO version, and the compatibility updates from + # https://github.com/openxla/xla/pull/16649 aren't included in the 0.4.33 + # release. + if ctx.is_forward_compat(): + return _nary_lower_hlo(chlo.tan, ctx, x) + return _nary_lower_hlo(hlo.tan, ctx, x) + tan_p = standard_unop(_float | _complex, 'tan') ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans))) -mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan)) +mlir.register_lowering(tan_p, _tan_lowering) def asin_impl(x): if dtypes.issubdtype(_dtype(x), np.complexfloating): diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index a5b5aaf31799..c1f4831e5ec0 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -785,6 +785,14 @@ def _get_causal_mask(T, S): mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_)) return mask[None, None, :, :] +def _get_window_mask(T: int, S: int, local_window_size: tuple[int, int]): + query_pos = jnp.array(range(T)) + key_pos = jnp.array(range(S)) + left_window, right_window = local_window_size + left_mask = query_pos[..., None] <= key_pos[..., None, :] + left_window + right_mask = query_pos[..., None] >= key_pos[..., None, :] - right_window + return jnp.logical_and(right_mask, left_mask)[None, None, :, :] + def _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen): q_mask = True kv_mask = True @@ -802,7 +810,8 @@ def _get_padding_mask_encoded(T, q_seqlen): mask = q_indices < q_seqlen[:, None] return mask[:, :, None, None] -def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen): +def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen, + local_window_size): if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None: return logits @@ -817,6 +826,10 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen): mask = _get_causal_mask(T, S) combined_mask = jnp.logical_and(combined_mask, mask) + if local_window_size is not None: + mask = _get_window_mask(T, S, local_window_size) + combined_mask = jnp.logical_and(combined_mask, mask) + if q_seqlen is not None or kv_seqlen is not None: mask = _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen) combined_mask = jnp.logical_and(combined_mask, mask) @@ -826,7 +839,7 @@ def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen): return padded_logits def _dot_product_attention_core(query, key, value, bias, mask, is_causal, - scale, q_seqlen, kv_seqlen): + scale, q_seqlen, kv_seqlen, local_window_size): logits_dtype = jnp.promote_types(query.dtype, jnp.float32) logits = jnp.einsum('BTNH,BSNH->BNTS', query, key, preferred_element_type=logits_dtype) @@ -836,7 +849,8 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal, if bias is not None: logits = (logits + bias).astype(logits.dtype) - padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen) + padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen, + local_window_size) # Softmax and it is always carried out in fp32. padded_logits = padded_logits.astype(jnp.float32) @@ -857,7 +871,8 @@ def _dot_product_attention_xla( is_causal: bool, scale: float, q_seqlen: Array | None, - kv_seqlen: Array | None): + kv_seqlen: Array | None, + local_window_size: tuple[int, int] | None): B, T, N, H = query.shape _, S, K, _ = key.shape @@ -875,11 +890,13 @@ def _reshape_to_grouped(t): return t bias = _reshape_to_grouped(bias) mask = _reshape_to_grouped(mask) - vmapped_fn = jax.vmap(_dot_product_attention_core, - in_axes=(3, None, None, 2, 2, None, None, None, None), - out_axes=3) + vmapped_fn = jax.vmap( + _dot_product_attention_core, + in_axes=(3, None, None, 2, 2, None, None, None, None, None), + out_axes=3, + ) encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale, - q_seqlen, kv_seqlen) + q_seqlen, kv_seqlen, local_window_size) encoded = jnp.reshape(encoded, (B, T, N, H)) return encoded @@ -894,6 +911,7 @@ def dot_product_attention( is_causal: bool = False, query_seq_lengths: ArrayLike | None = None, key_value_seq_lengths: ArrayLike | None = None, + local_window_size: int | tuple[int, int] | None = None, implementation: Literal['xla', 'cudnn'] | None = None) -> Array: r"""Scaled dot product attention function. @@ -943,6 +961,12 @@ def dot_product_attention( :code:`(B)` key_value_seq_lengths: `int32` array of sequence lengths for key and value; shape :code:`(B)` + local_window_size: Window sizes to make self attention to attend to each + token's local window. If set, this specifies the (left_window_size, + right_window_size) for each token. E.g., if local_window_size == (3, 2) + and the sequence is [0, 1, 2, 3, 4, 5, c, 7, 8, 9], token `c` can attend + to [3, 4, 5, c, 7, 8]. If a single int is given, it will be intepreted as + a symmetric window (window_size, window_size). implementation: A string to control which implementation backend to use. Supported strings are `xla`, `cudnn` (cuDNN flash attention). It defaults to `None`, which will automatically select the best available backend. @@ -969,6 +993,8 @@ def _ensure_4d(t): query_seq_lengths = jnp.asarray(query_seq_lengths) if key_value_seq_lengths is not None: key_value_seq_lengths = jnp.asarray(key_value_seq_lengths) + if isinstance(local_window_size, int): + local_window_size = (local_window_size, local_window_size) def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], dtype: DType | None, name: str) -> None: @@ -1003,6 +1029,7 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal, scale=scale_val, q_seqlen=query_seq_lengths, kv_seqlen=key_value_seq_lengths, + local_window_size=local_window_size, ) case 'cudnn': use_padding = ( @@ -1022,9 +1049,21 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], mask_type = MaskType.CAUSAL elif use_padding: mask_type = MaskType.PADDING + # CuDNN supports only the left window with an exclusive boundary when + # causal mask is enabled. + sliding_window = None + if local_window_size is not None: + l_window, r_window = local_window_size + if r_window == 0 or mask_type == MaskType.CAUSAL: + sliding_window = l_window + 1 + else: + raise ValueError(f"cuDNN doesn't support right window: {r_window} " + "when causal mask is not used.") + out = cudnn_dot_product_attention( query_arr, key_arr, value_arr, bias, mask, query_seq_lengths, - key_value_seq_lengths, scale=scale_val, mask_type=mask_type + key_value_seq_lengths, scale=scale_val, mask_type=mask_type, + sliding_window_length=sliding_window, ) case None: # TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select @@ -1033,6 +1072,7 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal, scale=scale_val, q_seqlen=query_seq_lengths, kv_seqlen=key_value_seq_lengths, + local_window_size=local_window_size, ) case _: raise ValueError(f"Unsupported implementation option: {implementation}") diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 2b42f22c9fd3..47a36a2b83ff 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -7018,7 +7018,7 @@ def diagflat(v: ArrayLike, k: int = 0) -> Array: return res -def trim_zeros(filt, trim='fb'): +def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array: """Trim leading and/or trailing zeros of the input array. JAX implementation of :func:`numpy.trim_zeros`. @@ -7040,14 +7040,26 @@ def trim_zeros(filt, trim='fb'): >>> jnp.trim_zeros(x) Array([2, 0, 1, 4, 3], dtype=int32) """ - filt = core.concrete_or_error(asarray, filt, - "Error arose in the `filt` argument of trim_zeros()") - nz = (filt == 0) + # Non-array inputs are deprecated 2024-09-11 + util.check_arraylike("trim_zeros", filt, emit_warning=True) + core.concrete_or_error(None, filt, + "Error arose in the `filt` argument of trim_zeros()") + filt_arr = jax.numpy.asarray(filt) + del filt + if filt_arr.ndim != 1: + # Added on 2024-09-11 + if deprecations.is_accelerated("jax-numpy-trimzeros-not-1d-array"): + raise TypeError(f"'filt' must be 1-D array, but received {filt_arr.ndim}-D array.") + warnings.warn( + "Passing arrays with ndim != 1 to jnp.trim_zeros() is deprecated. Currently, it " + "works with Arrays having ndim != 1. In the future this will result in an error.", + DeprecationWarning, stacklevel=2) + nz = (filt_arr == 0) if reductions.all(nz): - return empty(0, _dtype(filt)) - start = argmin(nz) if 'f' in trim.lower() else 0 - end = argmin(nz[::-1]) if 'b' in trim.lower() else 0 - return filt[start:len(filt) - end] + return empty(0, filt_arr.dtype) + start: Array | int = argmin(nz) if 'f' in trim.lower() else 0 + end: Array | int = argmin(nz[::-1]) if 'b' in trim.lower() else 0 + return filt_arr[start:len(filt_arr) - end] def trim_zeros_tol(filt, tol, trim='fb'): @@ -7397,20 +7409,17 @@ def dot(a: ArrayLike, b: ArrayLike, *, batch_dims = ((), ()) a_ndim, b_ndim = ndim(a), ndim(b) if a_ndim == 0 or b_ndim == 0: - # TODO(jakevdp): lower this case to dot_general as well? - # Currently, doing so causes issues in remat tests due to #16805 - if preferred_element_type is not None: - a = a.astype(preferred_element_type) - b = b.astype(preferred_element_type) - result = lax.mul(a, b) + contract_dims: tuple[tuple[int, ...], tuple[int, ...]] = ((), ()) else: if b_ndim == 1: contract_dims = ((a_ndim - 1,), (0,)) else: contract_dims = ((a_ndim - 1,), (b_ndim - 2,)) - result = lax.dot_general(a, b, dimension_numbers=(contract_dims, batch_dims), - precision=precision, preferred_element_type=preferred_element_type) - return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) + result = lax.dot_general(a, b, dimension_numbers=(contract_dims, batch_dims), + precision=precision, + preferred_element_type=preferred_element_type) + return lax_internal._convert_element_type(result, preferred_element_type, + output_weak_type) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) @@ -8819,11 +8828,40 @@ def sort( return lax.rev(result, dimensions=[dimension]) if descending else result -@util.implements(np.sort_complex) @jit def sort_complex(a: ArrayLike) -> Array: + """Return a sorted copy of complex array. + + JAX implementation of :func:`numpy.sort_complex`. + + Complex numbers are sorted lexicographically, meaning by their real part + first, and then by their imaginary part if real parts are equal. + + Args: + a: input array. If dtype is not complex, the array will be upcast to complex. + + Returns: + A sorted array of the same shape and complex dtype as the input. If ``a`` + is multi-dimensional, it is sorted along the last axis. + + See also: + - :func:`jax.numpy.sort`: Return a sorted copy of an array. + + Examples: + >>> a = jnp.array([1+2j, 2+4j, 3-1j, 2+3j]) + >>> jnp.sort_complex(a) + Array([1.+2.j, 2.+3.j, 2.+4.j, 3.-1.j], dtype=complex64) + + Multi-dimensional arrays are sorted along the last axis: + + >>> a = jnp.array([[5, 3, 4], + ... [6, 9, 2]]) + >>> jnp.sort_complex(a) + Array([[3.+0.j, 4.+0.j, 5.+0.j], + [2.+0.j, 6.+0.j, 9.+0.j]], dtype=complex64) + """ util.check_arraylike("sort_complex", a) - a = lax.sort(asarray(a), dimension=0) + a = lax.sort(asarray(a)) return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype)) @util.implements(np.lexsort) @@ -10770,11 +10808,46 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', }[method] return impl(asarray(a), asarray(v), side, dtype) # type: ignore -@util.implements(np.digitize, lax_description=_dedent(""" - Optionally, the ``method`` argument can be used to configure the - underlying :func:`jax.numpy.searchsorted` algorithm.""")) + @partial(jit, static_argnames=('right', 'method')) -def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, *, method: str = 'scan') -> Array: +def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, + *, method: str | None = None) -> Array: + """Convert an array to bin indices. + + JAX implementation of :func:`numpy.digitize`. + + Args: + x: array of values to digitize. + bins: 1D array of bin edges. Must be monotonically increasing or decreasing. + right: if true, the intervals include the right bin edges. If false (default) + the intervals include the left bin edges. + method: optional method argument to be passed to :func:`~jax.numpy.searchsorted`. + See that function for available options. + + Returns: + An integer array of the same shape as ``x`` indicating the bin number that + the values are in. + + See also: + - :func:`jax.numpy.searchsorted`: find insertion indices for values in a + sorted array. + - :func:`jax.numpy.histogram`: compute frequency of array values within + specified bins. + + Examples: + >>> x = jnp.array([1.0, 2.0, 2.5, 1.5, 3.0, 3.5]) + >>> bins = jnp.array([1, 2, 3]) + >>> jnp.digitize(x, bins) + Array([1, 2, 2, 1, 3, 3], dtype=int32) + >>> jnp.digitize(x, bins, right=True) + Array([0, 1, 2, 1, 2, 3], dtype=int32) + + ``digitize`` supports reverse-ordered bins as well: + + >>> bins = jnp.array([3, 2, 1]) + >>> jnp.digitize(x, bins) + Array([2, 1, 1, 2, 0, 0], dtype=int32) + """ util.check_arraylike("digitize", x, bins) right = core.concrete_or_error(bool, right, "right argument of jnp.digitize()") bins_arr = asarray(bins) @@ -10783,10 +10856,11 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, *, method: str if bins_arr.shape[0] == 0: return zeros_like(x, dtype=int32) side = 'right' if not right else 'left' + kwds: dict[str, str] = {} if method is None else {'method': method} return where( bins_arr[-1] >= bins_arr[0], - searchsorted(bins_arr, x, side=side, method=method), - bins_arr.shape[0] - searchsorted(bins_arr[::-1], x, side=side, method=method) + searchsorted(bins_arr, x, side=side, **kwds), + bins_arr.shape[0] - searchsorted(bins_arr[::-1], x, side=side, **kwds) ) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 56c47b9401cc..f8ec3b63339a 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -659,9 +659,10 @@ def slice_scratch_ops(self): @property def in_shapes(self) -> Iterable[jax.ShapeDtypeStruct]: """The shapes of *index, *inputs.""" - index_shapes = (jax.ShapeDtypeStruct(ia.inner_aval.shape, - ia.inner_aval.dtype) - for ia in self.index_map_avals[len(self.grid):]) + index_shapes = ( + jax.ShapeDtypeStruct(ia.shape, ia.dtype) + for ia in self.index_map_avals[len(self.grid) :] + ) inputs_shapes = ( bm.array_shape_dtype for bm in self.block_mappings[:self.num_inputs]) @@ -728,7 +729,16 @@ def _convert_block_spec_to_block_mapping( index_map_grid_aval = jax_core.ShapedArray((), jnp.int32) -@dataclasses.dataclass(init=False) + +class ScratchShape(Protocol): + def get_aval(self) -> jax_core.AbstractValue: + ... + + +ScratchShapeTree = Sequence[Union[ScratchShape, "ScratchShapeTree"]] + + +@dataclasses.dataclass(init=False, kw_only=True) class GridSpec: """Encodes the grid parameters for :func:`jax.experimental.pallas.pallas_call`. @@ -741,12 +751,14 @@ class GridSpec: grid_names: tuple[Hashable, ...] | None in_specs: BlockSpecTree out_specs: BlockSpecTree + scratch_shapes: ScratchShapeTree = () def __init__( self, grid: Grid = (), in_specs: BlockSpecTree = no_block_spec, out_specs: BlockSpecTree = no_block_spec, + scratch_shapes: ScratchShapeTree = (), ): # Be more lenient for in/out_specs if isinstance(in_specs, list): @@ -758,6 +770,7 @@ def __init__( self.in_specs = in_specs self.out_specs = out_specs + self.scratch_shapes = tuple(scratch_shapes) grid_names = None if isinstance(grid, int): @@ -773,9 +786,6 @@ def __init__( self.grid = grid # type: ignore self.grid_names = grid_names - def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue: - assert False # Not needed in GridSpec - def _make_scalar_ref_aval(self, aval): assert False # Not needed in GridSpec @@ -820,12 +830,10 @@ def get_grid_mapping( else: num_flat_scalar_prefetch = 0 jaxpr_scalar_ref_avals = () - - scratch_shapes: tuple[Any, ...] = getattr(grid_spec, "scratch_shapes", ()) - if scratch_shapes: + if grid_spec.scratch_shapes: flat_scratch_shapes, scratch_tree = tree_util.tree_flatten( - scratch_shapes) - flat_scratch_avals = map(grid_spec._make_scratch_aval, flat_scratch_shapes) + grid_spec.scratch_shapes) + flat_scratch_avals = map(lambda s: s.get_aval(), flat_scratch_shapes) num_flat_scratch_operands = len(flat_scratch_avals) jaxpr_scratch_avals = tree_util.tree_unflatten( scratch_tree, flat_scratch_avals) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 61b1dc435e72..b2b892a64f90 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -19,7 +19,7 @@ import dataclasses import enum import functools -from typing import Any, ClassVar, Hashable, Literal +from typing import Any, ClassVar, Literal import jax from jax._src import core as jax_core @@ -39,6 +39,7 @@ BlockSpecTree = pallas_core.BlockSpecTree GridMapping = pallas_core.GridMapping NoBlockSpec = pallas_core.NoBlockSpec +ScratchShapeTree = pallas_core.ScratchShapeTree AbstractMemoryRef = pallas_core.AbstractMemoryRef no_block_spec = pallas_core.no_block_spec _convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping @@ -174,14 +175,9 @@ def get_aval(self) -> AbstractMemoryRef: jax_core.ShapedArray(self.shape, self.dtype), self.memory_space) -@dataclasses.dataclass(init=False, unsafe_hash=True) +@dataclasses.dataclass(init=False, kw_only=True, unsafe_hash=True) class PrefetchScalarGridSpec(pallas_core.GridSpec): - grid: TupleGrid - grid_names: tuple[Hashable, ...] | None num_scalar_prefetch: int - in_specs: pallas_core.BlockSpecTree - out_specs: pallas_core.BlockSpecTree - scratch_shapes: tuple[Any, ...] def __init__( self, @@ -189,9 +185,9 @@ def __init__( grid: Grid = (), in_specs: BlockSpecTree = no_block_spec, out_specs: BlockSpecTree = no_block_spec, - scratch_shapes: Any | Sequence[Any] = () + scratch_shapes: ScratchShapeTree = () ): - super().__init__(grid, in_specs, out_specs) + super().__init__(grid, in_specs, out_specs, scratch_shapes) self.num_scalar_prefetch = num_scalar_prefetch self.scratch_shapes = tuple(scratch_shapes) @@ -199,14 +195,6 @@ def _make_scalar_ref_aval(self, aval): return AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype), TPUMemorySpace.SMEM) - def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue: - if isinstance(obj, MemoryRef): - return obj.get_aval() - if isinstance(obj, SemaphoreType): - return obj.get_aval() - raise ValueError(f"No registered conversion for {type(obj)}. " - "Only VMEM and SemaphoreType are supported.") - @dataclasses.dataclass(frozen=True) class TensorCore: diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 13d861033e90..f76a4d86616a 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2737,6 +2737,9 @@ def _delay_rule(ctx: LoweringRuleContext, nanos: int): def _debug_print_rule( ctx: LoweringRuleContext, *args, fmt: str, has_placeholders: bool ): + if any(aval.shape for aval in ctx.avals_in): + raise NotImplementedError("Only scalar values are supported") + primitives.check_debug_print_format(fmt, *args) if has_placeholders: if not all( diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index fca9ee471e6a..e8f2384784eb 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -189,7 +189,7 @@ class BufferedRef: dtype: dtype for buffers. buffer_type: enum indicating whether this is an input, output, or in/out accumulator buffered reference. - vmem_ref: a double-buffer to hold a working buffer and a dirty buffer used + window_ref: a double-buffer to hold a working buffer and a dirty buffer used to copy into and out of. In the case of a BufferedRef targeting a VMEM reference, this simply points to the existing ref. accum_ref: accumulating buffer used by accumulator BufferedRefs. @@ -210,7 +210,7 @@ class BufferedRef: spec: pl.BlockSpec # static metadata dtype: Any # static metadata buffer_type: BufferType # static metadata - vmem_ref: REF | None + window_ref: REF | None accum_ref: REF | None current_slot: ArrayRef | None next_slot: ArrayRef | None @@ -218,9 +218,17 @@ class BufferedRef: sem_sends: SemaphoreTuple | None def tree_flatten(self): - return ((self.vmem_ref, self.accum_ref, self.current_slot, - self.next_slot, self.sem_recvs, self.sem_sends), - (self.spec, self.dtype, self.buffer_type)) + return ( + ( + self.window_ref, + self.accum_ref, + self.current_slot, + self.next_slot, + self.sem_recvs, + self.sem_sends, + ), + (self.spec, self.dtype, self.buffer_type), + ) @classmethod def tree_unflatten(cls, meta, data): @@ -252,7 +260,7 @@ def create(cls, spec, dtype, buffer_type) -> BufferedRef: spec=spec, dtype=dtype, buffer_type=buffer_type, - vmem_ref=None, # to be bound to existing ref by the pipeline routine + window_ref=None, # to be bound to existing ref by the pipeline routine accum_ref=accum_ref, current_slot=None, next_slot=None, @@ -260,11 +268,12 @@ def create(cls, spec, dtype, buffer_type) -> BufferedRef: sem_sends=None, ) else: + memory_space = SMEM if spec.memory_space == SMEM else VMEM return cls( spec=spec, dtype=dtype, buffer_type=buffer_type, - vmem_ref=VMEM((2,) + block_shape, dtype), + window_ref=memory_space((2,) + block_shape, dtype), accum_ref=accum_ref, current_slot=SMEM((1,), jnp.int32), next_slot=SMEM((1,), jnp.int32), @@ -313,9 +322,9 @@ def current_ref(self): buffer_slice = tuple( 0 if x is None else slice(None) for x in self.block_shape) if self.memory_space == VMEM: - return self.vmem_ref.at[buffer_slice] + return self.window_ref.at[buffer_slice] else: - return self.vmem_ref.at[(self.current_slot[0], *buffer_slice)] + return self.window_ref.at[(self.current_slot[0], *buffer_slice)] @property def is_input(self): @@ -341,11 +350,12 @@ def is_accumulator(self): def is_input_output(self): return self.buffer_type == BufferType.INPUT_OUTPUT - def bind_existing_ref(self, vmem_ref, indices): + def bind_existing_ref(self, window_ref, indices): """For handling VMEM references, the pipeline aliases the existing ref.""" if self.memory_space == VMEM: return dataclasses.replace( - self, vmem_ref=vmem_ref.at[self.compute_slice(indices)]) + self, window_ref=window_ref.at[self.compute_slice(indices)] + ) return self def compute_slice(self, grid_indices): @@ -432,8 +442,9 @@ def copy_in(self, src_ref, grid_indices): dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) tpu_primitives.make_async_copy( src_ref.at[src_slice], - self.vmem_ref.at[next_slot].at[dst_slice], - self.sem_recvs.at[next_slot]).start() + self.window_ref.at[next_slot].at[dst_slice], + self.sem_recvs.at[next_slot], + ).start() def copy_out(self, dst_ref, grid_indices): """Starts copy of HBM dma slice from the current slot.""" @@ -444,9 +455,10 @@ def copy_out(self, dst_ref, grid_indices): dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) tpu_primitives.make_async_copy( - self.vmem_ref.at[slot].at[src_slice], + self.window_ref.at[slot].at[src_slice], dst_ref.at[dst_slice], - self.sem_sends.at[slot]).start() + self.sem_sends.at[slot], + ).start() def wait_in(self, src_ref, grid_indices): """Waits for input copy to finish.""" @@ -456,9 +468,12 @@ def wait_in(self, src_ref, grid_indices): dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) current_slot = self.current_slot[0] tpu_primitives.make_async_copy( - src_ref.at[src_slice], # nb: doesn't matter - self.vmem_ref.at[current_slot].at[dst_slice], # only dst shape is important - self.sem_recvs.at[current_slot]).wait() + src_ref.at[src_slice], # nb: doesn't matter + self.window_ref.at[current_slot].at[ + dst_slice + ], # only dst shape is important + self.sem_recvs.at[current_slot], + ).wait() def wait_out(self, dst_ref, grid_indices): """Waits for output copy to finish.""" @@ -468,9 +483,10 @@ def wait_out(self, dst_ref, grid_indices): dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) tpu_primitives.make_async_copy( - self.vmem_ref.at[prev_slot].at[src_slice], # nb: doesn't matter - dst_ref.at[dst_slice], # only dst shape is important - self.sem_sends.at[prev_slot]).wait() + self.window_ref.at[prev_slot].at[src_slice], # nb: doesn't matter + dst_ref.at[dst_slice], # only dst shape is important + self.sem_sends.at[prev_slot], + ).wait() # Accumulator methods # @@ -498,14 +514,14 @@ def accumulate(self): assert self.is_accumulator if self.accum_ref is not None: accum_dtype = jnp.float32 - if self.vmem_ref.dtype == jnp.int32: + if self.window_ref.dtype == jnp.int32: accum_dtype = jnp.int32 # TODO(levskaya): we could generalize init and reduction functions, # could it ever be useful to support more generic monoids? self.current_ref[...] = ( - self.current_ref[...].astype(accum_dtype) + - self.accum_ref[...].astype(accum_dtype) - ).astype(self.vmem_ref.dtype) + self.current_ref[...].astype(accum_dtype) + + self.accum_ref[...].astype(accum_dtype) + ).astype(self.window_ref.dtype) # Helper to tree map over BufferedRefs as leaves. diff --git a/jax/_src/pallas/mosaic_gpu/__init__.py b/jax/_src/pallas/mosaic_gpu/__init__.py index 11258f741b7f..1bd512834ce5 100644 --- a/jax/_src/pallas/mosaic_gpu/__init__.py +++ b/jax/_src/pallas/mosaic_gpu/__init__.py @@ -17,7 +17,6 @@ from jax._src.pallas.mosaic_gpu.core import Barrier from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams -from jax._src.pallas.mosaic_gpu.core import GPUGridSpec from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import async_copy_gmem_to_smem diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 3ef205d336d0..34ad5acf34d6 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -82,7 +82,7 @@ def __init__(self, tiling: tuple[int, ...]): def __call__( self, block_aval: pallas_core.AbstractMemoryRef ) -> pallas_core.AbstractMemoryRef: - block_shape = block_aval.inner_aval.shape # pytype: disable=attribute-error + block_shape = block_aval.shape old_tiled_dims = block_shape[-len(self.tiling) :] num_tiles = tuple( block_dim // tiling_dim @@ -150,26 +150,6 @@ def to_block_mapping( ) -@dataclasses.dataclass(init=False, kw_only=True) -class GPUGridSpec(pallas_core.GridSpec): - scratch_shapes: Sequence[Any] - - def __init__( - self, - grid: pallas_core.Grid = (), - in_specs: pallas_core.BlockSpecTree = pallas_core.no_block_spec, - out_specs: pallas_core.BlockSpecTree = pallas_core.no_block_spec, - scratch_shapes: Sequence[Any] = () - ): - super().__init__(grid, in_specs, out_specs) - self.scratch_shapes = tuple(scratch_shapes) - - def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue: - if isinstance(obj, (MemoryRef, Barrier)): - return obj.get_aval() - raise TypeError(f"Cannot convert {obj} to an abstract value") - - # TODO(b/354568887): Cosolidate this with TPU's MemoryRef. @dataclasses.dataclass(frozen=True) class MemoryRef: diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 39483e674681..1d76dc8405d5 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -268,10 +268,7 @@ def lower_jaxpr_to_module( for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs] ] in_structs_smem = [ - jax.ShapeDtypeStruct( - [num_stages, *bm.ref_aval.inner_aval.shape], - bm.ref_aval.inner_aval.dtype, - ) + jax.ShapeDtypeStruct([num_stages, *bm.ref_aval.shape], bm.ref_aval.dtype) if in_smem else None for bm, in_smem in zip( @@ -693,8 +690,9 @@ def _debug_print_lowering_rule( fmt, has_placeholders: bool, ): - del ctx - del has_placeholders + del has_placeholders # Unused. + if any(aval.shape for aval in ctx.avals_in): + raise NotImplementedError("Only scalar values are supported") primitives.check_debug_print_format(fmt, *args) mgpu.debug_print(fmt, *args) return () diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index b69fb03f0951..206c0cdee876 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -62,6 +62,7 @@ BlockSpecTree = pallas_core.BlockSpecTree NoBlockSpec = pallas_core.NoBlockSpec no_block_spec = pallas_core.no_block_spec +ScratchShapeTree = pallas_core.ScratchShapeTree CostEstimate = pallas_core.CostEstimate # See the docstring for GridMapping for the calling convention @@ -1233,6 +1234,7 @@ def pallas_call( grid: TupleGrid = (), in_specs: BlockSpecTree = no_block_spec, out_specs: BlockSpecTree = no_block_spec, + scratch_shapes: ScratchShapeTree = (), input_output_aliases: dict[int, int] = {}, debug: bool = False, interpret: bool = False, @@ -1250,8 +1252,9 @@ def pallas_call( corresponding ``in_specs`` and ``out_specs``. out_shape: a PyTree of :class:`jax.ShapeDtypeStruct` describing the shape and dtypes of the outputs. - grid_spec: An alternative way to specify ``grid``, ``in_specs``, and - ``out_specs``. If given, those other parameters must not be also given. + grid_spec: An alternative way to specify ``grid``, ``in_specs``, + ``out_specs`` and ``scratch_shapes``. If given, those other parameters + must not be also given. grid: the iteration space, as a tuple of integers. The kernel is executed as many times as ``prod(grid)``. See details at :ref:`pallas_grid`. @@ -1265,6 +1268,9 @@ def pallas_call( The default value for ``out_specs`` specifies the whole array, e.g., as ``pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)``. See details at :ref:`pallas_blockspec`. + scratch_shapes: a PyTree of backend-specific temporary objects required + by the kernel, such as temporary buffers, synchronization primitives, + etc. input_output_aliases: a dictionary mapping the index of some inputs to the index of the output that aliases them. These indices are in the flattened inputs and outputs. @@ -1305,7 +1311,7 @@ def pallas_call( } if grid_spec is None: - grid_spec = GridSpec(grid, in_specs, out_specs) + grid_spec = GridSpec(grid, in_specs, out_specs, scratch_shapes) else: if grid: raise ValueError( @@ -1319,6 +1325,10 @@ def pallas_call( raise ValueError( "If `grid_spec` is specified, then `out_specs` must " f"be `no_block_spec`. It is {out_specs}") + if scratch_shapes: + raise ValueError( + "If `grid_spec` is specified, then `scratch_shapes` must " + f"be `()`. It is {scratch_shapes}") del grid, in_specs, out_specs grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec) # TODO(necula): this canonicalization may be convenient for some usage diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index fbc389aae3fb..8cba0a36c6e4 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -714,7 +714,7 @@ class PrintEffect(effects.Effect): def debug_print(fmt: str, *args: jax.typing.ArrayLike): - """Prints scalar values from inside a Pallas kernel. + """Prints values from inside a Pallas kernel. Args: fmt: A format string to be included in the output. The restrictions on the @@ -724,11 +724,11 @@ def debug_print(fmt: str, *args: jax.typing.ArrayLike): (``{...}``), since it is always printed before any of the values. * On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must contain a placeholder for each value to be printed. Format specs and - conversions are not supported. + conversions are not supported. All values must be scalars. * In TPU, if ``fmt`` contains placeholders, all values must be 32-bit integers. If there are no placeholders, the values are printed after - the format string. - *args: The scalar values to print. + the format string. All values must be scalars. + *args: The values to print. """ # fmt: skip has_placeholders = False if fmt: @@ -771,9 +771,7 @@ def debug_print_impl(*args: Any, fmt: str, has_placeholders: bool): @debug_print_p.def_effectful_abstract_eval def debug_print_abstract_eval(*avals: Any, fmt: str, has_placeholders: bool): - del fmt, has_placeholders - if any(aval.shape for aval in avals): - raise ValueError("Only scalar values are supported") + del avals, fmt, has_placeholders # Unused. return [], {debug_print_effect} diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 5e495f4bef3e..0a23e512dfb3 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -277,6 +277,10 @@ def lower_jaxpr_to_triton_module( raise NotImplementedError( "scalar prefetch not implemented in the Triton backend" ) + if jaxpr.invars[grid_mapping.slice_scratch_ops]: + raise NotImplementedError( + "scratch memory not implemented in the Triton backend" + ) with grid_mapping.trace_env(): jaxpr, _ = pe.dce_jaxpr( jaxpr, [True] * len(jaxpr.outvars), instantiate=True @@ -1202,7 +1206,14 @@ def debug_print_lowering_rule( "pl.debug_print() does not support placeholders when lowering to Triton" ) - tt_dialect.print_(f" {fmt} ", hex=False, args=args) + tt_dialect.print_( + f" {fmt} ", + hex=False, + args=args, + is_signed=ir.DenseI32ArrayAttr.get([ + jnp.issubdtype(aval.dtype, jnp.signedinteger) for aval in ctx.avals_in + ]), + ) return () diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index fb76f7931c01..34bf257f639e 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -62,6 +62,7 @@ from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src import sharding from jax._src.mesh import AbstractMesh from jax._src.sharding_impls import ( @@ -164,7 +165,6 @@ class PjitInfo(NamedTuple): keep_unused: bool inline: bool abstracted_axes: Any | None - has_explicit_sharding: bool use_resource_env: bool # False for jit, True for pjit # Hash and compare PjitInfo by identity when used as a cache key. @@ -311,14 +311,39 @@ def _cpp_pjit_evict_fn(self): # The entries are doubled here from the default 4096 because _pjit_call_impl # also has a cpp dispatch path and that would double the number of entries in # the global shared cache. -_cpp_pjit_cache = xc._xla.PjitFunctionCache(capacity=8192) +# This cache is only used for jit's with only fun. For example: jax.jit(f) +_cpp_pjit_cache_fun_only = xc._xla.PjitFunctionCache(capacity=8192) +# This cache is used for jit where extra arguments are defined other than the +# fun. For example: jax.jit(f, donate_argnums=...) OR +# jax.jit(f, out_shardings=...), etc. We don't use the same cache because the +# capacity might get full very fast because of all the jitted function in JAX +# which might evict train_step for example. +_cpp_pjit_cache_explicit_attributes = xc._xla.PjitFunctionCache(capacity=8192) -def _get_cpp_global_cache(pjit_has_explicit_sharding): - if pjit_has_explicit_sharding: - return xc._xla.PjitFunctionCache() - else: - return _cpp_pjit_cache + +if xla_extension_version < 286: + def _get_cpp_global_cache(pjit_has_explicit_sharding): + if pjit_has_explicit_sharding: + return xc._xla.PjitFunctionCache() + else: + return _cpp_pjit_cache_fun_only + + def _pjit_explicit_sharding_and_layout( + in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat, + device, backend) -> bool: + return (device is not None or + backend is not None or + any(not is_unspecified(i) for i in in_shardings_flat) or + any(not is_unspecified(o) for o in out_shardings_flat) or + any(i is not None for i in in_layouts_flat) or + any(o is not None for o in out_layouts_flat)) +else: + def _get_cpp_global_cache(contains_explicit_attributes: bool): # type: ignore + if contains_explicit_attributes: + return _cpp_pjit_cache_explicit_attributes + else: + return _cpp_pjit_cache_fun_only def _cpp_pjit(fun: Callable, jit_info: PjitInfo): @@ -339,11 +364,35 @@ def cache_miss(*args, **kwargs): return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) - cpp_pjit_f = xc._xla.pjit( - fun_name(fun), - fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames, - jit_info.donate_argnums, tree_util.dispatch_registry, pxla.cc_shard_arg, - _get_cpp_global_cache(jit_info.has_explicit_sharding)) + if xla_extension_version >= 286: + cache_key = pxla.JitGlobalCppCacheKeys( + donate_argnums=jit_info.donate_argnums, + donate_argnames=jit_info.donate_argnames, + device=jit_info.device, backend=jit_info.backend, + in_shardings_treedef=jit_info.in_shardings_treedef, + in_shardings_leaves=jit_info.in_shardings_leaves, + out_shardings_treedef=jit_info.out_shardings_treedef, + out_shardings_leaves=jit_info.out_shardings_leaves, + in_layouts_treedef=jit_info.in_layouts_treedef, + in_layouts_leaves=jit_info.in_layouts_leaves, + out_layouts_treedef=jit_info.out_layouts_treedef, + out_layouts_leaves=jit_info.out_layouts_leaves, + use_resource_env=jit_info.use_resource_env) + cpp_pjit_f = xc._xla.pjit( + fun_name(fun), fun, cache_miss, jit_info.static_argnums, + jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore + pxla.cc_shard_arg, + _get_cpp_global_cache(cache_key.contains_explicit_attributes)) + else: + has_explicit_sharding = _pjit_explicit_sharding_and_layout( + jit_info.in_shardings_leaves, jit_info.out_shardings_leaves, + jit_info.in_layouts_leaves, jit_info.out_layouts_leaves, + jit_info.device, jit_info.backend) + cpp_pjit_f = xc._xla.pjit( + fun_name(fun), fun, cache_miss, jit_info.static_argnums, + jit_info.static_argnames, jit_info.donate_argnums, + tree_util.dispatch_registry, pxla.cc_shard_arg, + _get_cpp_global_cache(has_explicit_sharding)) cpp_pjitted_f = wraps(fun)(cpp_pjit_f) cpp_pjitted_f._fun = fun @@ -351,17 +400,6 @@ def cache_miss(*args, **kwargs): return cpp_pjitted_f -def _pjit_explicit_sharding_and_layout( - in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat, - device, backend) -> bool: - return (device is not None or - backend is not None or - any(not is_unspecified(i) for i in in_shardings_flat) or - any(not is_unspecified(o) for o in out_shardings_flat) or - any(i is not None for i in in_layouts_flat) or - any(o is not None for o in out_layouts_flat)) - - def _split_layout_and_sharding(entries): entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None) layouts, shardings = [], [] @@ -445,10 +483,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, fun, fun_signature, donate_argnums, donate_argnames, static_argnums, static_argnames) - has_explicit_sharding = _pjit_explicit_sharding_and_layout( - in_shardings_leaves, out_shardings_leaves, in_layouts_leaves, - out_layouts_leaves, device, backend) - return PjitInfo( fun_sourceinfo=fun_sourceinfo, fun_signature=fun_signature, @@ -466,7 +500,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, donate_argnames=donate_argnames, device=device, backend=backend, keep_unused=keep_unused, inline=inline, abstracted_axes=abstracted_axes, - has_explicit_sharding=has_explicit_sharding, use_resource_env=use_resource_env) @@ -1706,13 +1739,27 @@ def call_impl_cache_miss(*args_, **kwargs_): f = _get_jaxpr_as_fun( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline) - donated_argnums = [i for i, d in enumerate(donated_invars) if d] - has_explicit_sharding = _pjit_explicit_sharding_and_layout( - in_shardings, out_shardings, in_layouts, out_layouts, None, None) - return xc._xla.pjit( - name, f, call_impl_cache_miss, [], [], donated_argnums, - tree_util.dispatch_registry, pxla.cc_shard_arg, - _get_cpp_global_cache(has_explicit_sharding))(*args) + donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d) + if xla_extension_version >= 286: + cache_key = pxla.JitGlobalCppCacheKeys( + donate_argnums=donated_argnums, donate_argnames=None, + device=None, backend=None, + in_shardings_treedef=None, in_shardings_leaves=in_shardings, + out_shardings_treedef=None, out_shardings_leaves=out_shardings, + in_layouts_treedef=None, in_layouts_leaves=in_layouts, + out_layouts_treedef=None, out_layouts_leaves=out_layouts, + use_resource_env=resource_env is not None) + return xc._xla.pjit( + name, f, call_impl_cache_miss, [], [], cache_key, + tree_util.dispatch_registry, pxla.cc_shard_arg, + _get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args) + else: + has_explicit_sharding = _pjit_explicit_sharding_and_layout( + in_shardings, out_shardings, in_layouts, out_layouts, None, None) + return xc._xla.pjit( + name, f, call_impl_cache_miss, [], [], donated_argnums, + tree_util.dispatch_registry, pxla.cc_shard_arg, + _get_cpp_global_cache(has_explicit_sharding))(*args) pjit_p.def_impl(_pjit_call_impl) @@ -1753,13 +1800,11 @@ def pjit_staging_rule(trace, *args, **params): params['jaxpr'], params['out_shardings'], params['out_layouts']) params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, out_layouts=out_layouts) - if (params["inline"] and all(is_unspecified(i) for i in params["in_shardings"]) and all(is_unspecified(o) for o in params["out_shardings"]) and all(i is None for i in params["in_layouts"]) and all(o is None for o in params["out_layouts"])): - if config.dynamic_shapes.value: # Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic # shapes are enabled, use eval_jaxpr, which uses the tracing machinery, diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 6a912abf215b..4231822965b1 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -516,8 +516,7 @@ def eval_jaxpr(*refs): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( eval_jaxpr, [*in_avals, *res_ref_avals]) assert not consts - return jaxpr, [core.ShapedArray(a.inner_aval.shape, a.inner_aval.dtype) # pytype: disable=attribute-error - for a in res_ref_avals] + return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals] def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr: assert not jaxpr.constvars, "Jaxpr should not have constvars" diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 05368e978593..8289f858498b 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -196,15 +196,21 @@ def join(self, other): @property def shape(self): - if not isinstance(self.inner_aval, core.ShapedArray): - raise AttributeError(f"`Ref{{{self.inner_aval.str_short()}}} has no `shape`.") - return self.inner_aval.shape + try: + return self.inner_aval.shape # pytype: disable=attribute-error + except AttributeError: + raise AttributeError( + f"`Ref{{{self.inner_aval.str_short()}}} has no `shape`." + ) from None @property def dtype(self): - if not isinstance(self.inner_aval, core.UnshapedArray): - raise AttributeError(f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`.") - return self.inner_aval.dtype + try: + return self.inner_aval.dtype # pytype: disable=attribute-error + except AttributeError: + raise AttributeError( + f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`." + ) from None @core.aval_property def at(self): diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 2e2941fca5b1..0e263844b18e 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -27,6 +27,7 @@ import tempfile import time from typing import Any, Generic, TypeVar +import weakref import jax from jax._src import config @@ -800,6 +801,21 @@ def main(token_ptr, buffers): return module, out_shape, unwrap_output_tuple +def _declare_runtime_functions(): + """Declares the runtime functions that can be used by the generated code.""" + ptr_ty = ir.Type.parse("!llvm.ptr") + i64 = ir.IntegerType.get_signless(64) + arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty] + init_tma_desc_type = ir.FunctionType.get(arg_tys, []) + func.FuncOp( + "mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private" + ) + memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], []) + func.FuncOp( + "mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private" + ) + + def as_gpu_kernel( body, grid: tuple[int, int, int], @@ -867,16 +883,97 @@ def kernel(*args): return kernel -def _declare_runtime_functions(): - """Declares the runtime functions that can be used by the generated code.""" - ptr_ty = ir.Type.parse("!llvm.ptr") - i64 = ir.IntegerType.get_signless(64) - arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty] - init_tma_desc_type = ir.FunctionType.get(arg_tys, []) - func.FuncOp( - "mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private" - ) - memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], []) - func.FuncOp( - "mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private" +def as_torch_gpu_kernel( + body, + grid: tuple[int, int, int], + block: tuple[int, int, int], + in_shape, + out_shape, + smem_scratch_shape: ShapeTree | Union[ShapeTree], + prof_spec: profiler.ProfilerSpec | None = None, + cluster: tuple[int, int, int] = (1, 1, 1), + module_name: str = "unknown", +): + try: + import torch + except ImportError: + raise RuntimeError("as_torch_gpu_kernel requires PyTorch") + torch.cuda.init() # Make sure CUDA context is set up. + + if isinstance(in_shape, list): + in_shape = tuple(in_shape) + elif not isinstance(in_shape, tuple): + in_shape = (in_shape,) + + flat_out_types, out_treedef = jax.tree.flatten(out_shape) + expected_arg_treedef = jax.tree.structure(in_shape) + + module, out_shape, unwrap_output_tuple = ( + _lower_as_gpu_kernel( + body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, + module_name, prof_spec + ) ) + + # Get our hands on the compilation and unload functions + try: + import jax_plugins.xla_cuda12 as cuda_plugin + except ImportError: + raise RuntimeError("as_torch_gpu_kernel only works with recent jaxlib builds " + "that use backend plugins") + dll = ctypes.CDLL(cuda_plugin._get_library_path()) + compile_func = dll.MosaicGpuCompile + compile_func.argtypes = [ctypes.c_void_p] + compile_func.restype = ctypes.POINTER(ctypes.c_void_p) + unload_func = dll.MosaicGpuUnload + unload_func.argtypes = [compile_func.restype] + unload_func.restype = None + + module_asm = module.operation.get_asm(binary=True, enable_debug_info=True) + compiled = compile_func(ctypes.c_char_p(module_asm)) + if compiled is None: + raise RuntimeError("Failed to compile the module") + ctx, launch_ptr = compiled[0], compiled[1] + ctx_ptr_ptr = ctypes.pointer(ctypes.c_void_p(ctx)) + launch = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(launch_ptr) + + def as_torch_dtype(dtype): + # torch contains NumPy-compatible dtypes in its top namespace + return getattr(torch, np.dtype(dtype).name) + + def apply(*args): + flat_args, arg_treedef = jax.tree.flatten(args) + if arg_treedef != expected_arg_treedef: + raise ValueError( + f"Invalid argument structure: expected {expected_arg_treedef}, got" + f" {arg_treedef}, ({args=})" + ) + + # Construct a device pointer list like in the XLA calling convention + buffers = (ctypes.c_void_p * (arg_treedef.num_leaves + out_treedef.num_leaves))() + i = -1 # Define i in case there are no args + device = 'cuda' + for i, arg in enumerate(flat_args): + buffers[i] = arg.data_ptr() + device = arg.device + flat_outs = [] + for i, t in enumerate(flat_out_types, i + 1): + out = torch.empty(t.shape, dtype=as_torch_dtype(t.dtype), device=device) + flat_outs.append(out) + buffers[i] = out.data_ptr() + # Allocate another buffer for args of the host-side program. This is sadly + # the default MLIR calling convention. + args_ptr = (ctypes.POINTER(ctypes.c_void_p) * 3)() + args_ptr[0] = ctx_ptr_ptr + args_ptr[1] = ctypes.pointer(torch.cuda.default_stream(device)._as_parameter_) + args_ptr[2] = ctypes.cast(ctypes.pointer(ctypes.pointer(buffers)), + ctypes.POINTER(ctypes.c_void_p)) + launch(args_ptr) + return jax.tree.unflatten(out_treedef, flat_outs) + + # Unload the compiled code when the Python function is destroyed. + def unload(_): + unload_func(compiled) + apply.destructor = weakref.ref(apply, unload) + + return apply diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index 52d403cd0131..775b7c2ea898 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -132,7 +132,7 @@ def build_kernel( if stages < 2: raise ValueError(f"Need at least 2 stages, but got {stages=}") if not rhs_transpose and jnp.dtype(rhs_dtype).itemsize != 2: - raise ValueError("Transpose only supported for only happen for 16bit types") + raise ValueError(f"Transpose only supported for 16bit types (got: {rhs_transpose=}, {rhs_dtype=})") if swizzle not in {32, 64, 128}: raise ValueError(f"swizzle must be 32, 64, or 128, but got {swizzle=}") diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 0b228833cbdb..502373bdc91e 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -494,6 +494,8 @@ def astype(self, new_dtype: ir.Type): convert = arith.sitofp elif from_float and to_integer: convert = arith.fptosi + else: + raise NotImplementedError(f"Unsupported conversion {cur_dtype} -> {new_dtype}") new_registers = np.empty_like(self.registers) match self.layout: case WGMMAFragLayout(): diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 554bf2641769..56003ea7af5d 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -90,19 +90,17 @@ def sync_global_devices(name: str): assert_equal(h, f"sync_global_devices name mismatch ('{name}')") +# Identity function is at the top level so that `process_allgather` doesn't +# recompile on every invocation. def _identity_fn(x): return x -@lru_cache(maxsize=128) -def _jitted_identity_fn(sharding): - return jax.jit(_identity_fn, out_shardings=sharding) - def _handle_array_process_allgather(inp, tiled): if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable: reps = sharding_impls.GSPMDSharding.get_replicated( inp.sharding._device_assignment) - out = _jitted_identity_fn(reps)(inp) + out = jax.jit(_identity_fn, out_shardings=reps)(inp) else: # All inputs here will be fully addressable. if jax.process_count() == 1: @@ -125,7 +123,8 @@ def _handle_array_process_allgather(inp, tiled): bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()] global_arr = array.make_array_from_single_device_arrays( global_aval.shape, s, bufs) - out = _jitted_identity_fn(jax.NamedSharding(global_mesh, P()))(global_arr) + out = jax.jit(_identity_fn, + out_shardings=jax.NamedSharding(global_mesh, P()))(global_arr) return np.asarray(out.addressable_data(0)) diff --git a/jax/extend/BUILD b/jax/extend/BUILD index babe0c8b10d2..59958c1da389 100644 --- a/jax/extend/BUILD +++ b/jax/extend/BUILD @@ -80,3 +80,9 @@ pytype_strict_library( srcs = ["ffi.py"], deps = ["//jax"], ) + +pytype_strict_library( + name = "ifrt_programs", + srcs = ["ifrt_programs.py"], + deps = ["//jax/_src/lib"], +) diff --git a/jax/extend/ifrt_programs.py b/jax/extend/ifrt_programs.py new file mode 100644 index 000000000000..d5fb9245af91 --- /dev/null +++ b/jax/extend/ifrt_programs.py @@ -0,0 +1,22 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/google/jax/issues/7570 + +from jax._src.lib import xla_extension as _xe + +ifrt_programs = _xe.ifrt_programs + +del _xe diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index d5b66c1b3b32..c23f659bd3f9 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -300,7 +300,8 @@ def diagonal( def diff(a: ArrayLike, n: int = ..., axis: int = ..., prepend: ArrayLike | None = ..., append: ArrayLike | None = ...) -> Array: ... -def digitize(x: ArrayLike, bins: ArrayLike, right: builtins.bool = ...) -> Array: ... +def digitize(x: ArrayLike, bins: ArrayLike, right: builtins.bool = ..., *, + method: str | None = ...) -> Array: ... divide = true_divide def divmod(x: ArrayLike, y: ArrayLike, /) -> tuple[Array, Array]: ... def dot( diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 2e5723b184a8..103f9f78c32f 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -377,10 +377,40 @@ GetKernelCache() { return std::make_pair(&context_cache, &mutex); } + +absl::StatusOr CompileAndInit(const char* module) { + mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); + InitContext(&context); + mlir::ParserConfig parse_config(&context); + auto module_op = + mlir::parseSourceString(module, parse_config); + if (!module_op) { + return absl::InternalError("Failed to parse module"); + } + auto maybe_engine = Compile(*module_op); + if (!maybe_engine.ok()) { + return maybe_engine.status(); + } + mlir::ExecutionEngine* execution_engine = maybe_engine->get(); + auto main = execution_engine->lookupPacked("_mlir_ciface_main"); + auto init = execution_engine->lookupPacked("_mlir_ciface_main_init"); + if (!init || !main) { + return absl::InternalError("Failed to retrieve kernel function"); + } + void* module_ptr = nullptr; + void* kernel_ptr = nullptr; + void** module_ptr_ptr = &module_ptr; + void** kernel_ptr_ptr = &kernel_ptr; + void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; + reinterpret_cast(*init)(init_args); + return CompiledKernel(std::move(*maybe_engine), kernel_ptr, + reinterpret_cast(*main)); +} + // Each compiled kernel has a unique init func, and each kernel is used from // a single HLO module. So it should be safe to not include the CUDA context // in the key. -absl::StatusOr> CompileAndInit( +absl::StatusOr> CachedCompileAndInit( CacheKey key, const char* module) { auto cache_and_mutex = GetKernelCache(); auto* cache = cache_and_mutex.first; @@ -397,33 +427,11 @@ absl::StatusOr> CompileAndInit( absl::MutexLock lock(mutex); // We released the reader lock, another thread might have initialized it. if (cache->find(key) == cache->end()) { - mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); - InitContext(&context); - mlir::ParserConfig parse_config(&context); - auto module_op = - mlir::parseSourceString(module, parse_config); - if (!module_op) { - return absl::InternalError("Failed to parse module"); - } - auto maybe_engine = Compile(*module_op); - if (!maybe_engine.ok()) { - return maybe_engine.status(); + auto compiled = CompileAndInit(module); + if (!compiled.ok()) { + return compiled.status(); } - mlir::ExecutionEngine* execution_engine = maybe_engine->get(); - auto main = execution_engine->lookupPacked("_mlir_ciface_main"); - auto init = execution_engine->lookupPacked("_mlir_ciface_main_init"); - if (!init || !main) { - return absl::InternalError("Failed to retrieve kernel function"); - } - void* module_ptr = nullptr; - void* kernel_ptr = nullptr; - void** module_ptr_ptr = &module_ptr; - void** kernel_ptr_ptr = &kernel_ptr; - void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; - reinterpret_cast(*init)(init_args); - cache->insert_or_assign( - key, CompiledKernel(std::move(*maybe_engine), kernel_ptr, - reinterpret_cast(*main))); + cache->insert_or_assign(key, std::move(*compiled)); } return cache->at(key).GetHostLaunch(); } @@ -441,7 +449,7 @@ void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, abort(); } CacheKey key(hash, reinterpret_cast(ctx)); - auto ctx_and_kernel = CompileAndInit(key, opaque + sizeof(KernelHash)); + auto ctx_and_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash)); if (!ctx_and_kernel.ok()) { XlaCustomCallStatusSetFailure(status, ctx_and_kernel.status().message().data(), @@ -456,3 +464,33 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall, "CUDA"); } // namespace + +extern "C" { + +__attribute__((visibility("default"))) +void** MosaicGpuCompile(const char* module) { + auto compiled = CompileAndInit(module); + if (!compiled.ok()) { + return nullptr; + } + auto [ctx, launch] = compiled->GetHostLaunch(); + auto tuple_ptr = std::unique_ptr(new void*[3]); + if (!tuple_ptr) { + return nullptr; + } + tuple_ptr.get()[0] = ctx; + tuple_ptr.get()[1] = reinterpret_cast(launch); + tuple_ptr.get()[2] = new CompiledKernel(std::move(*compiled)); + if (!tuple_ptr.get()[2]) { + return nullptr; + } + return tuple_ptr.release(); +} + +__attribute__((visibility("default"))) +void MosaicGpuUnload(void** tuple_ptr) { + delete reinterpret_cast(tuple_ptr[2]); + delete[] tuple_ptr; +} + +} // extern "C" diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 8463cba08c5f..4642af12011d 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -64,11 +64,12 @@ py_test( cc_binary( name = "pjrt_c_api_gpu_plugin.so", linkopts = [ - "-Wl,--version-script,$(location @xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds)", + "-Wl,--version-script,$(location :gpu_version_script.lds)", "-Wl,--no-undefined", ], linkshared = True, deps = [ + ":gpu_version_script.lds", "@xla//xla/pjrt/c:pjrt_c_api_gpu", "@xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds", "@xla//xla/service:gpu_plugin", diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 28d2806a7da9..ced0b76c344c 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -74,7 +74,7 @@ def write_setup_cfg(sources_path, cpu): license_files = LICENSE.txt [bdist_wheel] -plat-name={tag} +plat_name={tag} """) diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 73cb8a9e020d..0e2bba0c74d0 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -80,7 +80,7 @@ def write_setup_cfg(sources_path, cpu): license_files = LICENSE.txt [bdist_wheel] -plat-name={tag} +plat_name={tag} python-tag=py3 """ ) diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 48aab847f3fb..6305b0c24aa8 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -164,7 +164,7 @@ def write_setup_cfg(sources_path, cpu): license_files = LICENSE.txt [bdist_wheel] -plat-name={tag} +plat_name={tag} """ ) diff --git a/jaxlib/tools/gpu_version_script.lds b/jaxlib/tools/gpu_version_script.lds new file mode 100644 index 000000000000..8e46b2c590b2 --- /dev/null +++ b/jaxlib/tools/gpu_version_script.lds @@ -0,0 +1,11 @@ +VERS_1.0 { + global: + extern "C" { + GetPjrtApi; + MosaicGpuCompile; + MosaicGpuUnload; + }; + + local: + *; +}; diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index 273c12f1b13c..5532fdf0303f 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -80,6 +80,18 @@ def f(x): jax.effects_barrier() self.assertEqual(output(), "x: 2\n") + def test_static_args(self): + @jax.jit + def f(arr): + jax.debug.print("arr {array}, dtype: {dtype}, arr {array2}", + array=arr, dtype=arr.dtype, array2=arr) + arr = jnp.array([1, 2, 3], dtype=jnp.float32) + with jtu.capture_stdout() as output: + f(arr) + jax.effects_barrier() + self.assertEqual( + output(), "arr [1. 2. 3.], dtype: float32, arr [1. 2. 3.]\n") + def test_debug_print_works_with_named_format_strings(self): def f(x): debug_print('x: {x}', x=x) diff --git a/tests/filecheck/math.filecheck.py b/tests/filecheck/math.filecheck.py index e75e8e7d735f..f34b8211eb33 100644 --- a/tests/filecheck/math.filecheck.py +++ b/tests/filecheck/math.filecheck.py @@ -419,7 +419,7 @@ def integer_pow(x): return lax.integer_pow(x, 3) print_ir(jnp.bfloat16(0))(lax.sqrt) # CHECK-LABEL: TEST: tan float16[] - # CHECK: chlo.tan + # CHECK: hlo.tan # CHECK-SAME: tensor print_ir(np.float16(0))(lax.tan) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 01c89caf7a22..f93e28dada71 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1478,6 +1478,12 @@ def testTrimZeros(self, a_shape, dtype, trim): jnp_fun = lambda arg1: jnp.trim_zeros(arg1, trim) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + def testTrimZerosNotOneDArray(self): + # TODO: make this an error after the deprecation period. + with self.assertWarnsRegex(DeprecationWarning, + r"Passing arrays with ndim != 1 to jnp.trim_zeros\(\)"): + jnp.trim_zeros(jnp.array([[0.0, 1.0, 0.0],[2.0, 4.5, 0.0]])) + @jtu.sample_product( rank=(1, 2), dtype=default_dtypes, @@ -4289,14 +4295,8 @@ def testSortStableDescending(self): self.assertArraysEqual(jnp.argsort(x), argsorted_stable) self.assertArraysEqual(jnp.argsort(x, descending=True), argsorted_rev_stable) - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in one_dim_array_shapes - for axis in [None] - ], - dtype=all_dtypes, - ) - def testSortComplex(self, dtype, shape, axis): + @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes) + def testSortComplex(self, shape, dtype): rng = jtu.rand_some_equal(self.rng()) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np.sort_complex, jnp.sort_complex, args_maker, diff --git a/tests/memories_test.py b/tests/memories_test.py index 68aecfdf669f..3e0f444a1e66 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -742,6 +742,29 @@ def h(x): self.assertArraysEqual(out2, inp * 6) self.assertEqual(out2.sharding.memory_kind, 'pinned_host') + def test_compute_on_basic_inline(self): + @compute_on('device_host') + @jax.jit + def g(x): + return x * 2 + + @functools.partial(jax.jit, inline=True) + def h(x): + y = g(x) + return y * 3 + + @jax.jit + def f(x): + return h(x) + + inp = jnp.arange(8) + out = f(inp) + self.assertArraysEqual(out, inp * 6) + + lowered_text = f.lower(jnp.arange(8)).as_text('hlo') + self.assertRegex(lowered_text, + 'to_apply=g.*frontend_attributes={_xla_compute_type="host"}') + def test_compute_on_reduction(self): out_s = SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host') diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 1d6f6eb9e584..1a29bbb5736d 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -19,6 +19,7 @@ import itertools import math import operator +import unittest from absl.testing import absltest, parameterized import jax @@ -1389,5 +1390,28 @@ def kernel(ctx, src, dst, _): jax.block_until_ready(f(xd)) +class TorchTest(TestCase): + + @classmethod + def setUpClass(cls): + try: + import torch + except ImportError: + raise unittest.SkipTest("Test requires PyTorch") + cls.torch = torch + + def test_basic(self): + def kernel(ctx, i_gmem, o_gmem, _): + x = mgpu.FragmentedArray.load_strided(i_gmem) + (x + x).store_untiled(o_gmem) + + ty = jax.ShapeDtypeStruct((128, 128), jnp.float32) + x = self.torch.randn((128, 128), dtype=self.torch.float, device='cuda') + f = mosaic_gpu.as_torch_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), ty, ty, ()) + y = f(x) + np.testing.assert_allclose(y.cpu(), x.cpu() * 2) + del y # Make sure the destructor runs successfully. + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/nn_test.py b/tests/nn_test.py index 3722db42671c..be07de184e60 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -38,11 +38,11 @@ config.parse_flags_with_absl() -def _is_required_cudnn_version_satisfied(): +def _is_required_cudnn_version_satisfied(min_cudnn_version): return ( jtu.is_cuda_compute_capability_at_least("8.0") and cuda_versions is not None and - cuda_versions.cudnn_get_version() >= 8904 + cuda_versions.cudnn_get_version() >= min_cudnn_version ) def _check_cudnn_backend(fn, *args, **kwargs): @@ -60,7 +60,7 @@ class NNFunctionsTest(jtu.JaxTestCase): impl=['cudnn', 'xla'], ) def testDotProductAttention(self, dtype, group_num, use_vmap, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): + if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(8904): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") if impl == 'cudnn' and dtype == jnp.float32: raise unittest.SkipTest("cuDNN only supports fp16 or bf16.") @@ -102,13 +102,15 @@ def testDotProductAttention(self, dtype, group_num, use_vmap, impl): @parameterized.product( mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'), - ('custom', 'padding'), ('bias', 'causal')], + ('custom', 'padding'), ('bias', 'causal'), + ('causal', 'sliding_window')], ) def testDotProductAttentionMask(self, mask_mode): - if not _is_required_cudnn_version_satisfied(): - raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") if isinstance(mask_mode, str): mask_mode = (mask_mode,) + min_cudnn_version = 90200 if 'sliding_window' in mask_mode else 8904 + if not _is_required_cudnn_version_satisfied(min_cudnn_version): + raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") dtype = jnp.bfloat16 B, S, T, N, H = 2, 128, 128, 4, 32 @@ -119,6 +121,7 @@ def testDotProductAttentionMask(self, mask_mode): grad = random.normal(keys[3], (B, T, N, H), dtype) bias, mask = None, None q_seqlen, kv_seqlen = None, None + window_size = None is_causal = 'causal' in mask_mode if 'padding' in mask_mode: @@ -130,6 +133,8 @@ def testDotProductAttentionMask(self, mask_mode): mask = custom_mask[None, None, :, :] if 'bias' in mask_mode: bias = random.normal(keys[4], (1, N, T, S), dtype) + if 'sliding_window' in mask_mode: + window_size = (3, 2) if is_causal else (3, 0) sdpa = nn.dot_product_attention sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None) @@ -141,9 +146,11 @@ def testDotProductAttentionMask(self, mask_mode): # Convert the kargs to positional args for the jax.vjp. fn_ref = lambda q, k, v, b, m, qs, kvs: sdpa_ref( q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs, + local_window_size=window_size, ) fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans( q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs, + local_window_size=window_size, ) out_ref, sdpa_vjp_ref = jax.vjp(fn_ref, *args, q_seqlen, kv_seqlen) out_ans, sdpa_vjp_ans = jax.vjp(fn_ans, *args, q_seqlen, kv_seqlen) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index bd9df6182793..746c4e93387b 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -79,15 +79,14 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_one_grid_with_scratch(self): + @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32), - grid_spec=plgpu.GPUGridSpec( - in_specs=[pl.BlockSpec((128,), lambda *i: i)], - out_specs=pl.BlockSpec((128,), lambda *i: i), - scratch_shapes=[plgpu.SMEM((128,), jnp.float32)], - grid=2, - ), + in_specs=[pl.BlockSpec((128,), lambda *i: i)], + out_specs=pl.BlockSpec((128,), lambda *i: i), + scratch_shapes=[plgpu.SMEM((128,), jnp.float32)], + grid=2, ) def kernel(x_ref, o_ref, scratch_ref): scratch_ref[...] = x_ref[...] + 1 @@ -120,10 +119,8 @@ def test_add_one_with_async_copy_smem_to_gmem(self): @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), - grid_spec=plgpu.GPUGridSpec( - out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), - scratch_shapes=[plgpu.SMEM((128,), jnp.float32)], - ), + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + scratch_shapes=[plgpu.SMEM((128,), jnp.float32)], ) def kernel(x_ref, o_ref_gmem, scratch_ref): scratch_ref[...] = x_ref[...] + 1 @@ -134,16 +131,15 @@ def kernel(x_ref, o_ref_gmem, scratch_ref): np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_one_with_async_copy_gmem_to_smem(self): + @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), - grid_spec=plgpu.GPUGridSpec( - in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), - scratch_shapes=[ - plgpu.SMEM((128,), jnp.float32), - plgpu.Barrier(num_arrivals=1), - ], - ), + in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), + scratch_shapes=[ + plgpu.SMEM((128,), jnp.float32), + plgpu.Barrier(num_arrivals=1), + ], ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): plgpu.async_copy_gmem_to_smem( diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 6c022653581d..11a541f2e5f5 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -57,6 +57,7 @@ from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension +from jax._src.lib import xla_extension_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -652,18 +653,16 @@ def testAutodiff(self, mesh, resources): @jtu.with_mesh([('x', 2), ('y', 1)]) def testAutodiffCache(self): - f = pjit( - lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None - ) + f = pjit(lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None) x = jnp.arange(16, dtype=jnp.float32) - jax.grad(f)(x) # Warm up the cache. - before = pjit_lib._pjit_lower_cached.cache_info() - jax.grad(f)(x) - after = pjit_lib._pjit_lower_cached.cache_info() - # One hit for the forward pass, one hit for backward. - self.assertEqual(after.hits, before.hits + 2) - self.assertEqual(after.misses, before.misses) + jax.grad(f)(x) # Warm up the cache. + with jtu.count_pjit_cpp_cache_miss() as count: + jax.grad(f)(x) + if xla_extension_version >= 286: + self.assertEqual(count[0], 0) # no cache miss i.e. cache hit + else: + self.assertEqual(count[0], 2) @jtu.with_mesh([('x', 2), ('y', 1)]) def testEvalJaxpr(self): @@ -4531,6 +4530,20 @@ def test_wsc_abstract_mesh_errors(self): ' match the mesh shape of the target sharding.*'): with_sharding_constraint(arr, NamedSharding(abs_mesh2, P('y'))) + @unittest.skipIf(xla_extension_version < 286, + "Requires xla_extension_version >= 286") + def test_global_jit_cpp_cache_hit_out_shardings(self): + mesh = jtu.create_mesh((2,), 'x') + s = NamedSharding(mesh, P('x')) + + def f(x): + return x * 2 + + with jtu.count_pjit_cpp_cache_miss() as count: + jax.jit(f, out_shardings=s)(np.arange(8)) + jax.jit(f, out_shardings=s)(np.arange(8)) + self.assertEqual(count[0], 1) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)")