Skip to content

Commit

Permalink
Add support for passing array attributes via ffi_call
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Sep 26, 2024
1 parent 6f7ad64 commit eafa62e
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 5 deletions.
4 changes: 4 additions & 0 deletions examples/ffi/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ find_package(nanobind CONFIG REQUIRED)
nanobind_add_module(_rms_norm NB_STATIC "src/jax_ffi_example/rms_norm.cc")
target_include_directories(_rms_norm PUBLIC ${XLA_DIR})
install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})

nanobind_add_module(_attrs NB_STATIC "src/jax_ffi_example/attrs.cc")
target_include_directories(_attrs PUBLIC ${XLA_DIR})
install(TARGETS _attrs LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
100 changes: 100 additions & 0 deletions examples/ffi/src/jax_ffi_example/attrs.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>

#include "nanobind/nanobind.h"
#include "xla/ffi/api/ffi.h"

namespace nb = nanobind;
namespace ffi = xla::ffi;

ffi::Error ArrayAttrImpl(ffi::Span<const int32_t> array,
ffi::Result<ffi::BufferR0<ffi::S32>> res) {
int64_t total = 0;
for (int32_t x : array) {
total += x;
}
res->typed_data()[0] = total;
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(ArrayAttr, ArrayAttrImpl,
ffi::Ffi::Bind()
.Attr<ffi::Span<const int32_t>>("array")
.Ret<ffi::BufferR0<ffi::S32>>());

ffi::Error DictionaryAttrImpl(ffi::Dictionary attrs,
ffi::Result<ffi::BufferR0<ffi::S32>> secret,
ffi::Result<ffi::BufferR0<ffi::S32>> count) {
auto maybe_secret = attrs.get<int64_t>("secret");
if (maybe_secret.has_error()) {
return maybe_secret.error();
}
secret->typed_data()[0] = maybe_secret.value();
count->typed_data()[0] = attrs.size();
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(DictionaryAttr, DictionaryAttrImpl,
ffi::Ffi::Bind()
.Attrs()
.Ret<ffi::BufferR0<ffi::S32>>()
.Ret<ffi::BufferR0<ffi::S32>>());

// TODO(dfm): User defined structs are not yet supported, but once we plumb
// them through XLA, add this example back in:
//
// struct UserStruct {
// int64_t a;
// int64_t b;
// };
//
// XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(UserStruct,
// ffi::StructMember<int64_t>("a"),
// ffi::StructMember<int64_t>("b"));

enum class UserEnum : int64_t {
kFailure = 0,
kSuccess = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(UserEnum);

ffi::Error UserDefinedAttrImpl(UserEnum user_enum,
ffi::Result<ffi::BufferR0<ffi::S32>>) {
if (user_enum != UserEnum::kSuccess) {
return ffi::Error::InvalidArgument("user_enum must be kSuccess");
}
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(UserDefinedAttr, UserDefinedAttrImpl,
ffi::Ffi::Bind()
.Attr<UserEnum>("user_enum")
.Ret<ffi::BufferR0<ffi::S32>>());

NB_MODULE(_attrs, m) {
m.def("registrations", []() {
nb::dict registrations;
registrations["array_attr"] =
nb::capsule(reinterpret_cast<void *>(ArrayAttr));
registrations["dictionary_attr"] =
nb::capsule(reinterpret_cast<void *>(DictionaryAttr));
registrations["user_defined_attr"] =
nb::capsule(reinterpret_cast<void *>(UserDefinedAttr));
return registrations;
});
}
57 changes: 57 additions & 0 deletions examples/ffi/src/jax_ffi_example/attrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An example demonstrating the different ways that attributes can be passed to
the FFI.
For example, we can pass arrays, variadic attributes, and user-defined types
(caveat: full support of user-defined types isn't yet supported by XLA, but
custom Enums are supported for now).
"""

import numpy as np

import jax
import jax.extend as jex

from jax_ffi_example import _attrs

for name, target in _attrs.registrations().items():
jex.ffi.register_ffi_target(name, target)


def array_attr(num: int):
return jex.ffi.ffi_call(
"array_attr",
jax.ShapeDtypeStruct((), np.int32),
array=np.arange(num, dtype=np.int32),
)


def dictionary_attr(**kwargs):
return jex.ffi.ffi_call(
"dictionary_attr",
(jax.ShapeDtypeStruct((), np.int32), jax.ShapeDtypeStruct((), np.int32)),
**kwargs,
)


def user_defined_attr(flag):
return jex.ffi.ffi_call(
"user_defined_attr",
jax.ShapeDtypeStruct((), np.int32),
# Since our enum is backed by and int64 in C++, we can pass a Python int.
# Otherwise, we would want to use the appropriate numpy type here.
user_enum=int(flag),
)
54 changes: 54 additions & 0 deletions examples/ffi/tests/attrs_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest

import jax
import jax.numpy as jnp
from jax._src import test_util as jtu

from jax_ffi_example import attrs

jax.config.parse_flags_with_absl()


class AttrsTests(jtu.JaxTestCase):
def test_array_attr(self):
self.assertEqual(attrs.array_attr(5), jnp.arange(5).sum())
self.assertEqual(attrs.array_attr(3), jnp.arange(3).sum())

def test_dictionary_attr(self):
secret, count = attrs.dictionary_attr(secret=5)
self.assertEqual(secret, 5)
self.assertEqual(count, 1)

secret, count = attrs.dictionary_attr(secret=3, a_string="hello")
self.assertEqual(secret, 3)
self.assertEqual(count, 2)

with self.assertRaisesRegex(Exception, "Unexpected attribute"):
attrs.dictionary_attr()

with self.assertRaisesRegex(Exception, "Wrong attribute type"):
attrs.dictionary_attr(secret="invalid")

def test_user_defined_attr(self):
attrs.user_defined_attr(True) # doesn't crash

with self.assertRaisesRegex(Exception, "user_enum must be kSuccess"):
attrs.user_defined_attr(False)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
25 changes: 23 additions & 2 deletions jax/_src/extend/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from jax._src import core
from jax._src import dispatch
from jax._src import util
from jax._src.api_util import _HashableByObjectId
from jax._src.callback import _check_shape_dtype, callback_batching_rule
from jax._src.interpreters import ad
from jax._src.interpreters import batching
Expand Down Expand Up @@ -237,12 +238,26 @@ def ffi_call(
else:
multiple_results = False
result_avals = _result_avals((result_shape_dtypes,))

# Because we support passing numpy arrays (which are not hashable) as
# attributes, we need to handle that here. For any non-hashable kwargs, we
# wrap them in a hashable object that uses the object's id as the hash. Then,
# in the lowering, we unwrap the object and use the original object.
hashable_kwargs = {}
for k, v in kwargs.items():
try:
hash(v)
except TypeError:
hashable_kwargs[k] = _HashableByObjectId(v)
else:
hashable_kwargs[k] = v

results = ffi_call_p.bind(
*args,
result_avals=result_avals,
vectorized=vectorized,
target_name=target_name,
**kwargs,
**hashable_kwargs,
)
if multiple_results:
return results
Expand Down Expand Up @@ -284,7 +299,13 @@ def ffi_call_lowering(
**kwargs: Any,
) -> Sequence[ir.Value]:
del result_avals, vectorized
return ffi_lowering(target_name)(ctx, *operands, **kwargs)
unhashable_kwargs = {}
for k, v in kwargs.items():
if isinstance(v, _HashableByObjectId):
unhashable_kwargs[k] = v.val
else:
unhashable_kwargs[k] = v
return ffi_lowering(target_name)(ctx, *operands, **unhashable_kwargs)


ffi_call_p = core.Primitive("ffi_call")
Expand Down
26 changes: 23 additions & 3 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,12 @@ def ir_constant(val: Any) -> IrValues:
raise TypeError(f"No constant handler for type: {type(val)}")

def _numpy_array_constant(x: np.ndarray | np.generic) -> IrValues:
attr = _numpy_array_attribute(x)
element_type = dtype_to_ir_type(x.dtype)
shape = x.shape
if x.dtype == np.bool_:
x = np.packbits(x, bitorder='little') # type: ignore
x = np.ascontiguousarray(x)
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore
return hlo.constant(attr)


Expand Down Expand Up @@ -359,13 +364,26 @@ def _numpy_scalar_attribute(val: Any) -> ir.Attribute:
else:
raise TypeError(f"Unsupported scalar attribute type: {type(val)}")

_dtype_to_array_attr: dict[Any, AttributeHandler] = {
np.dtype(np.bool_): ir.DenseBoolArrayAttr.get,
np.dtype(np.float32): ir.DenseF32ArrayAttr.get,
np.dtype(np.float64): ir.DenseF64ArrayAttr.get,
np.dtype(np.int32): ir.DenseI32ArrayAttr.get,
np.dtype(np.int64): ir.DenseI64ArrayAttr.get,
np.dtype(np.int8): ir.DenseI8ArrayAttr.get,
}

def _numpy_array_attribute(x: np.ndarray | np.generic) -> ir.Attribute:
element_type = dtype_to_ir_type(x.dtype)
shape = x.shape
if x.dtype == np.bool_:
x = np.packbits(x, bitorder='little') # type: ignore
x = np.ascontiguousarray(x)
return ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore
builder = _dtype_to_array_attr.get(x.dtype, None)
if builder:
return builder(x)
else:
element_type = dtype_to_ir_type(x.dtype)
return ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore

def _numpy_array_attribute_handler(val: np.ndarray | np.generic) -> ir.Attribute:
if 0 in val.strides and val.size > 0:
Expand Down Expand Up @@ -407,6 +425,8 @@ def _sequence_attribute_handler(val: Sequence[Any]) -> ir.Attribute:

register_attribute_handler(list, _sequence_attribute_handler)
register_attribute_handler(tuple, _sequence_attribute_handler)
register_attribute_handler(ir.Attribute, lambda x: x)
register_attribute_handler(ir.Type, lambda x: x)

def ir_attribute(val: Any) -> ir.Attribute:
"""Convert a Python value to an MLIR attribute."""
Expand Down

0 comments on commit eafa62e

Please sign in to comment.