Skip to content

Commit

Permalink
PR #19237: [GPU] Fix passing of key-value store handle from client to…
Browse files Browse the repository at this point in the history
… compiler.

Imported from GitHub PR #19237

Copybara import of the project:

--
8080bd9 by Ilia Sergachev <isergachev@nvidia.com>:

[GPU] Fix passing of key-value store handle from client to compiler.

Merging this change closes #19237

FUTURE_COPYBARA_INTEGRATE_REVIEW=#19237 from openxla:fix_kv_store 8080bd9
PiperOrigin-RevId: 695629316
  • Loading branch information
sergachev authored and Google-ML-Automation committed Nov 12, 2024
1 parent 61d921c commit a6fa889
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 1 deletion.
6 changes: 5 additions & 1 deletion xla/pjrt/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -179,20 +179,24 @@ xla_cc_test(
"//xla/pjrt:pjrt_executable",
"//xla/pjrt:pjrt_future",
"//xla/pjrt:pjrt_stream_executor_client",
"//xla/pjrt/distributed",
"//xla/pjrt/distributed:client",
"//xla/pjrt/distributed:in_memory_key_value_store",
"//xla/service:gpu_plugin",
"//xla/service:platform_util",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:stream",
"//xla/tests:literal_test_util",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/util:command_line_flags",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest",
"@llvm-project//mlir:IR",
"@tsl//tsl/platform:casts",
"@tsl//tsl/platform:env",
Expand All @@ -201,7 +205,7 @@ xla_cc_test(
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:status_matchers",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test_main",
"@tsl//tsl/platform:subprocess",
],
)

Expand Down
97 changes: 97 additions & 0 deletions xla/pjrt/gpu/se_gpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ limitations under the License.
#include "xla/layout.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/pjrt/distributed/client.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/pjrt/distributed/in_memory_key_value_store.h"
#include "xla/pjrt/gpu/gpu_topology.h"
#include "xla/pjrt/host_memory_spaces.h"
Expand Down Expand Up @@ -72,6 +74,7 @@ limitations under the License.
#include "tsl/platform/status.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/subprocess.h"
#include "tsl/platform/threadpool.h"

namespace xla {
Expand Down Expand Up @@ -1753,5 +1756,99 @@ TEST(StreamExecutorGpuClientTest, GetDefaultLayout) {
EXPECT_EQ(layout.element_size_in_bits(), 4);
}

static const char* test_binary_name;
constexpr int kNumNodes = 2;

TEST(StreamExecutorGpuClientTest, ShardedAutotuningWorks) {
tsl::SubProcess child[kNumNodes];
for (int node_id = 0; node_id < kNumNodes; ++node_id) {
std::vector<std::string> argv;
argv.push_back(test_binary_name);
argv.push_back(absl::StrFormat("--node_id=%d", node_id));
child[node_id].SetProgram(test_binary_name, argv);
child[node_id].SetChannelAction(tsl::CHAN_STDOUT, tsl::ACTION_PIPE);
child[node_id].SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE);
ASSERT_TRUE(child[node_id].Start()) << "node " << node_id;
}
for (int node_id = 0; node_id < kNumNodes; ++node_id) {
std::string stdout_str;
std::string stderr_str;
int child_status =
child[node_id].Communicate(nullptr, &stdout_str, &stderr_str);
EXPECT_EQ(child_status, 0) << " node " << node_id << "\nstdout:\n"
<< stdout_str << "\nstderr:\n"
<< stderr_str;
}
}

absl::Status ShardedAutotuningWorksTestBody(const int node_id) {
tsl::setenv("CUDA_VISIBLE_DEVICES", std::to_string(node_id).data(),
/*overwrite=*/true);
std::unique_ptr<xla::DistributedRuntimeService> service;
if (node_id == 0) {
TF_ASSIGN_OR_RETURN(service,
xla::GetDistributedRuntimeService(
"[::]:12345", xla::CoordinationServiceImpl::Options{
.num_nodes = kNumNodes}));
}

xla::DistributedRuntimeClient::Options distributed_options;
distributed_options.node_id = node_id;
distributed_options.init_timeout = absl::Seconds(120);
auto distributed_client =
GetDistributedRuntimeClient("127.0.0.1:12345", distributed_options);
TF_QCHECK_OK(distributed_client->Connect());
GpuClientOptions options;
options.node_id = node_id;
options.num_nodes = kNumNodes;
options.kv_store = GetDistributedKeyValueStore(distributed_client,
/*key_prefix=*/"gpu:");
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> client,
GetStreamExecutorGpuClient(options));
TF_RET_CHECK(client->platform_name() == "cuda");
TF_RET_CHECK(client->addressable_device_count() == 1);
TF_RET_CHECK(client->device_count() == 2);

mlir::MLIRContext context;
TF_ASSIGN_OR_RETURN(auto module, xla::ParseMlirModuleString(R"mlir(
func.func public @main(%arg0: tensor<2x2048x2048xf32>) -> (tensor<2x2048x2048xf32> {jax.result_info = ""}) {
%0 = stablehlo.dot_general %arg0, %arg0, batching_dims = [0] x [0], contracting_dims = [2] x [1]
: (tensor<2x2048x2048xf32>, tensor<2x2048x2048xf32>) -> tensor<2x2048x2048xf32>
return %0 : tensor<2x2048x2048xf32>
})mlir",
context));

CompileOptions compile_options;
compile_options.executable_build_options.mutable_debug_options()
->set_xla_gpu_shard_autotuning(true);
compile_options.executable_build_options.mutable_debug_options()
->set_xla_gpu_triton_gemm_any(true);
compile_options.executable_build_options.mutable_debug_options()
->set_xla_gpu_cublas_fallback(false);
TF_ASSIGN_OR_RETURN(auto executable,
client->Compile(*module, compile_options));
TF_RET_CHECK(absl::StrContains(
executable->GetHloModules()->front()->ToString(), "triton_gemm"));
return absl::OkStatus();
}

} // namespace
} // namespace xla

int main(int argc, char* argv[]) {
// Save name of binary so that it may invoke itself.
xla::test_binary_name = argv[0];
int node_id = -1;
std::vector<tsl::Flag> flag_list = {
tsl::Flag("node_id", &node_id,
"Node ID for ShardedAutotuningWorks test."),
};
xla::AppendDebugOptionsFlags(&flag_list);
std::string usage = tsl::Flags::Usage(argv[0], flag_list);
tsl::Flags::Parse(&argc, argv, flag_list);
testing::InitGoogleTest(&argc, argv);
if (node_id >= 0) {
return !xla::ShardedAutotuningWorksTestBody(node_id).ok();
}
return RUN_ALL_TESTS();
}
3 changes: 3 additions & 0 deletions xla/pjrt/pjrt_stream_executor_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3558,6 +3558,9 @@ absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
PjRtStreamExecutorClient::Compile(mlir::ModuleOp module,
CompileOptions options) {
XlaComputation xla_computation;
if (!options.executable_build_options.key_value_store()) {
options.executable_build_options.set_key_value_store(*key_value_store());
}
const ExecutableBuildOptions& exec_build_options =
options.executable_build_options;
TF_RETURN_IF_ERROR(MlirToXlaComputation(
Expand Down

0 comments on commit a6fa889

Please sign in to comment.