From eafa62efd9744cf2527822333a9b77738153acad Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 26 Sep 2024 15:02:58 -0400 Subject: [PATCH] Add support for passing array attributes via ffi_call --- examples/ffi/CMakeLists.txt | 4 + examples/ffi/src/jax_ffi_example/attrs.cc | 100 ++++++++++++++++++++++ examples/ffi/src/jax_ffi_example/attrs.py | 57 ++++++++++++ examples/ffi/tests/attrs_test.py | 54 ++++++++++++ jax/_src/extend/ffi.py | 25 +++++- jax/_src/interpreters/mlir.py | 26 +++++- 6 files changed, 261 insertions(+), 5 deletions(-) create mode 100644 examples/ffi/src/jax_ffi_example/attrs.cc create mode 100644 examples/ffi/src/jax_ffi_example/attrs.py create mode 100644 examples/ffi/tests/attrs_test.py diff --git a/examples/ffi/CMakeLists.txt b/examples/ffi/CMakeLists.txt index 8d9b811374d1..62142fd49034 100644 --- a/examples/ffi/CMakeLists.txt +++ b/examples/ffi/CMakeLists.txt @@ -13,3 +13,7 @@ find_package(nanobind CONFIG REQUIRED) nanobind_add_module(_rms_norm NB_STATIC "src/jax_ffi_example/rms_norm.cc") target_include_directories(_rms_norm PUBLIC ${XLA_DIR}) install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) + +nanobind_add_module(_attrs NB_STATIC "src/jax_ffi_example/attrs.cc") +target_include_directories(_attrs PUBLIC ${XLA_DIR}) +install(TARGETS _attrs LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) diff --git a/examples/ffi/src/jax_ffi_example/attrs.cc b/examples/ffi/src/jax_ffi_example/attrs.cc new file mode 100644 index 000000000000..c9da16bdea0e --- /dev/null +++ b/examples/ffi/src/jax_ffi_example/attrs.cc @@ -0,0 +1,100 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "nanobind/nanobind.h" +#include "xla/ffi/api/ffi.h" + +namespace nb = nanobind; +namespace ffi = xla::ffi; + +ffi::Error ArrayAttrImpl(ffi::Span array, + ffi::Result> res) { + int64_t total = 0; + for (int32_t x : array) { + total += x; + } + res->typed_data()[0] = total; + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ArrayAttr, ArrayAttrImpl, + ffi::Ffi::Bind() + .Attr>("array") + .Ret>()); + +ffi::Error DictionaryAttrImpl(ffi::Dictionary attrs, + ffi::Result> secret, + ffi::Result> count) { + auto maybe_secret = attrs.get("secret"); + if (maybe_secret.has_error()) { + return maybe_secret.error(); + } + secret->typed_data()[0] = maybe_secret.value(); + count->typed_data()[0] = attrs.size(); + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(DictionaryAttr, DictionaryAttrImpl, + ffi::Ffi::Bind() + .Attrs() + .Ret>() + .Ret>()); + +// TODO(dfm): User defined structs are not yet supported, but once we plumb +// them through XLA, add this example back in: +// +// struct UserStruct { +// int64_t a; +// int64_t b; +// }; +// +// XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(UserStruct, +// ffi::StructMember("a"), +// ffi::StructMember("b")); + +enum class UserEnum : int64_t { + kFailure = 0, + kSuccess = 1, +}; + +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(UserEnum); + +ffi::Error UserDefinedAttrImpl(UserEnum user_enum, + ffi::Result>) { + if (user_enum != UserEnum::kSuccess) { + return ffi::Error::InvalidArgument("user_enum must be kSuccess"); + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(UserDefinedAttr, UserDefinedAttrImpl, + ffi::Ffi::Bind() + .Attr("user_enum") + .Ret>()); + +NB_MODULE(_attrs, m) { + m.def("registrations", []() { + nb::dict registrations; + registrations["array_attr"] = + nb::capsule(reinterpret_cast(ArrayAttr)); + registrations["dictionary_attr"] = + nb::capsule(reinterpret_cast(DictionaryAttr)); + registrations["user_defined_attr"] = + nb::capsule(reinterpret_cast(UserDefinedAttr)); + return registrations; + }); +} diff --git a/examples/ffi/src/jax_ffi_example/attrs.py b/examples/ffi/src/jax_ffi_example/attrs.py new file mode 100644 index 000000000000..a7be5b841bef --- /dev/null +++ b/examples/ffi/src/jax_ffi_example/attrs.py @@ -0,0 +1,57 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An example demonstrating the different ways that attributes can be passed to +the FFI. + +For example, we can pass arrays, variadic attributes, and user-defined types +(caveat: full support of user-defined types isn't yet supported by XLA, but +custom Enums are supported for now). +""" + +import numpy as np + +import jax +import jax.extend as jex + +from jax_ffi_example import _attrs + +for name, target in _attrs.registrations().items(): + jex.ffi.register_ffi_target(name, target) + + +def array_attr(num: int): + return jex.ffi.ffi_call( + "array_attr", + jax.ShapeDtypeStruct((), np.int32), + array=np.arange(num, dtype=np.int32), + ) + + +def dictionary_attr(**kwargs): + return jex.ffi.ffi_call( + "dictionary_attr", + (jax.ShapeDtypeStruct((), np.int32), jax.ShapeDtypeStruct((), np.int32)), + **kwargs, + ) + + +def user_defined_attr(flag): + return jex.ffi.ffi_call( + "user_defined_attr", + jax.ShapeDtypeStruct((), np.int32), + # Since our enum is backed by and int64 in C++, we can pass a Python int. + # Otherwise, we would want to use the appropriate numpy type here. + user_enum=int(flag), + ) diff --git a/examples/ffi/tests/attrs_test.py b/examples/ffi/tests/attrs_test.py new file mode 100644 index 000000000000..7624d525327a --- /dev/null +++ b/examples/ffi/tests/attrs_test.py @@ -0,0 +1,54 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +import jax +import jax.numpy as jnp +from jax._src import test_util as jtu + +from jax_ffi_example import attrs + +jax.config.parse_flags_with_absl() + + +class AttrsTests(jtu.JaxTestCase): + def test_array_attr(self): + self.assertEqual(attrs.array_attr(5), jnp.arange(5).sum()) + self.assertEqual(attrs.array_attr(3), jnp.arange(3).sum()) + + def test_dictionary_attr(self): + secret, count = attrs.dictionary_attr(secret=5) + self.assertEqual(secret, 5) + self.assertEqual(count, 1) + + secret, count = attrs.dictionary_attr(secret=3, a_string="hello") + self.assertEqual(secret, 3) + self.assertEqual(count, 2) + + with self.assertRaisesRegex(Exception, "Unexpected attribute"): + attrs.dictionary_attr() + + with self.assertRaisesRegex(Exception, "Wrong attribute type"): + attrs.dictionary_attr(secret="invalid") + + def test_user_defined_attr(self): + attrs.user_defined_attr(True) # doesn't crash + + with self.assertRaisesRegex(Exception, "user_enum must be kSuccess"): + attrs.user_defined_attr(False) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py index 833ac4f615a8..79e1e6127dd2 100644 --- a/jax/_src/extend/ffi.py +++ b/jax/_src/extend/ffi.py @@ -23,6 +23,7 @@ from jax._src import core from jax._src import dispatch from jax._src import util +from jax._src.api_util import _HashableByObjectId from jax._src.callback import _check_shape_dtype, callback_batching_rule from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -237,12 +238,26 @@ def ffi_call( else: multiple_results = False result_avals = _result_avals((result_shape_dtypes,)) + + # Because we support passing numpy arrays (which are not hashable) as + # attributes, we need to handle that here. For any non-hashable kwargs, we + # wrap them in a hashable object that uses the object's id as the hash. Then, + # in the lowering, we unwrap the object and use the original object. + hashable_kwargs = {} + for k, v in kwargs.items(): + try: + hash(v) + except TypeError: + hashable_kwargs[k] = _HashableByObjectId(v) + else: + hashable_kwargs[k] = v + results = ffi_call_p.bind( *args, result_avals=result_avals, vectorized=vectorized, target_name=target_name, - **kwargs, + **hashable_kwargs, ) if multiple_results: return results @@ -284,7 +299,13 @@ def ffi_call_lowering( **kwargs: Any, ) -> Sequence[ir.Value]: del result_avals, vectorized - return ffi_lowering(target_name)(ctx, *operands, **kwargs) + unhashable_kwargs = {} + for k, v in kwargs.items(): + if isinstance(v, _HashableByObjectId): + unhashable_kwargs[k] = v.val + else: + unhashable_kwargs[k] = v + return ffi_lowering(target_name)(ctx, *operands, **unhashable_kwargs) ffi_call_p = core.Primitive("ffi_call") diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index af773365b12d..4aee650f4a29 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -277,7 +277,12 @@ def ir_constant(val: Any) -> IrValues: raise TypeError(f"No constant handler for type: {type(val)}") def _numpy_array_constant(x: np.ndarray | np.generic) -> IrValues: - attr = _numpy_array_attribute(x) + element_type = dtype_to_ir_type(x.dtype) + shape = x.shape + if x.dtype == np.bool_: + x = np.packbits(x, bitorder='little') # type: ignore + x = np.ascontiguousarray(x) + attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore return hlo.constant(attr) @@ -359,13 +364,26 @@ def _numpy_scalar_attribute(val: Any) -> ir.Attribute: else: raise TypeError(f"Unsupported scalar attribute type: {type(val)}") +_dtype_to_array_attr: dict[Any, AttributeHandler] = { + np.dtype(np.bool_): ir.DenseBoolArrayAttr.get, + np.dtype(np.float32): ir.DenseF32ArrayAttr.get, + np.dtype(np.float64): ir.DenseF64ArrayAttr.get, + np.dtype(np.int32): ir.DenseI32ArrayAttr.get, + np.dtype(np.int64): ir.DenseI64ArrayAttr.get, + np.dtype(np.int8): ir.DenseI8ArrayAttr.get, +} + def _numpy_array_attribute(x: np.ndarray | np.generic) -> ir.Attribute: - element_type = dtype_to_ir_type(x.dtype) shape = x.shape if x.dtype == np.bool_: x = np.packbits(x, bitorder='little') # type: ignore x = np.ascontiguousarray(x) - return ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore + builder = _dtype_to_array_attr.get(x.dtype, None) + if builder: + return builder(x) + else: + element_type = dtype_to_ir_type(x.dtype) + return ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore def _numpy_array_attribute_handler(val: np.ndarray | np.generic) -> ir.Attribute: if 0 in val.strides and val.size > 0: @@ -407,6 +425,8 @@ def _sequence_attribute_handler(val: Sequence[Any]) -> ir.Attribute: register_attribute_handler(list, _sequence_attribute_handler) register_attribute_handler(tuple, _sequence_attribute_handler) +register_attribute_handler(ir.Attribute, lambda x: x) +register_attribute_handler(ir.Type, lambda x: x) def ir_attribute(val: Any) -> ir.Attribute: """Convert a Python value to an MLIR attribute."""