Skip to content

Commit

Permalink
Add FFI example project and test on CI.
Browse files Browse the repository at this point in the history
This PR includes an end-to-end example project which demonstrates the
use of the FFI. This complements [the FFI
tutorial](https://jax.readthedocs.io/en/latest/ffi.html) by putting all
of the code in one place, as well as demonstrating how FFI extensions
can be packaged. Alongside the example project, I have also added a new
GitHub Actions workflow to test the example as part of CI.

For now, the tests only run on CPU, but once we have GPU runners for
GitHub Actions (soon!), I plan on migrating the custom call examples
from `docs/gpu_ops` and `docs/cuda_custom_call` into this test case.

Similarly, I wanted to start small and this example project only
includes exactly the same functions as the tutorial for now, but I think
this could be a good place to showcase more advanced examples (including
custom calls with state).
  • Loading branch information
dfm committed Sep 4, 2024
1 parent a8a55e0 commit ce494f9
Show file tree
Hide file tree
Showing 11 changed files with 395 additions and 4 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ jobs:
echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
pytest -n auto --tb=short --maxfail=20 tests examples
pytest -n auto --tb=short --maxfail=20 --ignore=examples/ffi tests examples
documentation:
Expand Down Expand Up @@ -231,4 +231,3 @@ jobs:
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
pytest -n auto --tb=short --maxfail=20 jax/experimental/jax2tf/tests/jax2tf_test.py
46 changes: 46 additions & 0 deletions .github/workflows/examples.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: Examples

on:
push:
branches:
- main
pull_request:
branches:
- main

permissions:
contents: read # to fetch code
actions: write # to cancel previous workflows

jobs:
ffi:
name: FFI
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1
with:
access_token: ${{ github.token }}
if: ${{github.ref != 'refs/heads/main'}}
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- name: Set up Python 3.11
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
with:
python-version: 3.11
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', 'examples/**/pyproject.toml') }}
- name: Install JAX
run: pip install .
- name: Install dependencies
run: python -m pip install ./examples/ffi[test]
- name: Run tests
run: python -m pytest examples/ffi/tests
2 changes: 1 addition & 1 deletion .github/workflows/wheel_win_x64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ jobs:
python -m pip install -e ${{ github.workspace }}
python -m pip install --no-index --find-links ${{ github.workspace }}\dist jaxlib
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
pytest -n auto --tb=short tests examples
pytest -n auto --tb=short --ignore=examples\ffi tests examples
2 changes: 1 addition & 1 deletion .github/workflows/windows_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ jobs:
python -m pip install -e ${{ github.workspace }}\jax
python -m pip install --no-index --find-links ${{ github.workspace }}\jax\dist jaxlib
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
pytest -n auto --tb=short tests examples
pytest -n auto --tb=short --ignore=examples\ffi tests examples
15 changes: 15 additions & 0 deletions examples/ffi/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
cmake_minimum_required(VERSION 3.15...3.30)
project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX)

find_package(Python 3.8 REQUIRED COMPONENTS Interpreter Development.Module)
execute_process(
COMMAND "${Python_EXECUTABLE}"
"-c" "from jax.extend import ffi; print(ffi.include_dir())"
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR)
message(STATUS "XLA include directory: ${XLA_DIR}")

find_package(nanobind CONFIG REQUIRED)

nanobind_add_module(_rms_norm NB_STATIC "src/jax_ffi_example/rms_norm.cc")
target_include_directories(_rms_norm PUBLIC ${XLA_DIR})
install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
9 changes: 9 additions & 0 deletions examples/ffi/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# End-to-end example usage for JAX's foreign function interface

This directory includes an example project demonstrating the use of JAX's
foreign function interface (FFI). The JAX docs provide more information about
this interface in [the FFI tutorial](https://jax.readthedocs.io/en/latest/ffi.html),
but the example in this directory explicitly demonstrates:

1. One way to package and distribute FFI targets, and
2. Some more advanced use cases.
11 changes: 11 additions & 0 deletions examples/ffi/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[build-system]
requires = ["scikit-build-core", "nanobind", "jax>=0.4.31"]
build-backend = "scikit_build_core.build"

[project]
name = "jax_ffi_example"
version = "0.0.1"
dependencies = ["jax"]

[project.optional-dependencies]
test = ["pytest", "absl-py"]
13 changes: 13 additions & 0 deletions examples/ffi/src/jax_ffi_example/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
160 changes: 160 additions & 0 deletions examples/ffi/src/jax_ffi_example/rms_norm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cmath>
#include <cstdint>
#include <functional>
#include <numeric>
#include <type_traits>
#include <utility>

#include "nanobind/nanobind.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"

namespace nb = nanobind;
namespace ffi = xla::ffi;

// This is the example "library function" that we want to expose to JAX. This
// isn't meant to be a particularly good implementation, it's just here as a
// placeholder for the purposes of this tutorial.
float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {
float sm = 0.0f;
for (int64_t n = 0; n < size; ++n) {
sm += x[n] * x[n];
}
float scale = 1.0f / std::sqrt(sm / float(size) + eps);
for (int64_t n = 0; n < size; ++n) {
y[n] = x[n] * scale;
}
return scale;
}

// A helper function for extracting the relevant dimensions from `ffi::Buffer`s.
// In this example, we treat all leading dimensions as batch dimensions, so this
// function returns the total number of elements in the buffer, and the size of
// the last dimension.
template <ffi::DataType T>
std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
auto dims = buffer.dimensions();
if (dims.size() == 0) {
return std::make_pair(0, 0);
}
return std::make_pair(buffer.element_count(), dims.back());
}

// A wrapper function providing the interface between the XLA FFI call and our
// library function `ComputeRmsNorm` above. This function handles the batch
// dimensions by calling `ComputeRmsNorm` within a loop.
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNorm input must be an array");
}
for (int64_t n = 0; n < totalSize; n += lastDim) {
ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
}
return ffi::Error::Success();
}

// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare
// this handler in a header, you can use the `XLA_FFI_DECLASE_HANDLER_SYMBOL`
// macro: `XLA_FFI_DECLASE_HANDLER_SYMBOL(RmsNorm)`.
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x
.Ret<ffi::Buffer<ffi::DataType::F32>>() // y
);

ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> res) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNormFwd input must be an array");
}
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]),
&(y->typed_data()[n]));
}
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
RmsNormFwd, RmsNormFwdImpl,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x
.Ret<ffi::Buffer<ffi::DataType::F32>>() // y
.Ret<ffi::Buffer<ffi::DataType::F32>>() // res
);

void ComputeRmsNormBwd(int64_t size, float res, const float *x,
const float *ct_y, float *ct_x) {
float ct_res = 0.0f;
for (int64_t n = 0; n < size; ++n) {
ct_res += x[n] * ct_y[n];
}
float factor = ct_res * res * res * res / float(size);
for (int64_t n = 0; n < size; ++n) {
ct_x[n] = res * ct_y[n] - factor * x[n];
}
}

ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::DataType::F32> res,
ffi::Buffer<ffi::DataType::F32> x,
ffi::Buffer<ffi::DataType::F32> ct_y,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> ct_x) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNormBwd inputs must be arrays");
}
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]),
&(ct_y.typed_data()[n]), &(ct_x->typed_data()[n]));
}
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
RmsNormBwd, RmsNormBwdImpl,
ffi::Ffi::Bind()
.Arg<ffi::Buffer<ffi::DataType::F32>>() // res
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x
.Arg<ffi::Buffer<ffi::DataType::F32>>() // ct_y
.Ret<ffi::Buffer<ffi::DataType::F32>>() // ct_x
);

template <typename T>
nb::capsule EncapsulateFfiHandler(T *fn) {
static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
"Encapsulated function must be and XLA FFI handler");
return nb::capsule(reinterpret_cast<void *>(fn));
}

NB_MODULE(_rms_norm, m) {
m.def("registrations", []() {
nb::dict registrations;
registrations["rms_norm"] = EncapsulateFfiHandler(RmsNorm);
registrations["rms_norm_fwd"] = EncapsulateFfiHandler(RmsNormFwd);
registrations["rms_norm_bwd"] = EncapsulateFfiHandler(RmsNormBwd);
return registrations;
});
}
90 changes: 90 additions & 0 deletions examples/ffi/src/jax_ffi_example/rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

import numpy as np

import jax
import jax.extend as jex
import jax.numpy as jnp

from jax_ffi_example import _rms_norm

for name, target in _rms_norm.registrations().items():
jex.ffi.register_ffi_target(name, target)


@partial(jax.custom_vjp, nondiff_argnums=(1,))
def rms_norm(x, eps=1e-5):
# We only implemented the `float32` version of this function, so we start by
# checking the dtype. This check isn't strictly necessary because type
# checking is also performed by the FFI when decoding input and output
# buffers, but it can be useful to check types in Python to raise more
# informative errors.
if x.dtype != jnp.float32:
raise ValueError("Only the float32 dtype is implemented by rms_norm")

# In this case, the output of our FFI function is just a single array with the
# same shape and dtype as the input.
out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)

return jex.ffi.ffi_call(
# The target name must be the same string as we used to register the target
# above in `register_ffi_target`
"rms_norm",
out_type,
x,
# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
# the attribute `eps`. Our FFI function expects this to have the C++ `float`
# type (which corresponds to numpy's `float32` type), and it must be a
# static parameter (i.e. not a JAX array).
eps=np.float32(eps),
# The `vectorized` parameter controls this function's behavior under `vmap`.
vectorized=True,
)


def rms_norm_fwd(x, eps=1e-5):
y, res = jex.ffi.ffi_call(
"rms_norm_fwd",
(
jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
),
x,
eps=np.float32(eps),
vectorized=True,
)
return y, (res, x)


def rms_norm_bwd(eps, res, ct):
del eps
res, x = res
assert res.shape == ct.shape[:-1]
assert x.shape == ct.shape
return (
jex.ffi.ffi_call(
"rms_norm_bwd",
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
res,
x,
ct,
vectorized=True,
),
)


rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd)
Loading

0 comments on commit ce494f9

Please sign in to comment.