diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml new file mode 100644 index 000000000000..581462e37e01 --- /dev/null +++ b/.github/workflows/examples.yml @@ -0,0 +1,46 @@ +name: Examples + +on: + push: + branches: + - main + pull_request: + branches: + - main + +permissions: + contents: read # to fetch code + actions: write # to cancel previous workflows + +jobs: + ffi: + name: FFI + runs-on: ubuntu-latest + timeout-minutes: 5 + steps: + - name: Cancel previous + uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + with: + access_token: ${{ github.token }} + if: ${{github.ref != 'refs/heads/main'}} + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 + - name: Set up Python 3.11 + uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + with: + python-version: 3.11 + - name: Get pip cache dir + id: pip-cache + run: | + python -m pip install --upgrade pip wheel + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + - name: pip cache + uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', 'examples/**/pyproject.toml') }} + - name: Install JAX + run: pip install . + - name: Install dependencies + run: python -m pip install ./examples/ffi[test] + - name: Run tests + run: python -m pytest examples/ffi/tests diff --git a/examples/ffi/CMakeLists.txt b/examples/ffi/CMakeLists.txt new file mode 100644 index 000000000000..3638a82cb77f --- /dev/null +++ b/examples/ffi/CMakeLists.txt @@ -0,0 +1,15 @@ +cmake_minimum_required(VERSION 3.15...3.30) +project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) + +find_package(Python 3.8 REQUIRED COMPONENTS Interpreter Development.Module) +execute_process( + COMMAND "${Python_EXECUTABLE}" + "-c" "from jax.extend import ffi; print(ffi.include_dir())" + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR) +message(STATUS "XLA include directory: ${XLA_DIR}") + +find_package(nanobind CONFIG REQUIRED) + +nanobind_add_module(_rms_norm NB_STATIC "src/jax_ffi_examples/rms_norm.cc") +target_include_directories(_rms_norm PUBLIC ${XLA_DIR}) +install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) diff --git a/examples/ffi/README.md b/examples/ffi/README.md new file mode 100644 index 000000000000..cc7018782a25 --- /dev/null +++ b/examples/ffi/README.md @@ -0,0 +1,9 @@ +# End-to-end example usage for JAX's foreign function interface + +This directory includes an example project demonstrating the use of JAX's +foreign function interface (FFI). The JAX docs provide more information about +this interface in [the FFI tutorial](https://jax.readthedocs.io/en/latest/ffi.html), +but the example in this directory explicitly demonstrates: + +1. One way to package and distribute FFI targets, and +2. Some more advanced use cases. diff --git a/examples/ffi/pyproject.toml b/examples/ffi/pyproject.toml new file mode 100644 index 000000000000..2d700016ecac --- /dev/null +++ b/examples/ffi/pyproject.toml @@ -0,0 +1,11 @@ +[build-system] +requires = ["scikit-build-core", "nanobind", "jax>=0.4.31"] +build-backend = "scikit_build_core.build" + +[project] +name = "jax_ffi_examples" +version = "0.0.1" +dependencies = ["jax"] + +[project.optional-dependencies] +test = ["pytest", "absl-py"] diff --git a/examples/ffi/src/jax_ffi_examples/__init__.py b/examples/ffi/src/jax_ffi_examples/__init__.py new file mode 100644 index 000000000000..862a661e24b9 --- /dev/null +++ b/examples/ffi/src/jax_ffi_examples/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/ffi/src/jax_ffi_examples/rms_norm.cc b/examples/ffi/src/jax_ffi_examples/rms_norm.cc new file mode 100644 index 000000000000..5f8e80507412 --- /dev/null +++ b/examples/ffi/src/jax_ffi_examples/rms_norm.cc @@ -0,0 +1,160 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "nanobind/nanobind.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +namespace nb = nanobind; +namespace ffi = xla::ffi; + +// This is the example "library function" that we want to expose to JAX. This +// isn't meant to be a particularly good implementation, it's just here as a +// placeholder for the purposes of this tutorial. +float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) { + float sm = 0.0f; + for (int64_t n = 0; n < size; ++n) { + sm += x[n] * x[n]; + } + float scale = 1.0f / std::sqrt(sm / float(size) + eps); + for (int64_t n = 0; n < size; ++n) { + y[n] = x[n] * scale; + } + return scale; +} + +// A helper function for extracting the relevant dimensions from `ffi::Buffer`s. +// In this example, we treat all leading dimensions as batch dimensions, so this +// function returns the total number of elements in the buffer, and the size of +// the last dimension. +template +std::pair GetDims(const ffi::Buffer &buffer) { + auto dims = buffer.dimensions(); + if (dims.size() == 0) { + return std::make_pair(0, 0); + } + return std::make_pair(buffer.element_count(), dims.back()); +} + +// A wrapper function providing the interface between the XLA FFI call and our +// library function `ComputeRmsNorm` above. This function handles the batch +// dimensions by calling `ComputeRmsNorm` within a loop. +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::Result> y) { + auto [totalSize, lastDim] = GetDims(x); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNorm input must be an array"); + } + for (int64_t n = 0; n < totalSize; n += lastDim) { + ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n])); + } + return ffi::Error::Success(); +} + +// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare +// this handler in a header, you can use the `XLA_FFI_DECLASE_HANDLER_SYMBOL` +// macro: `XLA_FFI_DECLASE_HANDLER_SYMBOL(RmsNorm)`. +XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl, + ffi::Ffi::Bind() + .Attr("eps") + .Arg>() // x + .Ret>() // y +); + +ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, + ffi::Result> y, + ffi::Result> res) { + auto [totalSize, lastDim] = GetDims(x); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNormFwd input must be an array"); + } + for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { + res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), + &(y->typed_data()[n])); + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + RmsNormFwd, RmsNormFwdImpl, + ffi::Ffi::Bind() + .Attr("eps") + .Arg>() // x + .Ret>() // y + .Ret>() // res +); + +void ComputeRmsNormBwd(int64_t size, float res, const float *x, + const float *ct_y, float *ct_x) { + float ct_res = 0.0f; + for (int64_t n = 0; n < size; ++n) { + ct_res += x[n] * ct_y[n]; + } + float factor = ct_res * res * res * res / float(size); + for (int64_t n = 0; n < size; ++n) { + ct_x[n] = res * ct_y[n] - factor * x[n]; + } +} + +ffi::Error RmsNormBwdImpl(ffi::Buffer res, + ffi::Buffer x, + ffi::Buffer ct_y, + ffi::Result> ct_x) { + auto [totalSize, lastDim] = GetDims(x); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNormBwd inputs must be arrays"); + } + for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { + ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]), + &(ct_y.typed_data()[n]), &(ct_x->typed_data()[n])); + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + RmsNormBwd, RmsNormBwdImpl, + ffi::Ffi::Bind() + .Arg>() // res + .Arg>() // x + .Arg>() // ct_y + .Ret>() // ct_x +); + +template +nb::capsule EncapsulateFfiHandler(T *fn) { + static_assert(std::is_invocable_r_v, + "Encapsulated function must be and XLA FFI handler"); + return nb::capsule(reinterpret_cast(fn)); +} + +NB_MODULE(_rms_norm, m) { + m.def("registrations", []() { + nb::dict registrations; + registrations["rms_norm"] = EncapsulateFfiHandler(RmsNorm); + registrations["rms_norm_fwd"] = EncapsulateFfiHandler(RmsNormFwd); + registrations["rms_norm_bwd"] = EncapsulateFfiHandler(RmsNormBwd); + return registrations; + }); +} diff --git a/examples/ffi/src/jax_ffi_examples/rms_norm.py b/examples/ffi/src/jax_ffi_examples/rms_norm.py new file mode 100644 index 000000000000..eac19b85af20 --- /dev/null +++ b/examples/ffi/src/jax_ffi_examples/rms_norm.py @@ -0,0 +1,92 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import numpy as np + +import jax +import jax.extend as jex +import jax.numpy as jnp + +from jax_ffi_examples import _rms_norm + +for name, target in _rms_norm.registrations().items(): + jex.ffi.register_ffi_target(name, target) + + +@partial(jax.custom_vjp, nondiff_argnums=(1,)) +def rms_norm(x, eps=1e-5): + # We only implemented the `float32` version of this function, so we start by + # checking the dtype. This check isn't strictly necessary because type + # checking is also performed by the FFI when decoding input and output + # buffers, but it can be useful to check types in Python to raise more + # informative errors. + if x.dtype != jnp.float32: + raise ValueError("Only the float32 dtype is implemented by rms_norm") + + # In this case, the output of our FFI function is just a single array with the + # same shape and dtype as the input. We discuss a case with a more interesting + # output type below. + out_type = jax.ShapeDtypeStruct(x.shape, x.dtype) + + return jex.ffi.ffi_call( + # The target name must be the same string as we used to register the target + # above in `register_custom_call_target` + "rms_norm", + out_type, + x, + # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for + # the attribute `eps`. Our FFI function expects this to have the C++ `float` + # type (which corresponds to numpy's `float32` type), and it must be a + # static parameter (i.e. not a JAX array). + eps=np.float32(eps), + # The `vectorized` parameter controls this function's behavior under `vmap` + # as discussed below. + vectorized=True, + ) + + +def rms_norm_fwd(x, eps=1e-5): + y, res = jex.ffi.ffi_call( + "rms_norm_fwd", + ( + jax.ShapeDtypeStruct(x.shape, x.dtype), + jax.ShapeDtypeStruct(x.shape[:-1], x.dtype), + ), + x, + eps=np.float32(eps), + vectorized=True, + ) + return y, (res, x) + + +def rms_norm_bwd(eps, res, ct): + del eps + res, x = res + assert res.shape == ct.shape[:-1] + assert x.shape == ct.shape + return ( + jex.ffi.ffi_call( + "rms_norm_bwd", + jax.ShapeDtypeStruct(ct.shape, ct.dtype), + res, + x, + ct, + vectorized=True, + ), + ) + + +rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd) diff --git a/examples/ffi/tests/rms_norm_test.py b/examples/ffi/tests/rms_norm_test.py new file mode 100644 index 000000000000..2dc302f3224c --- /dev/null +++ b/examples/ffi/tests/rms_norm_test.py @@ -0,0 +1,44 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +import jax +import jax.numpy as jnp +from jax._src import test_util as jtu + +from jax_ffi_examples import rms_norm + + +def rms_norm_ref(x, eps=1e-5): + scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps) + return x / scale + + +class RmsNormTests(jtu.JaxTestCase): + def test_basic(self): + x = jnp.linspace(-0.5, 0.5, 15) + self.assertAllClose(rms_norm.rms_norm(x), rms_norm_ref(x)) + + def test_batching(self): + x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5)) + self.assertAllClose(jax.vmap(rms_norm.rms_norm)(x), jax.vmap(rms_norm_ref)(x)) + + def test_grads(self): + x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5)) + jtu.check_grads(rms_norm.rms_norm, (x,), order=1, modes=("rev",)) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/pyproject.toml b/pyproject.toml index fb706dbbf8ea..b5fc33632cf1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ doctest_optionflags = [ "NUMBER", "NORMALIZE_WHITESPACE" ] -addopts = "--doctest-glob='*.rst'" +addopts = "--doctest-glob='*.rst' --ignore=examples" [tool.pylint.master] extension-pkg-whitelist = "numpy" @@ -115,7 +115,7 @@ ignore = [ "C408", # Unnecessary map usage "C417", - # Unnecessary dict comprehension for iterable + # Unnecessary dict comprehension for iterable "C420", # Object names too complex "C901",