diff --git a/xla/ffi/api/ffi.h b/xla/ffi/api/ffi.h index e6560833cb4ae..19eaaf52bb37c 100644 --- a/xla/ffi/api/ffi.h +++ b/xla/ffi/api/ffi.h @@ -477,6 +477,13 @@ class AnyBuffer { void* untyped_data() const { return buf_->data; } + template + T* typed_data() const { + assert(internal::NativeTypeToCApiDataType() == buf_->dtype && + "Template type must match the underlying buffer dtype"); + return reinterpret_cast(buf_->data); + } + private: const XLA_FFI_Buffer* buf_; }; diff --git a/xla/ffi/api/ffi_test.cc b/xla/ffi/api/ffi_test.cc index 0ad4ecc8d5c82..73fe75ed8247e 100644 --- a/xla/ffi/api/ffi_test.cc +++ b/xla/ffi/api/ffi_test.cc @@ -364,6 +364,8 @@ TEST(FfiTest, AnyBufferArgument) { auto handler = Ffi::Bind().Arg().To([&](auto buffer) { EXPECT_EQ(buffer.untyped_data(), storage.data()); + EXPECT_EQ(buffer.template typed_data(), + reinterpret_cast(storage.data())); EXPECT_EQ(buffer.dimensions().size(), 2); return Error::Success(); }); @@ -400,6 +402,8 @@ TEST(FfiTest, AnyBufferResult) { auto handler = Ffi::Bind().Ret().To([&](Result buffer) { EXPECT_EQ(buffer->untyped_data(), storage.data()); + EXPECT_EQ(buffer->template typed_data(), + reinterpret_cast(storage.data())); EXPECT_EQ(buffer->dimensions().size(), 2); return Error::Success(); }); diff --git a/xla/ffi/ffi.h b/xla/ffi/ffi.h index 17d57671e5170..82bcaaf51013b 100644 --- a/xla/ffi/ffi.h +++ b/xla/ffi/ffi.h @@ -110,6 +110,13 @@ class AnyBuffer { void* untyped_data() const { return buf_->data; } + template + T* typed_data() const { + DCHECK(primitive_util::NativeToPrimitiveType() == element_type()) + << "Template type must match the underlying buffer dtype"; + return reinterpret_cast(buf_->data); + } + se::DeviceMemoryBase device_memory() const { return se::DeviceMemoryBase(untyped_data(), size_bytes()); } diff --git a/xla/ffi/ffi_test.cc b/xla/ffi/ffi_test.cc index c0fcb0057ce36..3683e3860f1c7 100644 --- a/xla/ffi/ffi_test.cc +++ b/xla/ffi/ffi_test.cc @@ -530,6 +530,8 @@ TEST(FfiTest, AnyBufferArgument) { auto fn = [&](AnyBuffer buffer) { EXPECT_EQ(buffer.element_type(), PrimitiveType::F32); EXPECT_EQ(buffer.untyped_data(), storage.data()); + EXPECT_EQ(buffer.typed_data(), + reinterpret_cast(storage.data())); AnyBuffer::Dimensions dimensions = buffer.dimensions(); EXPECT_EQ(dimensions.size(), 2); EXPECT_EQ(dimensions[0], 2);