Skip to content

Commit

Permalink
[xla:ffi] Add typed_data() method to AnyBuffer.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689368086
  • Loading branch information
dfm authored and Google-ML-Automation committed Oct 24, 2024
1 parent d086c65 commit 086726f
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 0 deletions.
7 changes: 7 additions & 0 deletions xla/ffi/api/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,13 @@ class AnyBuffer {

void* untyped_data() const { return buf_->data; }

template <typename T>
T* typed_data() const {
assert(internal::NativeTypeToCApiDataType<T>() == buf_->dtype &&
"Template type must match the underlying buffer dtype");
return reinterpret_cast<T*>(buf_->data);
}

private:
const XLA_FFI_Buffer* buf_;
};
Expand Down
4 changes: 4 additions & 0 deletions xla/ffi/api/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ TEST(FfiTest, AnyBufferArgument) {

auto handler = Ffi::Bind().Arg<AnyBuffer>().To([&](auto buffer) {
EXPECT_EQ(buffer.untyped_data(), storage.data());
EXPECT_EQ(buffer.template typed_data<float>(),
reinterpret_cast<float*>(storage.data()));
EXPECT_EQ(buffer.dimensions().size(), 2);
return Error::Success();
});
Expand Down Expand Up @@ -400,6 +402,8 @@ TEST(FfiTest, AnyBufferResult) {

auto handler = Ffi::Bind().Ret<AnyBuffer>().To([&](Result<AnyBuffer> buffer) {
EXPECT_EQ(buffer->untyped_data(), storage.data());
EXPECT_EQ(buffer->template typed_data<float>(),
reinterpret_cast<float*>(storage.data()));
EXPECT_EQ(buffer->dimensions().size(), 2);
return Error::Success();
});
Expand Down
7 changes: 7 additions & 0 deletions xla/ffi/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ class AnyBuffer {

void* untyped_data() const { return buf_->data; }

template <typename T>
T* typed_data() const {
DCHECK(primitive_util::NativeToPrimitiveType<T>() == element_type())
<< "Template type must match the underlying buffer dtype";
return reinterpret_cast<T*>(buf_->data);
}

se::DeviceMemoryBase device_memory() const {
return se::DeviceMemoryBase(untyped_data(), size_bytes());
}
Expand Down
2 changes: 2 additions & 0 deletions xla/ffi/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,8 @@ TEST(FfiTest, AnyBufferArgument) {
auto fn = [&](AnyBuffer buffer) {
EXPECT_EQ(buffer.element_type(), PrimitiveType::F32);
EXPECT_EQ(buffer.untyped_data(), storage.data());
EXPECT_EQ(buffer.typed_data<float>(),
reinterpret_cast<float*>(storage.data()));
AnyBuffer::Dimensions dimensions = buffer.dimensions();
EXPECT_EQ(dimensions.size(), 2);
EXPECT_EQ(dimensions[0], 2);
Expand Down

0 comments on commit 086726f

Please sign in to comment.