From 9b9c1b485f02d180af9095d7e1a6b9b5330117f6 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Tue, 12 Nov 2024 00:49:08 -0800 Subject: [PATCH] PR #19237: [GPU] Fix passing of key-value store handle from client to compiler. Imported from GitHub PR https://github.com/openxla/xla/pull/19237 Copybara import of the project: -- 99336a672505d7efa9aacfea492585afb22a5dd2 by Ilia Sergachev : [GPU] Fix passing of key-value store handle from client to compiler. Merging this change closes #19237 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/19237 from openxla:fix_kv_store 99336a672505d7efa9aacfea492585afb22a5dd2 PiperOrigin-RevId: 695629316 --- xla/pjrt/gpu/BUILD | 6 +- xla/pjrt/gpu/se_gpu_pjrt_client.cc | 1 - xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 127 ++++++++++++++++++++++++ xla/pjrt/pjrt_stream_executor_client.cc | 3 + 4 files changed, 135 insertions(+), 2 deletions(-) diff --git a/xla/pjrt/gpu/BUILD b/xla/pjrt/gpu/BUILD index ae7cdd94e9dd8..c3e6062d9a855 100644 --- a/xla/pjrt/gpu/BUILD +++ b/xla/pjrt/gpu/BUILD @@ -179,6 +179,8 @@ xla_cc_test( "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", "//xla/pjrt:pjrt_stream_executor_client", + "//xla/pjrt/distributed", + "//xla/pjrt/distributed:client", "//xla/pjrt/distributed:in_memory_key_value_store", "//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options", "//xla/service:gpu_plugin", @@ -187,6 +189,7 @@ xla_cc_test( "//xla/stream_executor:stream", "//xla/tests:literal_test_util", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -194,6 +197,7 @@ xla_cc_test( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", "@llvm-project//mlir:IR", "@tsl//tsl/platform:casts", "@tsl//tsl/platform:env", @@ -202,7 +206,7 @@ xla_cc_test( "@tsl//tsl/platform:status", "@tsl//tsl/platform:status_matchers", "@tsl//tsl/platform:statusor", - "@tsl//tsl/platform:test_main", + "@tsl//tsl/platform:subprocess", ], ) diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/xla/pjrt/gpu/se_gpu_pjrt_client.cc index cd59543000bbf..a1859318d5f7a 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -779,7 +779,6 @@ PjRtFuture<> StreamExecutorGpuClient::CopyRawSubBufferToHost( absl::StatusOr> StreamExecutorGpuClient::Compile(const XlaComputation& computation, CompileOptions options) { - options.executable_build_options.set_key_value_store(kv_store_); auto executable = PjRtStreamExecutorClient::Compile(computation, options); #if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index f75452299b580..1081896f70408 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -44,6 +44,8 @@ limitations under the License. #include "xla/layout.h" #include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/pjrt/distributed/client.h" +#include "xla/pjrt/distributed/distributed.h" #include "xla/pjrt/distributed/in_memory_key_value_store.h" #include "xla/pjrt/gpu/gpu_topology.h" #include "xla/pjrt/host_memory_spaces.h" @@ -73,6 +75,7 @@ limitations under the License. #include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" +#include "tsl/platform/subprocess.h" #include "tsl/platform/threadpool.h" namespace xla { @@ -1786,5 +1789,129 @@ TEST(StreamExecutorGpuClientTest, AutoLayoutIsSupported) { EXPECT_NE(layouts[1]->ToString(), "{2,1,0}"); } +class ShardedAutotuningTest : public ::testing::TestWithParam { + public: + static constexpr int kNumNodes = 2; +}; + +static const char* test_binary_name; + +TEST_P(ShardedAutotuningTest, ShardedAutotuningWorks) { + tsl::SubProcess child[ShardedAutotuningTest::kNumNodes]; + for (int node_id = 0; node_id < ShardedAutotuningTest::kNumNodes; ++node_id) { + std::vector argv; + argv.push_back(test_binary_name); + argv.push_back(absl::StrFormat("--node_id=%d", node_id)); + argv.push_back(absl::StrFormat("--use_xla_computation=%d", GetParam())); + child[node_id].SetProgram(test_binary_name, argv); + child[node_id].SetChannelAction(tsl::CHAN_STDOUT, tsl::ACTION_PIPE); + child[node_id].SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE); + ASSERT_TRUE(child[node_id].Start()) << "node " << node_id; + } + for (int node_id = 0; node_id < ShardedAutotuningTest::kNumNodes; ++node_id) { + std::string stdout_str; + std::string stderr_str; + int child_status = + child[node_id].Communicate(nullptr, &stdout_str, &stderr_str); + EXPECT_EQ(child_status, 0) << " node " << node_id << "\nstdout:\n" + << stdout_str << "\nstderr:\n" + << stderr_str; + } +} + +absl::Status ShardedAutotuningWorksTestBody(const int node_id, + bool use_xla_computation) { + tsl::setenv("CUDA_VISIBLE_DEVICES", std::to_string(node_id).data(), + /*overwrite=*/true); + std::unique_ptr service; + if (node_id == 0) { + TF_ASSIGN_OR_RETURN( + service, + xla::GetDistributedRuntimeService( + "[::]:12345", xla::CoordinationServiceImpl::Options{ + .num_nodes = ShardedAutotuningTest::kNumNodes})); + } + + xla::DistributedRuntimeClient::Options distributed_options; + distributed_options.node_id = node_id; + distributed_options.init_timeout = absl::Seconds(120); + auto distributed_client = + GetDistributedRuntimeClient("127.0.0.1:12345", distributed_options); + TF_QCHECK_OK(distributed_client->Connect()); + GpuClientOptions options; + options.node_id = node_id; + options.num_nodes = ShardedAutotuningTest::kNumNodes; + options.kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"gpu:"); + TF_ASSIGN_OR_RETURN(std::unique_ptr client, + GetStreamExecutorGpuClient(options)); + TF_RET_CHECK(client->platform_name() == "cuda"); + TF_RET_CHECK(client->addressable_device_count() == 1); + TF_RET_CHECK(client->device_count() == 2); + + CompileOptions compile_options; + compile_options.executable_build_options.mutable_debug_options() + ->set_xla_gpu_shard_autotuning(true); + compile_options.executable_build_options.mutable_debug_options() + ->set_xla_gpu_triton_gemm_any(true); + compile_options.executable_build_options.mutable_debug_options() + ->set_xla_gpu_cublas_fallback(false); + + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(R"mlir( + func.func public @main(%arg0: tensor<2x2048x2048xf32>) -> + (tensor<2x2048x2048xf32> {jax.result_info = ""}) { + %0 = stablehlo.dot_general %arg0, %arg0, batching_dims = [0] x [0], + contracting_dims = [2] x [1] + : (tensor<2x2048x2048xf32>, tensor<2x2048x2048xf32>) -> + tensor<2x2048x2048xf32> + return %0 : tensor<2x2048x2048xf32> + })mlir", + context)); + std::unique_ptr executable; + if (use_xla_computation) { + XlaComputation computation; + TF_RETURN_IF_ERROR(MlirToXlaComputation(*module, computation, + /*use_tuple_args=*/false, + /*return_tuple=*/false, + /*use_shardy=*/false)); + TF_ASSIGN_OR_RETURN(executable, + client->Compile(computation, compile_options)); + } else { + TF_ASSIGN_OR_RETURN(executable, client->Compile(*module, compile_options)); + } + + TF_RET_CHECK(absl::StrContains( + executable->GetHloModules()->front()->ToString(), "triton_gemm")); + + return absl::OkStatus(); +} + +INSTANTIATE_TEST_SUITE_P(ShardedAutotuningTest, ShardedAutotuningTest, + ::testing::Values(false, true)); + } // namespace } // namespace xla + +int main(int argc, char* argv[]) { + // Save name of binary so that it may invoke itself. + xla::test_binary_name = argv[0]; + int node_id = -1; + bool use_xla_computation = false; + std::vector flag_list = { + tsl::Flag("node_id", &node_id, + "Node ID for ShardedAutotuningWorks test."), + tsl::Flag("use_xla_computation", &use_xla_computation, + "Test parameter for ShardedAutotuningWorks."), + }; + xla::AppendDebugOptionsFlags(&flag_list); + std::string usage = tsl::Flags::Usage(argv[0], flag_list); + tsl::Flags::Parse(&argc, argv, flag_list); + testing::InitGoogleTest(&argc, argv); + if (node_id >= 0) { + return !xla::ShardedAutotuningWorksTestBody(node_id, use_xla_computation) + .ok(); + } + return RUN_ALL_TESTS(); +} diff --git a/xla/pjrt/pjrt_stream_executor_client.cc b/xla/pjrt/pjrt_stream_executor_client.cc index eb52d5093bf8d..f3e6b847bf01e 100644 --- a/xla/pjrt/pjrt_stream_executor_client.cc +++ b/xla/pjrt/pjrt_stream_executor_client.cc @@ -3514,6 +3514,9 @@ PjRtStreamExecutorClient::CompileInternal( CompileOptions options) { tsl::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile"); VLOG(1) << "PjRtStreamExecutorClient::Compile"; + if (!options.executable_build_options.key_value_store()) { + options.executable_build_options.set_key_value_store(*key_value_store()); + } options.executable_build_options.set_process_index(process_index()); TF_RET_CHECK(device_count() % addressable_device_count() == 0) << "Each process is expected to have the same number of devices";