Skip to content

Commit

Permalink
Reverts 590b36f
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 697010441
  • Loading branch information
Google-ML-Automation committed Nov 15, 2024
1 parent 4d5f691 commit 4c8cc57
Show file tree
Hide file tree
Showing 9 changed files with 7 additions and 122 deletions.
1 change: 1 addition & 0 deletions xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,7 @@ cc_library(
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/service:computation_placer_hdr",
"//xla/service:hlo_cost_analysis",
"//xla/service:hlo_module_config",
"//xla/service:hlo_proto_cc",
"//xla/tsl/framework:allocator",
"@com_google_absl//absl/cleanup",
Expand Down
2 changes: 0 additions & 2 deletions xla/pjrt/c/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -396,11 +396,9 @@ xla_test(
"//xla/service:custom_call_target_registry",
"//xla/stream_executor/gpu:gpu_init",
"//xla/tests:literal_test_util",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:status_matchers",
Expand Down
3 changes: 0 additions & 3 deletions xla/pjrt/c/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# PJRT C API changelog

## 0.56 (Nov 11, 2024)
* Added ``PJRT_Buffer_CopyRawToHost``

## 0.55
* Added types F8E4M3 and F8E3M4.

Expand Down
17 changes: 1 addition & 16 deletions xla/pjrt/c/pjrt_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next);
// Changes include:
// * Adding a new field to the PJRT_Api or argument structs
// * Renaming a method or argument (doesn't affect ABI)
#define PJRT_API_MINOR 56
#define PJRT_API_MINOR 55

// The plugin should set the major_version and minor_version of
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
Expand Down Expand Up @@ -1759,20 +1759,6 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_IsDeleted_Args, is_deleted);
// True if and only if PJRT_Buffer_Delete has previously been called.
typedef PJRT_Error* PJRT_Buffer_IsDeleted(PJRT_Buffer_IsDeleted_Args* args);

struct PJRT_Buffer_CopyRawToHost_Args {
size_t struct_size;
PJRT_Extension_Base* extension_start;
PJRT_Buffer* buffer;
void* dst;
int64_t offset;
int64_t transfer_size;
PJRT_Event* event; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_CopyRawToHost_Args, event);

typedef PJRT_Error* PJRT_Buffer_CopyRawToHost(
PJRT_Buffer_CopyRawToHost_Args* args);

struct PJRT_Buffer_CopyToDevice_Args {
size_t struct_size;
PJRT_Extension_Base* extension_start;
Expand Down Expand Up @@ -2220,7 +2206,6 @@ typedef struct PJRT_Api {
_PJRT_API_STRUCT_FIELD(PJRT_Buffer_Memory);
_PJRT_API_STRUCT_FIELD(PJRT_Buffer_Delete);
_PJRT_API_STRUCT_FIELD(PJRT_Buffer_IsDeleted);
_PJRT_API_STRUCT_FIELD(PJRT_Buffer_CopyRawToHost);
_PJRT_API_STRUCT_FIELD(PJRT_Buffer_CopyToDevice);
_PJRT_API_STRUCT_FIELD(PJRT_Buffer_ToHostBuffer);
_PJRT_API_STRUCT_FIELD(PJRT_Buffer_IsOnCpu);
Expand Down
73 changes: 0 additions & 73 deletions xla/pjrt/c/pjrt_c_api_gpu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ limitations under the License.

#include "xla/pjrt/c/pjrt_c_api_gpu.h"

#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <functional>
#include <memory>
#include <numeric>
Expand All @@ -32,7 +30,6 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xla/client/client_library.h"
#include "xla/ffi/api/ffi.h"
#include "xla/ffi/execution_context.h"
Expand All @@ -56,18 +53,13 @@ limitations under the License.
#include "xla/shape_util.h"
#include "xla/stream_executor/gpu/gpu_init.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"

namespace pjrt {
namespace {

using ::testing::HasSubstr;
using ::testing::IsNull;
using ::tsl::testing::StatusIs;

#ifdef TENSORFLOW_USE_ROCM
const bool kUnused = (RegisterPjRtCApiTestFactory([]() { return GetPjrtApi(); },
/*platform_name=*/"rocm"),
Expand Down Expand Up @@ -164,71 +156,6 @@ TEST_F(PjrtCApiGpuTest, CreateViewOfDeviceBuffer) {
xla::LiteralUtil::CreateR1<float>(float_data), *literal));
}

class PjrtCApiGpuBufferTest : public PjrtCApiGpuTest {
public:
PjrtCApiGpuBufferTest() : PjrtCApiGpuTest() {
auto buffer_and_event = create_buffer();
buffer_ = std::move(buffer_and_event.first);
event_ = buffer_and_event.second;
}

~PjrtCApiGpuBufferTest() override {
// event_ needs to complete before the client is destroyed; otherwise there
// is a data race between destroying the client and trying to access the
// host context in the client for the callback after host to device transfer
// is completed.
TF_EXPECT_OK(event_.Await());
// buffer_ must be destroyed before the client is destroyed or else the
// unique_ptr for buffer_ will go out of scope causing heap-use-after-free
// error.
buffer_.reset(nullptr);
}

std::unique_ptr<PJRT_Buffer, PJRT_BufferDeleter> buffer_;
xla::PjRtFuture<> event_;
};

TEST_F(PjrtCApiGpuBufferTest, CopyRawToHost) {
size_t alignment = buffer_->buffer->GetOnDeviceSizeInBytes().value();
PJRT_Buffer_CopyRawToHost_Args args;
args.struct_size = PJRT_Buffer_CopyRawToHost_Args_STRUCT_SIZE;
args.extension_start = nullptr;
args.buffer = buffer_.get();
args.dst = aligned_alloc(alignment, 0);
args.offset = 0;
args.transfer_size = alignment;
PJRT_Error* error = api_->PJRT_Buffer_CopyRawToHost(&args);
ASSERT_THAT(error, IsNull());
xla::PjRtFuture<> copy_to_host_event =
ConvertCEventToCppFuture(args.event, api_);
TF_EXPECT_OK(copy_to_host_event.Await());
EXPECT_EQ(*(static_cast<float*>(args.dst)), 41);
free(args.dst);
}

TEST_F(PjrtCApiGpuBufferTest, CopyRawToHostWithInvalidOffset) {
size_t alignment = buffer_->buffer->GetOnDeviceSizeInBytes().value();
PJRT_Buffer_CopyRawToHost_Args args;
args.struct_size = PJRT_Buffer_CopyRawToHost_Args_STRUCT_SIZE;
args.extension_start = nullptr;
args.buffer = buffer_.get();
args.dst = aligned_alloc(alignment, 0);
args.offset = alignment + 1; // offset is invalid
args.transfer_size = alignment;
PJRT_Error* error = api_->PJRT_Buffer_CopyRawToHost(&args);
ASSERT_EQ(error, nullptr);
xla::PjRtFuture<> copy_to_host_event =
ConvertCEventToCppFuture(args.event, api_);
absl::Status status = copy_to_host_event.Await();
std::string expected_message = absl::StrFormat(
"Copy raw buffer called on buffer size %lld with "
"invalid offset %lld, transfer size %lld",
alignment, args.offset, args.transfer_size);
EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr(expected_message)));
free(args.dst);
}

TEST_F(PjrtCApiGpuTest, CreateAndDestroyExecuteContext) {
PJRT_ExecuteContext_Create_Args create_arg;
create_arg.struct_size = PJRT_ExecuteContext_Create_Args_STRUCT_SIZE;
Expand Down
11 changes: 0 additions & 11 deletions xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1742,16 +1742,6 @@ PJRT_Error* PJRT_Buffer_IsDeleted(PJRT_Buffer_IsDeleted_Args* args) {
return nullptr;
}

PJRT_Error* PJRT_Buffer_CopyRawToHost(PJRT_Buffer_CopyRawToHost_Args* args) {
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
"PJRT_Buffer_CopyRawToHost_Args",
PJRT_Buffer_CopyRawToHost_Args_STRUCT_SIZE, args->struct_size));
xla::PjRtFuture<> wrapped_promise = args->buffer->buffer->CopyRawToHost(
args->dst, args->offset, args->transfer_size);
args->event = new PJRT_Event{std::move(wrapped_promise)};
return nullptr;
}

PJRT_Error* PJRT_Buffer_CopyToDevice(PJRT_Buffer_CopyToDevice_Args* args) {
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
"PJRT_Buffer_CopyToDevice_Args",
Expand Down Expand Up @@ -2471,7 +2461,6 @@ PJRT_Api CreatePjrtApi(PJRT_Client_Create* create_fn,
/*PJRT_Buffer_Memory=*/pjrt::PJRT_Buffer_Memory,
/*PJRT_Buffer_Delete=*/pjrt::PJRT_Buffer_Delete,
/*PJRT_Buffer_IsDeleted=*/pjrt::PJRT_Buffer_IsDeleted,
/*PJRT_Buffer_CopyRawToHost=*/pjrt::PJRT_Buffer_CopyRawToHost,
/*PJRT_Buffer_CopyToDevice=*/pjrt::PJRT_Buffer_CopyToDevice,
/*PJRT_Buffer_ToHostBuffer=*/pjrt::PJRT_Buffer_ToHostBuffer,
/*PJRT_Buffer_IsOnCpu=*/pjrt::PJRT_Buffer_IsOnCpu,
Expand Down
1 change: 0 additions & 1 deletion xla/pjrt/c/pjrt_c_api_wrapper_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,6 @@ PJRT_Error* PJRT_Buffer_Device(PJRT_Buffer_Device_Args* args);
PJRT_Error* PJRT_Buffer_Memory(PJRT_Buffer_Memory_Args* args);
PJRT_Error* PJRT_Buffer_Delete(PJRT_Buffer_Delete_Args* args);
PJRT_Error* PJRT_Buffer_IsDeleted(PJRT_Buffer_IsDeleted_Args* args);
PJRT_Error* PJRT_Buffer_CopyRawToHost(PJRT_Buffer_CopyRawToHost_Args* args);
PJRT_Error* PJRT_Buffer_CopyToDevice(PJRT_Buffer_CopyToDevice_Args* args);
PJRT_Error* PJRT_Buffer_CopyToMemory(PJRT_Buffer_CopyToMemory_Args* args);
PJRT_Error* PJRT_Buffer_ToHostBuffer(PJRT_Buffer_ToHostBuffer_Args* args);
Expand Down
15 changes: 0 additions & 15 deletions xla/pjrt/pjrt_c_api_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2006,21 +2006,6 @@ bool PjRtCApiBuffer::IsDeleted() {
return args.is_deleted;
}

PjRtFuture<> PjRtCApiBuffer::CopyRawToHost(void* dst, int64_t offset,
int64_t transfer_size) {
PJRT_Buffer_CopyRawToHost_Args args;
args.struct_size = PJRT_Buffer_CopyRawToHost_Args_STRUCT_SIZE;
args.extension_start = nullptr;
args.buffer = buffer_.get();
args.dst = dst;
args.offset = offset;
args.transfer_size = transfer_size;
const PJRT_Api* api = pjrt_c_api();
RETURN_FUTURE_IF_ERROR(api->PJRT_Buffer_CopyRawToHost(&args), api);
CHECK(args.event != nullptr);
return pjrt::ConvertCEventToCppFuture(args.event, api);
}

absl::StatusOr<std::unique_ptr<PjRtBuffer>> PjRtCApiBuffer::CopyToDevice(
PjRtDevice* dst_device) {
if (dst_device->client() == client_) {
Expand Down
6 changes: 5 additions & 1 deletion xla/pjrt/pjrt_c_api_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,11 @@ class PjRtCApiBuffer : public PjRtBuffer {
absl::StatusOr<size_t> GetOnDeviceSizeInBytes() const override;

PjRtFuture<> CopyRawToHost(void* dst, int64_t offset,
int64_t transfer_size) override;
int64_t transfer_size) override {
return PjRtFuture<>(Unimplemented(
"PJRT C API does not support CopyRawToHost. Please report an issue at "
"https://github.com/google/jax/issues if you need this feature."));
}

void Delete() override;

Expand Down

0 comments on commit 4c8cc57

Please sign in to comment.