Skip to content

Commit

Permalink
1. Fix asan, msan race by invoking Stop() within PassBarrier() as opp…
Browse files Browse the repository at this point in the history
…osed to repeatedly within done callbacks.

2. Abort stale connect requests to prevent 2 connected tasks with the same identity at once.

PiperOrigin-RevId: 689501074
  • Loading branch information
Google-ML-Automation committed Oct 24, 2024
1 parent aa25741 commit 93eb762
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 15 deletions.
38 changes: 28 additions & 10 deletions xla/tsl/distributed_runtime/coordination/coordination_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ using tensorflow::CoordinationServiceError;
using tensorflow::DeviceInfo;
using tensorflow::KeyValueEntry;

constexpr char kClusterRegisterBarrierId[] =
"[Init]Wait_for_all_tasks_to_register";
constexpr absl::Duration kDevicePropagationTimeout = absl::Hours(1);
constexpr int kDefaultHeartbeatTimeoutMs = 10 * 1000; // 10 seconds
constexpr int kServiceToClientTimeoutMs = 10 * 1000; // 10 seconds
Expand Down Expand Up @@ -163,6 +165,7 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface {
const std::vector<CoordinatedTask>& participating_tasks,
StatusCallback done) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
StatusCallback ConnectAfterBarrierPasses(absl::string_view task_name,
uint64_t incarnation,
StatusCallback done);
// Checks if any task has stopped sending heartbeats.
void CheckHeartbeatTimeout();
Expand Down Expand Up @@ -704,18 +707,21 @@ absl::Status CoordinationServiceStandaloneImpl::RegisterTask(
}

StatusCallback CoordinationServiceStandaloneImpl::ConnectAfterBarrierPasses(
absl::string_view task_name, StatusCallback done) {
return [this, task = std::string(task_name),
absl::string_view task_name, uint64_t incarnation, StatusCallback done) {
return [this, task = std::string(task_name), incarnation,
done = std::move(done)](absl::Status s) mutable {
state_mu_.AssertHeld();
if (s.ok()) {
if (!s.ok()) {
done(s);
} else if (incarnation == cluster_state_[task]->GetTaskIncarnation()) {
// Connect task to service.
cluster_state_[task]->Connect();
done(s);
done(absl::OkStatus());
} else {
done(s);
// Initialization failed, stop service now.
Stop();
// Avoid using `AbortedError` which typically has retry semantics.
done(MakeCoordinationError(
absl::AlreadyExistsError("Aborted connect attempt as there is a "
"request from a newer incarnation.")));
}
};
}
Expand Down Expand Up @@ -768,9 +774,9 @@ void CoordinationServiceStandaloneImpl::RegisterTaskAsync(
// and the barrier has not succeeded yet.
// There is no state that needs to be cleaned up.
task_cluster_state->SetTaskIncarnation(incarnation);
BarrierAsyncLocked("[Init]Wait_for_all_tasks_to_register",
cluster_register_timeout_, task, {},
ConnectAfterBarrierPasses(task_name, std::move(done)));
BarrierAsyncLocked(
kClusterRegisterBarrierId, cluster_register_timeout_, task, {},
ConnectAfterBarrierPasses(task_name, incarnation, std::move(done)));
return;
}
task_cluster_state->SetTaskIncarnation(incarnation);
Expand Down Expand Up @@ -1514,6 +1520,18 @@ void CoordinationServiceStandaloneImpl::PassBarrier(std::string_view barrier_id,
callback(result);
}
barrier->done_callbacks.clear();
if (barrier_id == kClusterRegisterBarrierId && !result.ok()) {
// Stop service if register failed.
LOG(ERROR)
<< "Stopping coordination service as cluster registration failed. This "
"may be due to 1) some tasks crashed earlier before connecting, 2) "
"some tasks were never scheduled, or 3) scheduling delays. Consider "
"setting a longer initialization timeout if such delays are "
"expected, the timeout is currently set to: "
<< cluster_register_timeout_ << ".\n\nOriginal error: " << result;
Stop();
return;
}
// Special hook for shutdown barrier to disconnect tasks at the barrier and
// propagate errors to those that have not.
if (barrier_id == shutdown_barrier_id_) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2012,17 +2012,21 @@ TEST_F(CoordinateTwoTasksTest,
/*enable_shutdown_barrier=*/false,
/*enable_register_barrier=*/true);
absl::Status task0_status = absl::InternalError("uninitialized_status");
absl::Status restarted_task0_status =
absl::InternalError("uninitialized_status");
// Task 0 registers first.
coord_service_->RegisterTaskAsync(task_0_, incarnation_0_,
[](const absl::Status& s) {});
coord_service_->RegisterTaskAsync(
task_0_, incarnation_0_,
[&](const absl::Status& s) { task0_status = s; });
// Task 0 restarts with a new incarnation, and registers again.
// This should be allowed since all tasks have not joined the cluster yet.
coord_service_->RegisterTaskAsync(
task_0_, /*incarnation=*/incarnation_0_ + 1,
[&](const absl::Status& s) { task0_status = s; });
[&](const absl::Status& s) { restarted_task0_status = s; });
// Now all tasks will register in a synchronized fashion due to the barrier.
ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_));
ASSERT_OK(task0_status);
ASSERT_THAT(task0_status, StatusIs(absl::StatusCode::kAlreadyExists));
ASSERT_OK(restarted_task0_status);
// Task 0 joins again with the same incarnation.
// This is okay, it didn't restart, probably sent RPC twice due to network
// retries.
Expand Down Expand Up @@ -2064,7 +2068,8 @@ TEST_F(CoordinateTwoTasksTest, RegisterWithBarrier_Timeout) {
EnableCoordinationService(/*has_service_to_client_connection=*/false,
/*enable_shutdown_barrier=*/false,
/*enable_register_barrier=*/true);
// Task 0 joins without task 1. Times out eventually.
// Task 0 joins without task 1. Times out eventually as this function is
// blocking.
EXPECT_THAT(coord_service_->RegisterTask(task_0_, incarnation_0_),
StatusIs(absl::StatusCode::kDeadlineExceeded));
}
Expand Down

0 comments on commit 93eb762

Please sign in to comment.