-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add FFI example project and test on CI.
This PR includes an end-to-end example project which demonstrates the use of the FFI. This complements [the FFI tutorial](https://jax.readthedocs.io/en/latest/ffi.html) by putting all of the code in one place, as well as demonstrating how FFI extensions can be packaged. Alongside the example project, I have also added a new GitHub Actions workflow to test the example as part of CI. For now, the tests only run on CPU, but once we have GPU runners for GitHub Actions (soon!), I plan on migrating the custom call examples from `docs/gpu_ops` and `docs/cuda_custom_call` into this test case. Similarly, I wanted to start small and this example project only includes exactly the same functions as the tutorial for now, but I think this could be a good place to showcase more advanced examples (including custom calls with state).
- Loading branch information
Showing
7 changed files
with
381 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
name: Examples | ||
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
pull_request: | ||
branches: | ||
- main | ||
|
||
permissions: | ||
contents: read # to fetch code | ||
actions: write # to cancel previous workflows | ||
|
||
jobs: | ||
ffi: | ||
name: FFI examples | ||
runs-on: ubuntu-latest | ||
timeout-minutes: 5 | ||
steps: | ||
- name: Cancel previous | ||
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 | ||
with: | ||
access_token: ${{ github.token }} | ||
if: ${{github.ref != 'refs/heads/main'}} | ||
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 | ||
- name: Set up Python 3.11 | ||
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 | ||
with: | ||
python-version: 3.11 | ||
- name: Get pip cache dir | ||
id: pip-cache | ||
run: | | ||
python -m pip install --upgrade pip wheel | ||
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT | ||
- name: pip cache | ||
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4 | ||
with: | ||
path: ${{ steps.pip-cache.outputs.dir }} | ||
key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', 'examples/**/pyproject.toml') }} | ||
- name: Install JAX | ||
run: pip install . | ||
- name: Install dependencies | ||
run: python -m pip install ./examples/ffi[test] | ||
- name: Run tests | ||
run: python -m pytest examples/ffi/tests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
cmake_minimum_required(VERSION 3.15...3.30) | ||
project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) | ||
|
||
find_package(Python 3.8 REQUIRED COMPONENTS Interpreter Development.Module) | ||
execute_process( | ||
COMMAND "${Python_EXECUTABLE}" | ||
"-c" "from jax.extend import ffi; print(ffi.include_dir())" | ||
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR) | ||
message(STATUS "XLA include directory: ${XLA_DIR}") | ||
|
||
find_package(nanobind CONFIG REQUIRED) | ||
|
||
nanobind_add_module(_rms_norm NB_STATIC "src/jax_ffi_examples/rms_norm.cc") | ||
target_include_directories(_rms_norm PUBLIC ${XLA_DIR}) | ||
install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
[build-system] | ||
requires = ["scikit-build-core", "nanobind", "jax>=0.4.31"] | ||
build-backend = "scikit_build_core.build" | ||
|
||
[project] | ||
name = "jax_ffi_examples" | ||
version = "0.0.1" | ||
dependencies = ["jax"] | ||
|
||
[project.optional-dependencies] | ||
test = ["pytest", "absl-py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2024 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
/* Copyright 2024 The JAX Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include <cmath> | ||
#include <cstdint> | ||
#include <functional> | ||
#include <numeric> | ||
#include <type_traits> | ||
#include <utility> | ||
|
||
#include "nanobind/nanobind.h" | ||
#include "xla/ffi/api/c_api.h" | ||
#include "xla/ffi/api/ffi.h" | ||
|
||
namespace nb = nanobind; | ||
namespace ffi = xla::ffi; | ||
|
||
// This is the example "library function" that we want to expose to JAX. This | ||
// isn't meant to be a particularly good implementation, it's just here as a | ||
// placeholder for the purposes of this tutorial. | ||
float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) { | ||
float sm = 0.0f; | ||
for (int64_t n = 0; n < size; ++n) { | ||
sm += x[n] * x[n]; | ||
} | ||
float scale = 1.0f / std::sqrt(sm / float(size) + eps); | ||
for (int64_t n = 0; n < size; ++n) { | ||
y[n] = x[n] * scale; | ||
} | ||
return scale; | ||
} | ||
|
||
// A helper function for extracting the relevant dimensions from `ffi::Buffer`s. | ||
// In this example, we treat all leading dimensions as batch dimensions, so this | ||
// function returns the total number of elements in the buffer, and the size of | ||
// the last dimension. | ||
template <ffi::DataType T> | ||
std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) { | ||
auto dims = buffer.dimensions(); | ||
if (dims.size() == 0) { | ||
return std::make_pair(0, 0); | ||
} | ||
return std::make_pair(buffer.element_count(), dims.back()); | ||
} | ||
|
||
// A wrapper function providing the interface between the XLA FFI call and our | ||
// library function `ComputeRmsNorm` above. This function handles the batch | ||
// dimensions by calling `ComputeRmsNorm` within a loop. | ||
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x, | ||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y) { | ||
auto [totalSize, lastDim] = GetDims(x); | ||
if (lastDim == 0) { | ||
return ffi::Error(ffi::ErrorCode::kInvalidArgument, | ||
"RmsNorm input must be an array"); | ||
} | ||
for (int64_t n = 0; n < totalSize; n += lastDim) { | ||
ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n])); | ||
} | ||
return ffi::Error::Success(); | ||
} | ||
|
||
// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare | ||
// this handler in a header, you can use the `XLA_FFI_DECLASE_HANDLER_SYMBOL` | ||
// macro: `XLA_FFI_DECLASE_HANDLER_SYMBOL(RmsNorm)`. | ||
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl, | ||
ffi::Ffi::Bind() | ||
.Attr<float>("eps") | ||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x | ||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // y | ||
); | ||
|
||
ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::DataType::F32> x, | ||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y, | ||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> res) { | ||
auto [totalSize, lastDim] = GetDims(x); | ||
if (lastDim == 0) { | ||
return ffi::Error(ffi::ErrorCode::kInvalidArgument, | ||
"RmsNormFwd input must be an array"); | ||
} | ||
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { | ||
res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), | ||
&(y->typed_data()[n])); | ||
} | ||
return ffi::Error::Success(); | ||
} | ||
|
||
XLA_FFI_DEFINE_HANDLER_SYMBOL( | ||
RmsNormFwd, RmsNormFwdImpl, | ||
ffi::Ffi::Bind() | ||
.Attr<float>("eps") | ||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x | ||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // y | ||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // res | ||
); | ||
|
||
void ComputeRmsNormBwd(int64_t size, float res, const float *x, | ||
const float *ct_y, float *ct_x) { | ||
float ct_res = 0.0f; | ||
for (int64_t n = 0; n < size; ++n) { | ||
ct_res += x[n] * ct_y[n]; | ||
} | ||
float factor = ct_res * res * res * res / float(size); | ||
for (int64_t n = 0; n < size; ++n) { | ||
ct_x[n] = res * ct_y[n] - factor * x[n]; | ||
} | ||
} | ||
|
||
ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::DataType::F32> res, | ||
ffi::Buffer<ffi::DataType::F32> x, | ||
ffi::Buffer<ffi::DataType::F32> ct_y, | ||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> ct_x) { | ||
auto [totalSize, lastDim] = GetDims(x); | ||
if (lastDim == 0) { | ||
return ffi::Error(ffi::ErrorCode::kInvalidArgument, | ||
"RmsNormBwd inputs must be arrays"); | ||
} | ||
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { | ||
ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]), | ||
&(ct_y.typed_data()[n]), &(ct_x->typed_data()[n])); | ||
} | ||
return ffi::Error::Success(); | ||
} | ||
|
||
XLA_FFI_DEFINE_HANDLER_SYMBOL( | ||
RmsNormBwd, RmsNormBwdImpl, | ||
ffi::Ffi::Bind() | ||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // res | ||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x | ||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // ct_y | ||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // ct_x | ||
); | ||
|
||
template <typename T> | ||
nb::capsule EncapsulateFfiHandler(T *fn) { | ||
static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>, | ||
"Encapsulated function must be and XLA FFI handler"); | ||
return nb::capsule(reinterpret_cast<void *>(fn)); | ||
} | ||
|
||
NB_MODULE(_rms_norm, m) { | ||
m.def("registrations", []() { | ||
nb::dict registrations; | ||
registrations["rms_norm"] = EncapsulateFfiHandler(RmsNorm); | ||
registrations["rms_norm_fwd"] = EncapsulateFfiHandler(RmsNormFwd); | ||
registrations["rms_norm_bwd"] = EncapsulateFfiHandler(RmsNormBwd); | ||
return registrations; | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright 2024 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from functools import partial | ||
|
||
import numpy as np | ||
|
||
import jax | ||
import jax.extend as jex | ||
import jax.numpy as jnp | ||
|
||
from jax_ffi_examples import _rms_norm | ||
|
||
for name, target in _rms_norm.registrations().items(): | ||
jex.ffi.register_ffi_target(name, target) | ||
|
||
|
||
@partial(jax.custom_vjp, nondiff_argnums=(1,)) | ||
def rms_norm(x, eps=1e-5): | ||
# We only implemented the `float32` version of this function, so we start by | ||
# checking the dtype. This check isn't strictly necessary because type | ||
# checking is also performed by the FFI when decoding input and output | ||
# buffers, but it can be useful to check types in Python to raise more | ||
# informative errors. | ||
if x.dtype != jnp.float32: | ||
raise ValueError("Only the float32 dtype is implemented by rms_norm") | ||
|
||
# In this case, the output of our FFI function is just a single array with the | ||
# same shape and dtype as the input. We discuss a case with a more interesting | ||
# output type below. | ||
out_type = jax.ShapeDtypeStruct(x.shape, x.dtype) | ||
|
||
return jex.ffi.ffi_call( | ||
# The target name must be the same string as we used to register the target | ||
# above in `register_custom_call_target` | ||
"rms_norm", | ||
out_type, | ||
x, | ||
# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for | ||
# the attribute `eps`. Our FFI function expects this to have the C++ `float` | ||
# type (which corresponds to numpy's `float32` type), and it must be a | ||
# static parameter (i.e. not a JAX array). | ||
eps=np.float32(eps), | ||
# The `vectorized` parameter controls this function's behavior under `vmap` | ||
# as discussed below. | ||
vectorized=True, | ||
) | ||
|
||
|
||
def rms_norm_fwd(x, eps=1e-5): | ||
y, res = jex.ffi.ffi_call( | ||
"rms_norm_fwd", | ||
( | ||
jax.ShapeDtypeStruct(x.shape, x.dtype), | ||
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype), | ||
), | ||
x, | ||
eps=np.float32(eps), | ||
vectorized=True, | ||
) | ||
return y, (res, x) | ||
|
||
|
||
def rms_norm_bwd(eps, res, ct): | ||
del eps | ||
res, x = res | ||
assert res.shape == ct.shape[:-1] | ||
assert x.shape == ct.shape | ||
return ( | ||
jex.ffi.ffi_call( | ||
"rms_norm_bwd", | ||
jax.ShapeDtypeStruct(ct.shape, ct.dtype), | ||
res, | ||
x, | ||
ct, | ||
vectorized=True, | ||
), | ||
) | ||
|
||
|
||
rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright 2024 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from absl.testing import absltest | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jax._src import test_util as jtu | ||
|
||
from jax_ffi_examples import rms_norm | ||
|
||
|
||
def rms_norm_ref(x, eps=1e-5): | ||
scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps) | ||
return x / scale | ||
|
||
|
||
class RmsNormTests(jtu.JaxTestCase): | ||
def test_basic(self): | ||
x = jnp.linspace(-0.5, 0.5, 15) | ||
self.assertAllClose(rms_norm.rms_norm(x), rms_norm_ref(x)) | ||
|
||
def test_batching(self): | ||
x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5)) | ||
self.assertAllClose(jax.vmap(rms_norm.rms_norm)(x), jax.vmap(rms_norm_ref)(x)) | ||
|
||
def test_grads(self): | ||
x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5)) | ||
jtu.check_grads(rms_norm.rms_norm, (x,), order=1, modes=("rev",)) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main(testLoader=jtu.JaxTestLoader()) |