diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 1ff7902757bd0..cd59543000bbf 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -123,10 +123,10 @@ class AsyncHostToDeviceTransferManager public: static absl::StatusOr> Create(absl::Span shape_specs, - std::optional> device_layouts, + std::optional>> device_layouts, PjRtStreamExecutorDevice* device, PjRtStreamExecutorClient* client, PjRtMemorySpace* memory_space) { - if (device_layouts != std::nullopt && + if (device_layouts.has_value() && device_layouts->size() != shape_specs.size()) { return InvalidArgument( "Number of layouts %d does not match the number of shapes %d", @@ -153,14 +153,14 @@ class AsyncHostToDeviceTransferManager std::make_shared(client->thread_pool())); Shape& device_shape = device_shapes.emplace_back( ShapeUtil::MakeShape(shape_spec.element_type, shape_spec.dims)); - if (device_layouts == std::nullopt) { + if (device_layouts.has_value() && (*device_layouts)[i].has_value()) { + *device_shape.mutable_layout() = *(*device_layouts)[i]; + } else { TF_ASSIGN_OR_RETURN(device_shape, client->client() ->backend() .transfer_manager() ->ChooseCompactLayoutForShape(device_shape)); - } else { - *device_shape.mutable_layout() = (*device_layouts)[i]; } LocalDeviceState* local_device = device->local_device_state(); se::Stream* h2d_stream = local_device->host_to_device_stream(); @@ -555,7 +555,7 @@ absl::string_view StreamExecutorGpuClient::platform_version() const { absl::StatusOr> StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( absl::Span shape_specs, - std::optional> device_layouts, + std::optional>> device_layouts, PjRtDevice* device) { auto* stream_executor_device = tensorflow::down_cast(device); @@ -581,7 +581,7 @@ StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( absl::StatusOr> StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( absl::Span shape_specs, - std::optional> device_layouts, + std::optional>> device_layouts, PjRtMemorySpace* memory_space) { CHECK_EQ(memory_space->devices().size(), 1); PjRtDevice* device = memory_space->devices()[0]; diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client.h b/xla/pjrt/gpu/se_gpu_pjrt_client.h index 4c13520c598a9..21b7d1e382be6 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -208,7 +208,7 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { absl::StatusOr> CreateBuffersForAsyncHostToDevice( absl::Span shape_specs, - std::optional> device_layouts, + std::optional>> device_layouts, PjRtDevice* device) override; absl::StatusOr> @@ -218,7 +218,7 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { absl::StatusOr> CreateBuffersForAsyncHostToDevice( absl::Span shape_specs, - std::optional> device_layouts, + std::optional>> device_layouts, PjRtMemorySpace* memory_space) override; absl::StatusOr> diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index c9f64a8586c93..f65c32bee9f7a 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -488,12 +488,12 @@ TEST(StreamExecutorGpuClientTest, ToLiteralAsyncWithNonCompactLayout) { spec.element_type = src_literal.shape().element_type(); spec.dims = DimensionVector(src_literal.shape().dimensions().begin(), src_literal.shape().dimensions().end()); + std::vector> device_layouts = { + std::make_optional(transposed_shape.layout())}; TF_ASSERT_OK_AND_ASSIGN( auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( - {spec}, - std::make_optional>( - {transposed_shape.layout()}), + {spec}, device_layouts, client->addressable_devices()[0]->memory_spaces()[0])); auto buffer = transfer_manager->RetrieveBuffer(0); diff --git a/xla/pjrt/pjrt_client.h b/xla/pjrt/pjrt_client.h index 383ebe44c3edb..f9e34f0056308 100644 --- a/xla/pjrt/pjrt_client.h +++ b/xla/pjrt/pjrt_client.h @@ -765,12 +765,16 @@ class PjRtClient { }; // Returns a manager for async transfers into a set of buffers with on-host - // shapes defined by 'shape_specs' and optional `device_layouts`. The - // `device_layout` is used when non-compact layouts are preferred. + // shapes defined by 'shape_specs' and optional `device_layouts`. + // + // If the desired layout of one or more buffers is not specified in + // `device_layouts`, then those buffers will use the default device layout. If + // `device_layouts` itself is not specified, then all buffers will use the + // default device layout. virtual absl::StatusOr> CreateBuffersForAsyncHostToDevice( absl::Span shape_specs, - std::optional> device_layouts, + std::optional>> device_layouts, PjRtDevice* device) { return absl::UnimplementedError(absl::StrCat( "CreateBuffersForAsyncHostToDevice with ShapeSpec and Layout is " @@ -782,7 +786,7 @@ class PjRtClient { virtual absl::StatusOr> CreateBuffersForAsyncHostToDevice( absl::Span shape_specs, - std::optional> device_layouts, + std::optional>> device_layouts, PjRtMemorySpace* memory_space) { return absl::UnimplementedError(absl::StrCat( "CreateBuffersForAsyncHostToDevice with ShapeSpec and Layout is "