From a0c4be2edaa7c21c4b583be6cbb823aa730f9a3f Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Tue, 12 Nov 2024 00:49:08 -0800 Subject: [PATCH] PR #19237: [GPU] Fix passing of key-value store handle from client to compiler. Imported from GitHub PR https://github.com/openxla/xla/pull/19237 Copybara import of the project: -- 8080bd9fb8d169bf2add2c079650e4cad2e9fdff by Ilia Sergachev : [GPU] Fix passing of key-value store handle from client to compiler. Merging this change closes #19237 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/19237 from openxla:fix_kv_store 8080bd9fb8d169bf2add2c079650e4cad2e9fdff PiperOrigin-RevId: 695629316 --- xla/pjrt/gpu/BUILD | 6 +- xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 97 +++++++++++++++++++++++++ xla/pjrt/pjrt_stream_executor_client.cc | 3 + 3 files changed, 105 insertions(+), 1 deletion(-) diff --git a/xla/pjrt/gpu/BUILD b/xla/pjrt/gpu/BUILD index 914a0ff7cbb453..b270bec7d33861 100644 --- a/xla/pjrt/gpu/BUILD +++ b/xla/pjrt/gpu/BUILD @@ -179,6 +179,8 @@ xla_cc_test( "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", "//xla/pjrt:pjrt_stream_executor_client", + "//xla/pjrt/distributed", + "//xla/pjrt/distributed:client", "//xla/pjrt/distributed:in_memory_key_value_store", "//xla/service:gpu_plugin", "//xla/service:platform_util", @@ -186,6 +188,7 @@ xla_cc_test( "//xla/stream_executor:stream", "//xla/tests:literal_test_util", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -193,6 +196,7 @@ xla_cc_test( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", "@llvm-project//mlir:IR", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:env", @@ -201,7 +205,7 @@ xla_cc_test( "@tsl//tsl/platform:status", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", - "@tsl//tsl/platform:test_main", + "@tsl//tsl/platform:subprocess", ], ) diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index f65c32bee9f7aa..f10b93ce95fb38 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -44,6 +44,8 @@ limitations under the License. #include "xla/layout.h" #include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/pjrt/distributed/client.h" +#include "xla/pjrt/distributed/distributed.h" #include "xla/pjrt/distributed/in_memory_key_value_store.h" #include "xla/pjrt/gpu/gpu_topology.h" #include "xla/pjrt/host_memory_spaces.h" @@ -72,6 +74,7 @@ limitations under the License. #include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" +#include "tsl/platform/subprocess.h" #include "tsl/platform/threadpool.h" namespace xla { @@ -1753,5 +1756,99 @@ TEST(StreamExecutorGpuClientTest, GetDefaultLayout) { EXPECT_EQ(layout.element_size_in_bits(), 4); } +static const char* test_binary_name; +constexpr int kNumNodes = 2; + +TEST(StreamExecutorGpuClientTest, ShardedAutotuningWorks) { + tsl::SubProcess child[kNumNodes]; + for (int node_id = 0; node_id < kNumNodes; ++node_id) { + std::vector argv; + argv.push_back(test_binary_name); + argv.push_back(absl::StrFormat("--node_id=%d", node_id)); + child[node_id].SetProgram(test_binary_name, argv); + child[node_id].SetChannelAction(tsl::CHAN_STDOUT, tsl::ACTION_PIPE); + child[node_id].SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE); + ASSERT_TRUE(child[node_id].Start()) << "node " << node_id; + } + for (int node_id = 0; node_id < kNumNodes; ++node_id) { + std::string stdout_str; + std::string stderr_str; + int child_status = + child[node_id].Communicate(nullptr, &stdout_str, &stderr_str); + EXPECT_EQ(child_status, 0) << " node " << node_id << "\nstdout:\n" + << stdout_str << "\nstderr:\n" + << stderr_str; + } +} + +absl::Status ShardedAutotuningWorksTestBody(const int node_id) { + tsl::setenv("CUDA_VISIBLE_DEVICES", std::to_string(node_id).data(), + /*overwrite=*/true); + std::unique_ptr service; + if (node_id == 0) { + TF_ASSIGN_OR_RETURN(service, + xla::GetDistributedRuntimeService( + "[::]:12345", xla::CoordinationServiceImpl::Options{ + .num_nodes = kNumNodes})); + } + + xla::DistributedRuntimeClient::Options distributed_options; + distributed_options.node_id = node_id; + distributed_options.init_timeout = absl::Seconds(120); + auto distributed_client = + GetDistributedRuntimeClient("127.0.0.1:12345", distributed_options); + TF_QCHECK_OK(distributed_client->Connect()); + GpuClientOptions options; + options.node_id = node_id; + options.num_nodes = kNumNodes; + options.kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"gpu:"); + TF_ASSIGN_OR_RETURN(std::unique_ptr client, + GetStreamExecutorGpuClient(options)); + TF_RET_CHECK(client->platform_name() == "cuda"); + TF_RET_CHECK(client->addressable_device_count() == 1); + TF_RET_CHECK(client->device_count() == 2); + + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(auto module, xla::ParseMlirModuleString(R"mlir( + func.func public @main(%arg0: tensor<2x2048x2048xf32>) -> (tensor<2x2048x2048xf32> {jax.result_info = ""}) { + %0 = stablehlo.dot_general %arg0, %arg0, batching_dims = [0] x [0], contracting_dims = [2] x [1] + : (tensor<2x2048x2048xf32>, tensor<2x2048x2048xf32>) -> tensor<2x2048x2048xf32> + return %0 : tensor<2x2048x2048xf32> + })mlir", + context)); + + CompileOptions compile_options; + compile_options.executable_build_options.mutable_debug_options() + ->set_xla_gpu_shard_autotuning(true); + compile_options.executable_build_options.mutable_debug_options() + ->set_xla_gpu_triton_gemm_any(true); + compile_options.executable_build_options.mutable_debug_options() + ->set_xla_gpu_cublas_fallback(false); + TF_ASSIGN_OR_RETURN(auto executable, + client->Compile(*module, compile_options)); + TF_RET_CHECK(absl::StrContains( + executable->GetHloModules()->front()->ToString(), "triton_gemm")); + return absl::OkStatus(); +} + } // namespace } // namespace xla + +int main(int argc, char* argv[]) { + // Save name of binary so that it may invoke itself. + xla::test_binary_name = argv[0]; + int node_id = -1; + std::vector flag_list = { + tsl::Flag("node_id", &node_id, + "Node ID for ShardedAutotuningWorks test."), + }; + xla::AppendDebugOptionsFlags(&flag_list); + std::string usage = tsl::Flags::Usage(argv[0], flag_list); + tsl::Flags::Parse(&argc, argv, flag_list); + testing::InitGoogleTest(&argc, argv); + if (node_id >= 0) { + return !xla::ShardedAutotuningWorksTestBody(node_id).ok(); + } + return RUN_ALL_TESTS(); +} diff --git a/xla/pjrt/pjrt_stream_executor_client.cc b/xla/pjrt/pjrt_stream_executor_client.cc index eb52d5093bf8d3..8be5221d478323 100644 --- a/xla/pjrt/pjrt_stream_executor_client.cc +++ b/xla/pjrt/pjrt_stream_executor_client.cc @@ -3558,6 +3558,9 @@ absl::StatusOr> PjRtStreamExecutorClient::Compile(mlir::ModuleOp module, CompileOptions options) { XlaComputation xla_computation; + if (!options.executable_build_options.key_value_store()) { + options.executable_build_options.set_key_value_store(*key_value_store()); + } const ExecutableBuildOptions& exec_build_options = options.executable_build_options; TF_RETURN_IF_ERROR(MlirToXlaComputation(