Skip to content

Commit

Permalink
[GPU] Fix passing of key-value store handle from client to compiler.
Browse files Browse the repository at this point in the history
  • Loading branch information
sergachev committed Nov 12, 2024
1 parent d821c2a commit 5b731fc
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 2 deletions.
6 changes: 5 additions & 1 deletion xla/pjrt/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -179,20 +179,24 @@ xla_cc_test(
"//xla/pjrt:pjrt_executable",
"//xla/pjrt:pjrt_future",
"//xla/pjrt:pjrt_stream_executor_client",
"//xla/pjrt/distributed",
"//xla/pjrt/distributed:client",
"//xla/pjrt/distributed:in_memory_key_value_store",
"//xla/service:gpu_plugin",
"//xla/service:platform_util",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:stream",
"//xla/tests:literal_test_util",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/util:command_line_flags",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest",
"@llvm-project//mlir:IR",
"@tsl//tsl/platform:casts",
"@tsl//tsl/platform:env",
Expand All @@ -201,7 +205,7 @@ xla_cc_test(
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:status_matchers",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test_main",
"@tsl//tsl/platform:subprocess",
],
)

Expand Down
1 change: 0 additions & 1 deletion xla/pjrt/gpu/se_gpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,6 @@ PjRtFuture<> StreamExecutorGpuClient::CopyRawSubBufferToHost(
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
StreamExecutorGpuClient::Compile(const XlaComputation& computation,
CompileOptions options) {
options.executable_build_options.set_key_value_store(kv_store_);
auto executable = PjRtStreamExecutorClient::Compile(computation, options);

#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
Expand Down
127 changes: 127 additions & 0 deletions xla/pjrt/gpu/se_gpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ limitations under the License.
#include "xla/layout.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/pjrt/distributed/client.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/pjrt/distributed/in_memory_key_value_store.h"
#include "xla/pjrt/gpu/gpu_topology.h"
#include "xla/pjrt/host_memory_spaces.h"
Expand Down Expand Up @@ -72,6 +74,7 @@ limitations under the License.
#include "tsl/platform/status.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/subprocess.h"
#include "tsl/platform/threadpool.h"

namespace xla {
Expand Down Expand Up @@ -1753,5 +1756,129 @@ TEST(StreamExecutorGpuClientTest, GetDefaultLayout) {
EXPECT_EQ(layout.element_size_in_bits(), 4);
}

class ShardedAutotuningTest : public ::testing::TestWithParam<bool> {
public:
static constexpr int kNumNodes = 2;
};

static const char* test_binary_name;

TEST_P(ShardedAutotuningTest, ShardedAutotuningWorks) {
tsl::SubProcess child[ShardedAutotuningTest::kNumNodes];
for (int node_id = 0; node_id < ShardedAutotuningTest::kNumNodes; ++node_id) {
std::vector<std::string> argv;
argv.push_back(test_binary_name);
argv.push_back(absl::StrFormat("--node_id=%d", node_id));
argv.push_back(absl::StrFormat("--use_xla_computation=%d", GetParam()));
child[node_id].SetProgram(test_binary_name, argv);
child[node_id].SetChannelAction(tsl::CHAN_STDOUT, tsl::ACTION_PIPE);
child[node_id].SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE);
ASSERT_TRUE(child[node_id].Start()) << "node " << node_id;
}
for (int node_id = 0; node_id < ShardedAutotuningTest::kNumNodes; ++node_id) {
std::string stdout_str;
std::string stderr_str;
int child_status =
child[node_id].Communicate(nullptr, &stdout_str, &stderr_str);
EXPECT_EQ(child_status, 0) << " node " << node_id << "\nstdout:\n"
<< stdout_str << "\nstderr:\n"
<< stderr_str;
}
}

absl::Status ShardedAutotuningWorksTestBody(const int node_id,
bool use_xla_computation) {
tsl::setenv("CUDA_VISIBLE_DEVICES", std::to_string(node_id).data(),
/*overwrite=*/true);
std::unique_ptr<xla::DistributedRuntimeService> service;
if (node_id == 0) {
TF_ASSIGN_OR_RETURN(
service,
xla::GetDistributedRuntimeService(
"[::]:12345", xla::CoordinationServiceImpl::Options{
.num_nodes = ShardedAutotuningTest::kNumNodes}));
}

xla::DistributedRuntimeClient::Options distributed_options;
distributed_options.node_id = node_id;
distributed_options.init_timeout = absl::Seconds(120);
auto distributed_client =
GetDistributedRuntimeClient("127.0.0.1:12345", distributed_options);
TF_QCHECK_OK(distributed_client->Connect());
GpuClientOptions options;
options.node_id = node_id;
options.num_nodes = ShardedAutotuningTest::kNumNodes;
options.kv_store = GetDistributedKeyValueStore(distributed_client,
/*key_prefix=*/"gpu:");
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> client,
GetStreamExecutorGpuClient(options));
TF_RET_CHECK(client->platform_name() == "cuda");
TF_RET_CHECK(client->addressable_device_count() == 1);
TF_RET_CHECK(client->device_count() == 2);

CompileOptions compile_options;
compile_options.executable_build_options.mutable_debug_options()
->set_xla_gpu_shard_autotuning(true);
compile_options.executable_build_options.mutable_debug_options()
->set_xla_gpu_triton_gemm_any(true);
compile_options.executable_build_options.mutable_debug_options()
->set_xla_gpu_cublas_fallback(false);

mlir::MLIRContext context;
TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> module,
ParseMlirModuleString(R"mlir(
func.func public @main(%arg0: tensor<2x2048x2048xf32>) ->
(tensor<2x2048x2048xf32> {jax.result_info = ""}) {
%0 = stablehlo.dot_general %arg0, %arg0, batching_dims = [0] x [0],
contracting_dims = [2] x [1]
: (tensor<2x2048x2048xf32>, tensor<2x2048x2048xf32>) ->
tensor<2x2048x2048xf32>
return %0 : tensor<2x2048x2048xf32>
})mlir",
context));
std::unique_ptr<PjRtLoadedExecutable> executable;
if (use_xla_computation) {
XlaComputation computation;
TF_RETURN_IF_ERROR(MlirToXlaComputation(*module, computation,
/*use_tuple_args=*/false,
/*return_tuple=*/false,
/*use_shardy=*/false));
TF_ASSIGN_OR_RETURN(executable,
client->Compile(computation, compile_options));
} else {
TF_ASSIGN_OR_RETURN(executable, client->Compile(*module, compile_options));
}

TF_RET_CHECK(absl::StrContains(
executable->GetHloModules()->front()->ToString(), "triton_gemm"));

return absl::OkStatus();
}

INSTANTIATE_TEST_SUITE_P(ShardedAutotuningTest, ShardedAutotuningTest,
::testing::Values(false, true));

} // namespace
} // namespace xla

int main(int argc, char* argv[]) {
// Save name of binary so that it may invoke itself.
xla::test_binary_name = argv[0];
int node_id = -1;
bool use_xla_computation = false;
std::vector<tsl::Flag> flag_list = {
tsl::Flag("node_id", &node_id,
"Node ID for ShardedAutotuningWorks test."),
tsl::Flag("use_xla_computation", &use_xla_computation,
"Test parameter for ShardedAutotuningWorks."),
};
xla::AppendDebugOptionsFlags(&flag_list);
std::string usage = tsl::Flags::Usage(argv[0], flag_list);
tsl::Flags::Parse(&argc, argv, flag_list);
testing::InitGoogleTest(&argc, argv);
if (node_id >= 0) {
return !xla::ShardedAutotuningWorksTestBody(node_id, use_xla_computation)
.ok();
}
return RUN_ALL_TESTS();
}
3 changes: 3 additions & 0 deletions xla/pjrt/pjrt_stream_executor_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3514,6 +3514,9 @@ PjRtStreamExecutorClient::CompileInternal(
CompileOptions options) {
tsl::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
VLOG(1) << "PjRtStreamExecutorClient::Compile";
if (!options.executable_build_options.key_value_store()) {
options.executable_build_options.set_key_value_store(*key_value_store());
}
options.executable_build_options.set_process_index(process_index());
TF_RET_CHECK(device_count() % addressable_device_count() == 0)
<< "Each process is expected to have the same number of devices";
Expand Down

0 comments on commit 5b731fc

Please sign in to comment.