Skip to content

Commit

Permalink
Add support for passing array attributes via ffi_call
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Sep 26, 2024
1 parent 6f7ad64 commit e170da3
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 5 deletions.
4 changes: 4 additions & 0 deletions examples/ffi/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ find_package(nanobind CONFIG REQUIRED)
nanobind_add_module(_rms_norm NB_STATIC "src/jax_ffi_example/rms_norm.cc")
target_include_directories(_rms_norm PUBLIC ${XLA_DIR})
install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})

nanobind_add_module(_attrs NB_STATIC "src/jax_ffi_example/attrs.cc")
target_include_directories(_attrs PUBLIC ${XLA_DIR})
install(TARGETS _attrs LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
66 changes: 66 additions & 0 deletions examples/ffi/src/jax_ffi_example/attrs.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>

#include "nanobind/nanobind.h"
#include "xla/ffi/api/ffi.h"

namespace nb = nanobind;
namespace ffi = xla::ffi;

ffi::Error ArrayAttrImpl(ffi::Span<const int32_t> array,
ffi::Result<ffi::BufferR0<ffi::S32>> res) {
int64_t total = 0;
for (int32_t x : array) {
total += x;
}
res->typed_data()[0] = total;
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(ArrayAttr, ArrayAttrImpl,
ffi::Ffi::Bind()
.Attr<ffi::Span<const int32_t>>("array")
.Ret<ffi::BufferR0<ffi::S32>>());

ffi::Error DictionaryAttrImpl(ffi::Dictionary attrs,
ffi::Result<ffi::BufferR0<ffi::S32>> secret,
ffi::Result<ffi::BufferR0<ffi::S32>> count) {
auto maybe_secret = attrs.get<int64_t>("secret");
if (maybe_secret.has_error()) {
return maybe_secret.error();
}
secret->typed_data()[0] = maybe_secret.value();
count->typed_data()[0] = attrs.size();
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(DictionaryAttr, DictionaryAttrImpl,
ffi::Ffi::Bind()
.Attrs()
.Ret<ffi::BufferR0<ffi::S32>>()
.Ret<ffi::BufferR0<ffi::S32>>());

NB_MODULE(_attrs, m) {
m.def("registrations", []() {
nb::dict registrations;
registrations["array_attr"] =
nb::capsule(reinterpret_cast<void *>(ArrayAttr));
registrations["dictionary_attr"] =
nb::capsule(reinterpret_cast<void *>(DictionaryAttr));
return registrations;
});
}
47 changes: 47 additions & 0 deletions examples/ffi/src/jax_ffi_example/attrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An example demonstrating the different ways that attributes can be passed to
the FFI.
For example, we can pass arrays, variadic attributes, and user-defined types.
Full support of user-defined types isn't yet supported by XLA, so that example
will be added in the future.
"""

import numpy as np

import jax
import jax.extend as jex

from jax_ffi_example import _attrs

for name, target in _attrs.registrations().items():
jex.ffi.register_ffi_target(name, target)


def array_attr(num: int):
return jex.ffi.ffi_call(
"array_attr",
jax.ShapeDtypeStruct((), np.int32),
array=np.arange(num, dtype=np.int32),
)


def dictionary_attr(**kwargs):
return jex.ffi.ffi_call(
"dictionary_attr",
(jax.ShapeDtypeStruct((), np.int32), jax.ShapeDtypeStruct((), np.int32)),
**kwargs,
)
48 changes: 48 additions & 0 deletions examples/ffi/tests/attrs_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest

import jax
import jax.numpy as jnp
from jax._src import test_util as jtu

from jax_ffi_example import attrs

jax.config.parse_flags_with_absl()


class AttrsTests(jtu.JaxTestCase):
def test_array_attr(self):
self.assertEqual(attrs.array_attr(5), jnp.arange(5).sum())
self.assertEqual(attrs.array_attr(3), jnp.arange(3).sum())

def test_dictionary_attr(self):
secret, count = attrs.dictionary_attr(secret=5)
self.assertEqual(secret, 5)
self.assertEqual(count, 1)

secret, count = attrs.dictionary_attr(secret=3, a_string="hello")
self.assertEqual(secret, 3)
self.assertEqual(count, 2)

with self.assertRaisesRegex(Exception, "Unexpected attribute"):
attrs.dictionary_attr()

with self.assertRaisesRegex(Exception, "Wrong attribute type"):
attrs.dictionary_attr(secret="invalid")


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
25 changes: 23 additions & 2 deletions jax/_src/extend/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from jax._src import core
from jax._src import dispatch
from jax._src import util
from jax._src.api_util import _HashableByObjectId
from jax._src.callback import _check_shape_dtype, callback_batching_rule
from jax._src.interpreters import ad
from jax._src.interpreters import batching
Expand Down Expand Up @@ -237,12 +238,26 @@ def ffi_call(
else:
multiple_results = False
result_avals = _result_avals((result_shape_dtypes,))

# Because we support passing numpy arrays (which are not hashable) as
# attributes, we need to handle that here. For any non-hashable kwargs, we
# wrap them in a hashable object that uses the object's id as the hash. Then,
# in the lowering, we unwrap the object and use the original object.
hashable_kwargs = {}
for k, v in kwargs.items():
try:
hash(v)
except TypeError:
hashable_kwargs[k] = _HashableByObjectId(v)
else:
hashable_kwargs[k] = v

results = ffi_call_p.bind(
*args,
result_avals=result_avals,
vectorized=vectorized,
target_name=target_name,
**kwargs,
**hashable_kwargs,
)
if multiple_results:
return results
Expand Down Expand Up @@ -284,7 +299,13 @@ def ffi_call_lowering(
**kwargs: Any,
) -> Sequence[ir.Value]:
del result_avals, vectorized
return ffi_lowering(target_name)(ctx, *operands, **kwargs)
unhashable_kwargs = {}
for k, v in kwargs.items():
if isinstance(v, _HashableByObjectId):
unhashable_kwargs[k] = v.val
else:
unhashable_kwargs[k] = v
return ffi_lowering(target_name)(ctx, *operands, **unhashable_kwargs)


ffi_call_p = core.Primitive("ffi_call")
Expand Down
26 changes: 23 additions & 3 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,12 @@ def ir_constant(val: Any) -> IrValues:
raise TypeError(f"No constant handler for type: {type(val)}")

def _numpy_array_constant(x: np.ndarray | np.generic) -> IrValues:
attr = _numpy_array_attribute(x)
element_type = dtype_to_ir_type(x.dtype)
shape = x.shape
if x.dtype == np.bool_:
x = np.packbits(x, bitorder='little') # type: ignore
x = np.ascontiguousarray(x)
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore
return hlo.constant(attr)


Expand Down Expand Up @@ -359,13 +364,26 @@ def _numpy_scalar_attribute(val: Any) -> ir.Attribute:
else:
raise TypeError(f"Unsupported scalar attribute type: {type(val)}")

_dtype_to_array_attr: dict[Any, AttributeHandler] = {
np.dtype(np.bool_): ir.DenseBoolArrayAttr.get,
np.dtype(np.float32): ir.DenseF32ArrayAttr.get,
np.dtype(np.float64): ir.DenseF64ArrayAttr.get,
np.dtype(np.int32): ir.DenseI32ArrayAttr.get,
np.dtype(np.int64): ir.DenseI64ArrayAttr.get,
np.dtype(np.int8): ir.DenseI8ArrayAttr.get,
}

def _numpy_array_attribute(x: np.ndarray | np.generic) -> ir.Attribute:
element_type = dtype_to_ir_type(x.dtype)
shape = x.shape
if x.dtype == np.bool_:
x = np.packbits(x, bitorder='little') # type: ignore
x = np.ascontiguousarray(x)
return ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore
builder = _dtype_to_array_attr.get(x.dtype, None)
if builder:
return builder(x)
else:
element_type = dtype_to_ir_type(x.dtype)
return ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore

def _numpy_array_attribute_handler(val: np.ndarray | np.generic) -> ir.Attribute:
if 0 in val.strides and val.size > 0:
Expand Down Expand Up @@ -407,6 +425,8 @@ def _sequence_attribute_handler(val: Sequence[Any]) -> ir.Attribute:

register_attribute_handler(list, _sequence_attribute_handler)
register_attribute_handler(tuple, _sequence_attribute_handler)
register_attribute_handler(ir.Attribute, lambda x: x)
register_attribute_handler(ir.Type, lambda x: x)

def ir_attribute(val: Any) -> ir.Attribute:
"""Convert a Python value to an MLIR attribute."""
Expand Down

0 comments on commit e170da3

Please sign in to comment.