Skip to content

Commit

Permalink
[XLA:GPU] Support auto layouts in StreamExecutor PjRT path when using…
Browse files Browse the repository at this point in the history
… HLO as input

For now doing that behind a flag and as a minimal change.

The flag is `xla_pjrt_allow_auto_layout_in_hlo`

For StableHLO path, and non PjRT, and non-StreamExecutor paths, it's already like that by default.

PiperOrigin-RevId: 696793280
  • Loading branch information
mooskagh authored and Google-ML-Automation committed Nov 15, 2024
1 parent 521cd7b commit 10769ee
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 4 deletions.
8 changes: 8 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_dot_merger_threshold_mb(32);
opts.set_xla_enable_fast_math(false);
opts.set_xla_gpu_experimental_parallel_collective_overlap_limit(1);
opts.set_xla_pjrt_allow_auto_layout_in_hlo(false);
return opts;
}

Expand Down Expand Up @@ -2056,6 +2057,13 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_gpu_experimental_parallel_collective_overlap_limit(),
"This controls how many in-flight collectives "
"latency hiding scheduler can schedule."));
flag_list->push_back(tsl::Flag(
"xla_pjrt_allow_auto_layout_in_hlo",
bool_setter_for(&DebugOptions::set_xla_pjrt_allow_auto_layout_in_hlo),
debug_options->xla_pjrt_allow_auto_layout_in_hlo(),
"Experimental: Make unset entry computation layout mean auto layout "
"instead of default layout in HLO when run through PjRT. In other cases "
"(StableHLO or non-PjRT) the auto layout is already used."));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Expand Down
1 change: 1 addition & 0 deletions xla/pjrt/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ xla_cc_test(
"//xla/pjrt:pjrt_future",
"//xla/pjrt:pjrt_stream_executor_client",
"//xla/pjrt/distributed:in_memory_key_value_store",
"//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options",
"//xla/service:gpu_plugin",
"//xla/service:platform_util",
"//xla/stream_executor:device_memory",
Expand Down
33 changes: 33 additions & 0 deletions xla/pjrt/gpu/se_gpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ limitations under the License.
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/pjrt_future.h"
#include "xla/pjrt/pjrt_stream_executor_client.h"
#include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h"
#include "xla/service/platform_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
Expand Down Expand Up @@ -1753,5 +1754,37 @@ TEST(StreamExecutorGpuClientTest, GetDefaultLayout) {
EXPECT_EQ(layout.element_size_in_bits(), 4);
}

TEST(StreamExecutorGpuClientTest, AutoLayoutIsSupported) {
const char* hlo_text = R"(
HloModule DotLayout,
entry_computation_layout={(f32[2,3,5],f32[3,4,5])->f32[5,2,4]{2,1,0}}
ENTRY dot {
p0 = f32[2,3,5]{2,1,0} parameter(0)
p1 = f32[3,4,5]{2,1,0} parameter(1)
ROOT dot.1330.10585 = f32[5,2,4]{2,1,0} dot(p0, p1),
lhs_batch_dims={2}, lhs_contracting_dims={1},
rhs_batch_dims={2}, rhs_contracting_dims={0}
})";

TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> m,
ParseAndReturnUnverifiedModule(
hlo_text, {}, HloParserOptions().set_fill_missing_layouts(false)));

TF_ASSERT_OK_AND_ASSIGN(auto client,
GetStreamExecutorGpuClient(GpuClientOptions()));
CompileOptions compile_options;
compile_options.executable_build_options.mutable_debug_options()
->set_xla_pjrt_allow_auto_layout_in_hlo(true);
XlaComputation computation = m->ToProto();
TF_ASSERT_OK_AND_ASSIGN(auto executable,
client->Compile(computation, compile_options));
TF_ASSERT_OK_AND_ASSIGN(auto layouts, executable->GetParameterLayouts());
// Check that the assigned layouts are not default.
EXPECT_NE(layouts[0]->ToString(), "{2,1,0}");
EXPECT_NE(layouts[1]->ToString(), "{2,1,0}");
}

} // namespace
} // namespace xla
9 changes: 6 additions & 3 deletions xla/pjrt/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -683,11 +683,14 @@ absl::Status DetermineArgumentLayoutsFromCompileOptions(

// Assign a default layout based on `sharded_shape` to any array subshapes in
// `dst_shape` that are missing layouts.
auto assign_layouts = [&choose_compact_layout_for_shape_function](
const Shape& sharded_shape, Shape* dst_shape) {
const bool allow_auto_layout =
build_options && build_options->has_debug_options() &&
build_options->debug_options().xla_pjrt_allow_auto_layout_in_hlo();
auto assign_layouts = [&](const Shape& sharded_shape, Shape* dst_shape) {
return ShapeUtil::ForEachMutableSubshapeWithStatus(
dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
if (subshape->IsArray() && !subshape->has_layout()) {
if (subshape->IsArray() && !subshape->has_layout() &&
!allow_auto_layout) {
CHECK(ShapeUtil::IndexIsValid(sharded_shape, idx));
const Shape& sharded_subshape =
ShapeUtil::GetSubshape(sharded_shape, idx);
Expand Down
4 changes: 3 additions & 1 deletion xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,9 @@ message DebugOptions {
}
PGLEStrictnessLevel xla_gpu_pgle_accuracy_checker = 341;

// Next id: 344
bool xla_pjrt_allow_auto_layout_in_hlo = 344;

// Next id: 345

// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
Expand Down

0 comments on commit 10769ee

Please sign in to comment.