Skip to content

Commit

Permalink
Allow individual layouts in the device_layouts argument of `PjRtCli…
Browse files Browse the repository at this point in the history
…ent::CreateBuffersForAsyncHostToDevice` to be `std::nullopt`, meaning that those buffers should use the default device layout.

This change obviates the need for the caller to generate and specify layouts for buffers that should use the default device layout.

PiperOrigin-RevId: 689477941
  • Loading branch information
Google-ML-Automation committed Oct 24, 2024
1 parent eb3a2b6 commit 17851a6
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 16 deletions.
14 changes: 7 additions & 7 deletions xla/pjrt/gpu/se_gpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ class AsyncHostToDeviceTransferManager
public:
static absl::StatusOr<std::unique_ptr<AsyncHostToDeviceTransferManager>>
Create(absl::Span<const PjRtClient::ShapeSpec> shape_specs,
std::optional<absl::Span<const Layout>> device_layouts,
std::optional<absl::Span<const std::optional<Layout>>> device_layouts,
PjRtStreamExecutorDevice* device, PjRtStreamExecutorClient* client,
PjRtMemorySpace* memory_space) {
if (device_layouts != std::nullopt &&
if (device_layouts.has_value() &&
device_layouts->size() != shape_specs.size()) {
return InvalidArgument(
"Number of layouts %d does not match the number of shapes %d",
Expand All @@ -153,14 +153,14 @@ class AsyncHostToDeviceTransferManager
std::make_shared<BufferSequencingEvent>(client->thread_pool()));
Shape& device_shape = device_shapes.emplace_back(
ShapeUtil::MakeShape(shape_spec.element_type, shape_spec.dims));
if (device_layouts == std::nullopt) {
if (device_layouts.has_value() && (*device_layouts)[i].has_value()) {
*device_shape.mutable_layout() = *(*device_layouts)[i];
} else {
TF_ASSIGN_OR_RETURN(device_shape,
client->client()
->backend()
.transfer_manager()
->ChooseCompactLayoutForShape(device_shape));
} else {
*device_shape.mutable_layout() = (*device_layouts)[i];
}
LocalDeviceState* local_device = device->local_device_state();
se::Stream* h2d_stream = local_device->host_to_device_stream();
Expand Down Expand Up @@ -555,7 +555,7 @@ absl::string_view StreamExecutorGpuClient::platform_version() const {
absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice(
absl::Span<const PjRtClient::ShapeSpec> shape_specs,
std::optional<absl::Span<const Layout>> device_layouts,
std::optional<absl::Span<const std::optional<Layout>>> device_layouts,
PjRtDevice* device) {
auto* stream_executor_device =
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device);
Expand All @@ -581,7 +581,7 @@ StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice(
absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice(
absl::Span<const PjRtClient::ShapeSpec> shape_specs,
std::optional<absl::Span<const Layout>> device_layouts,
std::optional<absl::Span<const std::optional<Layout>>> device_layouts,
PjRtMemorySpace* memory_space) {
CHECK_EQ(memory_space->devices().size(), 1);
PjRtDevice* device = memory_space->devices()[0];
Expand Down
4 changes: 2 additions & 2 deletions xla/pjrt/gpu/se_gpu_pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient {
absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
CreateBuffersForAsyncHostToDevice(
absl::Span<const PjRtClient::ShapeSpec> shape_specs,
std::optional<absl::Span<const Layout>> device_layouts,
std::optional<absl::Span<const std::optional<Layout>>> device_layouts,
PjRtDevice* device) override;

absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
Expand All @@ -218,7 +218,7 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient {
absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
CreateBuffersForAsyncHostToDevice(
absl::Span<const PjRtClient::ShapeSpec> shape_specs,
std::optional<absl::Span<const Layout>> device_layouts,
std::optional<absl::Span<const std::optional<Layout>>> device_layouts,
PjRtMemorySpace* memory_space) override;

absl::StatusOr<std::unique_ptr<PjRtClient::AsyncHostToDeviceTransferManager>>
Expand Down
6 changes: 3 additions & 3 deletions xla/pjrt/gpu/se_gpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -488,12 +488,12 @@ TEST(StreamExecutorGpuClientTest, ToLiteralAsyncWithNonCompactLayout) {
spec.element_type = src_literal.shape().element_type();
spec.dims = DimensionVector(src_literal.shape().dimensions().begin(),
src_literal.shape().dimensions().end());
std::vector<std::optional<xla::Layout>> device_layouts = {
std::make_optional(transposed_shape.layout())};
TF_ASSERT_OK_AND_ASSIGN(
auto transfer_manager,
client->CreateBuffersForAsyncHostToDevice(
{spec},
std::make_optional<absl::Span<const Layout>>(
{transposed_shape.layout()}),
{spec}, device_layouts,
client->addressable_devices()[0]->memory_spaces()[0]));
auto buffer = transfer_manager->RetrieveBuffer(0);

Expand Down
12 changes: 8 additions & 4 deletions xla/pjrt/pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -765,12 +765,16 @@ class PjRtClient {
};

// Returns a manager for async transfers into a set of buffers with on-host
// shapes defined by 'shape_specs' and optional `device_layouts`. The
// `device_layout` is used when non-compact layouts are preferred.
// shapes defined by 'shape_specs' and optional `device_layouts`.
//
// If the desired layout of one or more buffers is not specified in
// `device_layouts`, then those buffers will use the default device layout. If
// `device_layouts` itself is not specified, then all buffers will use the
// default device layout.
virtual absl::StatusOr<std::unique_ptr<AsyncHostToDeviceTransferManager>>
CreateBuffersForAsyncHostToDevice(
absl::Span<const ShapeSpec> shape_specs,
std::optional<absl::Span<const Layout>> device_layouts,
std::optional<absl::Span<const std::optional<Layout>>> device_layouts,
PjRtDevice* device) {
return absl::UnimplementedError(absl::StrCat(
"CreateBuffersForAsyncHostToDevice with ShapeSpec and Layout is "
Expand All @@ -782,7 +786,7 @@ class PjRtClient {
virtual absl::StatusOr<std::unique_ptr<AsyncHostToDeviceTransferManager>>
CreateBuffersForAsyncHostToDevice(
absl::Span<const ShapeSpec> shape_specs,
std::optional<absl::Span<const Layout>> device_layouts,
std::optional<absl::Span<const std::optional<Layout>>> device_layouts,
PjRtMemorySpace* memory_space) {
return absl::UnimplementedError(absl::StrCat(
"CreateBuffersForAsyncHostToDevice with ShapeSpec and Layout is "
Expand Down

0 comments on commit 17851a6

Please sign in to comment.