diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index 73d2c1143c34f..ee233faf5ff0b 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -292,6 +292,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_dot_merger_threshold_mb(32); opts.set_xla_enable_fast_math(false); opts.set_xla_gpu_experimental_parallel_collective_overlap_limit(1); + opts.set_xla_pjrt_allow_auto_layout_in_hlo(false); return opts; } @@ -2056,6 +2057,13 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_experimental_parallel_collective_overlap_limit(), "This controls how many in-flight collectives " "latency hiding scheduler can schedule.")); + flag_list->push_back(tsl::Flag( + "xla_pjrt_allow_auto_layout_in_hlo", + bool_setter_for(&DebugOptions::set_xla_pjrt_allow_auto_layout_in_hlo), + debug_options->xla_pjrt_allow_auto_layout_in_hlo(), + "Experimental: Make unset entry computation layout mean auto layout " + "instead of default layout in HLO when run through PjRT. In other cases " + "(StableHLO or non-PjRT) the auto layout is already used.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/xla/pjrt/gpu/BUILD b/xla/pjrt/gpu/BUILD index 914a0ff7cbb45..ae7cdd94e9dd8 100644 --- a/xla/pjrt/gpu/BUILD +++ b/xla/pjrt/gpu/BUILD @@ -180,6 +180,7 @@ xla_cc_test( "//xla/pjrt:pjrt_future", "//xla/pjrt:pjrt_stream_executor_client", "//xla/pjrt/distributed:in_memory_key_value_store", + "//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options", "//xla/service:gpu_plugin", "//xla/service:platform_util", "//xla/stream_executor:device_memory", diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index f65c32bee9f7a..f75452299b580 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -53,6 +53,7 @@ limitations under the License. #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/pjrt_stream_executor_client.h" +#include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h" #include "xla/service/platform_util.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -1753,5 +1754,37 @@ TEST(StreamExecutorGpuClientTest, GetDefaultLayout) { EXPECT_EQ(layout.element_size_in_bits(), 4); } +TEST(StreamExecutorGpuClientTest, AutoLayoutIsSupported) { + const char* hlo_text = R"( + HloModule DotLayout, + entry_computation_layout={(f32[2,3,5],f32[3,4,5])->f32[5,2,4]{2,1,0}} + + ENTRY dot { + p0 = f32[2,3,5]{2,1,0} parameter(0) + p1 = f32[3,4,5]{2,1,0} parameter(1) + ROOT dot.1330.10585 = f32[5,2,4]{2,1,0} dot(p0, p1), + lhs_batch_dims={2}, lhs_contracting_dims={1}, + rhs_batch_dims={2}, rhs_contracting_dims={0} + })"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr m, + ParseAndReturnUnverifiedModule( + hlo_text, {}, HloParserOptions().set_fill_missing_layouts(false))); + + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + CompileOptions compile_options; + compile_options.executable_build_options.mutable_debug_options() + ->set_xla_pjrt_allow_auto_layout_in_hlo(true); + XlaComputation computation = m->ToProto(); + TF_ASSERT_OK_AND_ASSIGN(auto executable, + client->Compile(computation, compile_options)); + TF_ASSERT_OK_AND_ASSIGN(auto layouts, executable->GetParameterLayouts()); + // Check that the assigned layouts are not default. + EXPECT_NE(layouts[0]->ToString(), "{2,1,0}"); + EXPECT_NE(layouts[1]->ToString(), "{2,1,0}"); +} + } // namespace } // namespace xla diff --git a/xla/pjrt/utils.cc b/xla/pjrt/utils.cc index b50fa31b648db..3b4130f0f8c7e 100644 --- a/xla/pjrt/utils.cc +++ b/xla/pjrt/utils.cc @@ -683,11 +683,14 @@ absl::Status DetermineArgumentLayoutsFromCompileOptions( // Assign a default layout based on `sharded_shape` to any array subshapes in // `dst_shape` that are missing layouts. - auto assign_layouts = [&choose_compact_layout_for_shape_function]( - const Shape& sharded_shape, Shape* dst_shape) { + const bool allow_auto_layout = + build_options && build_options->has_debug_options() && + build_options->debug_options().xla_pjrt_allow_auto_layout_in_hlo(); + auto assign_layouts = [&](const Shape& sharded_shape, Shape* dst_shape) { return ShapeUtil::ForEachMutableSubshapeWithStatus( dst_shape, [&](Shape* subshape, const ShapeIndex& idx) { - if (subshape->IsArray() && !subshape->has_layout()) { + if (subshape->IsArray() && !subshape->has_layout() && + !allow_auto_layout) { CHECK(ShapeUtil::IndexIsValid(sharded_shape, idx)); const Shape& sharded_subshape = ShapeUtil::GetSubshape(sharded_shape, idx); diff --git a/xla/xla.proto b/xla/xla.proto index 803a5af0764c3..1778db2085783 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -1042,7 +1042,9 @@ message DebugOptions { } PGLEStrictnessLevel xla_gpu_pgle_accuracy_checker = 341; - // Next id: 344 + bool xla_pjrt_allow_auto_layout_in_hlo = 344; + + // Next id: 345 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.