Skip to content

Commit

Permalink
LHS deadlock avoidance
Browse files Browse the repository at this point in the history
  • Loading branch information
terryysun committed Nov 14, 2024
1 parent 6e21890 commit 14362ea
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 30 deletions.
1 change: 1 addition & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/container:node_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/gpu_hlo_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -536,8 +536,8 @@ absl::StatusOr<ScheduleMetadata> ScheduleGpuModule(
return GetSizeOfShape(shape, pointer_size);
};
auto scheduler_core = std::make_unique<DefaultSchedulerCore>(
shape_size_in_bytes, async_tracker.get(), latency_estimator.get(),
config);
shape_size_in_bytes, async_tracker.get(), latency_estimator.get(), config,
GpuScheduleCrossesOverlapLimit);
pipeline.AddPass<SchedulingInstructionAnnotator>();
pipeline.AddPass<LatencyHidingScheduler>(
std::move(latency_estimator), std::move(async_tracker),
Expand Down
84 changes: 84 additions & 0 deletions xla/service/gpu/gpu_latency_hiding_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,27 @@ bool IsAsyncPair(const HloInstruction& from, const HloInstruction& target) {
return IsGpuAsyncStart(from) && IsGpuAsyncDone(target);
}

// Count the maximum overlapping count in subgroups of group and other
size_t CountOverlappingRanks(const std::vector<ReplicaGroup>& group,
const std::vector<ReplicaGroup>& other) {
size_t overlapping_count = 0;
for (const auto& curr_replica_group : group) {
absl::flat_hash_set<int> curr_replica_ids;
for (const auto curr_replica_id : curr_replica_group.replica_ids()) {
curr_replica_ids.insert(curr_replica_id);
}

for (const auto& replica_group : other) {
size_t subgroup_count = 0;
for (const auto replica_id : replica_group.replica_ids()) {
if (curr_replica_ids.contains(replica_id)) ++subgroup_count;
}
overlapping_count = std::max(overlapping_count, subgroup_count);
}
}
return overlapping_count;
}

} // namespace

int64_t GetSizeOfShape(const Shape& shape, int pointer_size) {
Expand Down Expand Up @@ -141,6 +162,69 @@ CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo) {
}
}

bool GpuScheduleCrossesOverlapLimit(
const DefaultSchedulerCore::SchedulingState& sched_state,
const HloGraphNode* node) {
for (const auto& [resource, limit] : sched_state.max_concurrent_resource) {
// No resources in flight of this kind. Continue.
auto it = sched_state.resource_occupiers_in_flight.find(resource);
if (it == sched_state.resource_occupiers_in_flight.end() ||
it->second.size() == 0) {
continue;
}
// Number of instances of 'resource' needed if this instruction was
// to be scheduled.
const int64_t num_resources_needed =
sched_state.async_tracker->GetNumResourcesPerInstruction(
resource, node->GetInstr());
if (limit < num_resources_needed) {
return true;
}
}

if (node->GetResources().size() == 0) return false;
auto resource_type = node->GetResources().at(0).first;
// If the candidate collective has more than 1 overlapping ranks with
// in-flight collectives, they can form cyclic dependency and cannot be
// overlapped
if ((resource_type - AsyncTracker::GetFirstTargetDefinedResource()) ==
static_cast<int64_t>(GpuResourceType::kGpuAsyncStreamCollectives) &&
sched_state.resource_occupiers_in_flight.contains(resource_type) &&
sched_state.resource_occupiers_in_flight.at(resource_type).size() > 0) {
const HloInstruction& curr_hlo_inst = node->GetInstr();
if (hlo_query::IsAsyncCollectiveDoneOp(&curr_hlo_inst, true)) {
CHECK(
hlo_query::IsAsyncCollectiveStartOp(curr_hlo_inst.operand(0), true));
const HloInstruction* curr_start_inst =
curr_hlo_inst.operand(0)->async_wrapped_instruction();

// If candidate can be overlapped with in-flight collectives
bool can_overlap = true;
for (const auto occupier :
sched_state.resource_occupiers_in_flight.at(resource_type)) {
if (hlo_query::IsAsyncCollectiveStartOp(occupier, true)) {
// Number of overlapping ranks between this occupier and candidate
size_t overlapping_count = CountOverlappingRanks(
curr_start_inst->replica_groups(), occupier->replica_groups());
// VLOG(0) << "overlapping_count: " << overlapping_count;
if (overlapping_count > 1) {
can_overlap = false;
VLOG(3) << "Collectives have " << overlapping_count
<< "overlapping ranks and cannot be overlapped. Candidate "
"collective: "
<< curr_start_inst->ToString()
<< ", in flight collective: " << occupier->ToString();
break;
}
}
}
if (!can_overlap) return true;
}
}

return false;
}

//===--------------------------------------------------------------------===//
// GpuAsyncTrackerBase
//===--------------------------------------------------------------------===//
Expand Down
9 changes: 9 additions & 0 deletions xla/service/gpu/gpu_latency_hiding_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo);
// Returns size of the `shape` given the `pointer_size`.
int64_t GetSizeOfShape(const Shape& shape, int pointer_size);

// GPU overlap limit rule rule for scheduling candidate.
// On top of the default rule, we do not allow collectives with more than 1
// overlapping ranks to overlap. This is because the execution order of NCCL
// kernels is not deterministic and cannot be controlled by launch order at the
// moment. A cyclic dependency can be formed with at least 2 overlapping ranks.
bool GpuScheduleCrossesOverlapLimit(
const DefaultSchedulerCore::SchedulingState& sched_state,
const HloGraphNode* node);

// GPU specific resources for latency hiding scheduler.
//
// We use two different set of resources to model the scheduling of asynchronous
Expand Down
56 changes: 55 additions & 1 deletion xla/service/gpu/gpu_latency_hiding_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ TEST_F(GpuLatencyHidingSchedulerBaseTest,
std::vector<HloInstruction*> instruction_sequence =
schedule.sequence(module->entry_computation()).instructions();
// Since we allow 2 collectives in-flight, we should expect this pattern:
// ar(rs)-start -> rs(ar)-start -> add -> ar(rs)-done -> ar(rs)-done
// ar(rs)-start -> rs(ar)-start -> add -> ar(rs)-done -> rs(ar)-done
EXPECT_TRUE(GetIndexByName(instruction_sequence, "ar_0") <
GetIndexByName(instruction_sequence, "rs_1") &&
GetIndexByName(instruction_sequence, "rs_0") <
Expand All @@ -386,5 +386,59 @@ TEST_F(GpuLatencyHidingSchedulerBaseTest,
GetIndexByName(instruction_sequence, "rs_1"));
}

TEST_F(GpuLatencyHidingSchedulerBaseTest,
OverlappingRanksPreventOverlappingCollectives) {
absl::string_view kFdoProfile = R"pb(
costs { name: "add_0" cost_us: 100000.0 }
costs { name: "ar_0" cost_us: 10.0 }
costs { name: "rs_0" cost_us: 10.0 }
)pb";
;
absl::string_view kHloModule = R"(
HloModule m
reduce {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT _ = f32[] add(x, y)
}
ENTRY main {
p0 = f32[] parameter(0)
p1 = f32[2] parameter(1)
p2 = f32[2] parameter(2)
ar_0 = f32[] all-reduce-start(p0), to_apply=reduce, replica_groups={{0,1}}
ar_1 = f32[] all-reduce-done(ar_0)
rs_0 = ((f32[2]), f32[1]) reduce-scatter-start(p1), to_apply=reduce, dimensions={0}, replica_groups={{0, 1}}
rs_1 = f32[1] reduce-scatter-done(rs_0)
add_0 = f32[2] add(p1, p2)
ROOT _ = (f32[], f32[1], f32[2]) tuple(ar_1, rs_1, add_0)
}
)";

auto config = GetModuleConfig(kFdoProfile);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloModule, config));

TF_EXPECT_OK(ScheduleModule(module.get(), /*num_parallel_resources=*/2));
auto schedule = module->schedule();
std::vector<HloInstruction*> instruction_sequence =
schedule.sequence(module->entry_computation()).instructions();
// AR and RS have two ranks in common so cannot be overlapped, expect pattern:
// rs(ar)-start -> add -> rs(ar)-done -> ar(rs)-start -> ar(rs)-done
EXPECT_TRUE(GetIndexByName(instruction_sequence, "ar_1") <
GetIndexByName(instruction_sequence, "rs_0") ||
GetIndexByName(instruction_sequence, "rs_1") <
GetIndexByName(instruction_sequence, "ar_0"));
EXPECT_TRUE((GetIndexByName(instruction_sequence, "ar_0") <
GetIndexByName(instruction_sequence, "add_0") &&
GetIndexByName(instruction_sequence, "add_0") <
GetIndexByName(instruction_sequence, "ar_1")) ||
(GetIndexByName(instruction_sequence, "rs_0") <
GetIndexByName(instruction_sequence, "add_0") &&
GetIndexByName(instruction_sequence, "add_0") <
GetIndexByName(instruction_sequence, "rs_1")));
}

} // namespace
} // namespace xla::gpu
67 changes: 42 additions & 25 deletions xla/service/latency_hiding_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/map_util.h"
#include "xla/service/dump.h"
#include "xla/service/hlo_buffer.h"
Expand Down Expand Up @@ -1142,6 +1143,8 @@ class ReadySetLt {
const DefaultSchedulerCore::SchedulingState& sched_state_;
DefaultSchedulerCore::TargetSchedulingRule target_scheduling_rule_;
DefaultSchedulerCore::TargetSchedulingRule early_target_scheduling_rule_;
DefaultSchedulerCore::OverlapLimitRule
scheduling_instruction_crosses_overlap_limit_;

int ReadyIfScheduled(const HloGraphNode& gn) const {
int ready_nodes_if_scheduled = 0;
Expand Down Expand Up @@ -1275,9 +1278,9 @@ class ReadySetLt {
cand.node->GetResources());
int64_t num_conflicting_resources = 0;
for (int64_t resource : resources) {
if (!sched_state_.resources_in_flight.contains(resource)) continue;
if (!sched_state_.resource_occupiers_in_flight.count(resource)) continue;
num_conflicting_resources +=
sched_state_.resources_in_flight.at(resource);
sched_state_.resource_occupiers_in_flight.at(resource).size();
}
return num_conflicting_resources;
}
Expand Down Expand Up @@ -1307,26 +1310,29 @@ DefaultSchedulerCore::FindAndExtractBestNodeAvailable(
DefaultSchedulerCore::ShouldSkipNodeFunction should_skip_node) {
absl::InlinedVector<std::pair<HloGraphNode*, SkipNodeReason>, 2>
skipped_nodes_and_reasons;
auto scheduling_instruction_crosses_overlap_limit =
[&sched_state](const HloInstruction& instr) {
for (const auto& [resource, limit] :
sched_state.max_concurrent_resource) {
// No resources in flight of this kind. Continue.
auto it = sched_state.resources_in_flight.find(resource);
if (it == sched_state.resources_in_flight.end() || it->second == 0) {
continue;
}
// Number of instances of 'resource' needed if this instruction was to
// be scheduled.
const int64_t num_resources_needed =
sched_state.async_tracker->GetNumResourcesPerInstruction(resource,
instr);
if (limit < num_resources_needed) {
return true;
if (!scheduling_instruction_crosses_overlap_limit_) {
scheduling_instruction_crosses_overlap_limit_ =
[](const SchedulingState& sched_state, const HloGraphNode* node) {
for (const auto& [resource, limit] :
sched_state.max_concurrent_resource) {
// No resources in flight of this kind. Continue.
auto it = sched_state.resource_occupiers_in_flight.find(resource);
if (it == sched_state.resource_occupiers_in_flight.end() ||
it->second.size() == 0) {
continue;
}
// Number of instances of 'resource' needed if this instruction was
// to be scheduled.
const int64_t num_resources_needed =
sched_state.async_tracker->GetNumResourcesPerInstruction(
resource, node->GetInstr());
if (limit < num_resources_needed) {
return true;
}
}
}
return false;
};
return false;
};
}
VLOG(2) << "Current time: " << sched_state.current_time;
ReadySetLt ready_lt{&sched_state, target_scheduling_rule_,
early_target_scheduling_rule_};
Expand All @@ -1347,8 +1353,8 @@ DefaultSchedulerCore::FindAndExtractBestNodeAvailable(
}
// If this node would cause the max_concurrent_resource count to go beyond
// the limit do not schedule it and pass to the next node.
if (scheduling_instruction_crosses_overlap_limit(
(*ready_node_it)->GetInstr())) {
if (scheduling_instruction_crosses_overlap_limit_(sched_state,
*ready_node_it)) {
if (ready_chosen.node == nullptr) {
skipped_nodes_and_reasons.push_back(
{*ready_node_it, SkipNodeReason::kExceedsOverlapLimit});
Expand Down Expand Up @@ -1722,9 +1728,20 @@ absl::StatusOr<HloGraphNode::TimeCost> DefaultSchedulerCore::ScheduleNode(
++sched_state->scheduled_count;
for (auto& resource : n->GetResources()) {
if (resource.second == ResourceUsageType::kResourceRelease) {
--sched_state->resources_in_flight[resource.first];
sched_state->resource_occupiers_in_flight.at(resource.first)
.erase(&n->GetInstr());
} else if (resource.second == ResourceUsageType::kResourceOccupy) {
++sched_state->resources_in_flight[resource.first];
// For async collective done ops, save their corresponding start ops to
// the map
if (hlo_query::IsAsyncCollectiveDoneOp(&n->GetInstr(), true)) {
CHECK(hlo_query::IsAsyncCollectiveStartOp(n->GetInstr().operand(0),
true));
sched_state->resource_occupiers_in_flight[resource.first].insert(
n->GetInstr().operand(0));
} else {
sched_state->resource_occupiers_in_flight[resource.first].insert(
&n->GetInstr());
}
}
}
VLOG(10) << "Memory pressure before schedule: "
Expand Down
12 changes: 10 additions & 2 deletions xla/service/latency_hiding_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/container/node_hash_set.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -857,8 +858,9 @@ class DefaultSchedulerCore : public SchedulerCore {
std::vector<HloInstruction*> new_sequence_reversed;
// Units of time passed in the schedule. To keep track of latency hiding.
HloGraphNode::TimeCost current_time = 0;
// Number of resources in flight.
ResourceMap resources_in_flight;
// Resources and corresponding occupiers in flight.
absl::flat_hash_map<int64_t, absl::node_hash_set<const HloInstruction*>>
resource_occupiers_in_flight;
// Number of instructions using the key resource type in the set waiting to
// be scheduled.
ResourceMap resource_users_in_queue;
Expand Down Expand Up @@ -899,19 +901,24 @@ class DefaultSchedulerCore : public SchedulerCore {
config(config) {}
};

using OverlapLimitRule =
std::function<bool(const SchedulingState&, const HloGraphNode*)>;
using PostProcessingFn = std::function<void(SchedulingState&)>;

DefaultSchedulerCore(
HloCostAnalysis::ShapeSizeFunction shape_size_bytes,
const AsyncTracker* async_tracker,
const LatencyEstimator* latency_estimator, const SchedulerConfig& config,
OverlapLimitRule scheduling_instruction_crosses_overlap_limit = nullptr,
TargetSchedulingRule target_scheduling_rule = nullptr,
TargetSchedulingRule early_target_scheduling_rule = nullptr,
PostProcessingFn post_processing_fn = nullptr)
: shape_size_bytes_(shape_size_bytes),
async_tracker_(async_tracker),
latency_estimator_(latency_estimator),
config_(config),
scheduling_instruction_crosses_overlap_limit_(
scheduling_instruction_crosses_overlap_limit),
target_scheduling_rule_(target_scheduling_rule),
early_target_scheduling_rule_(early_target_scheduling_rule),
post_processing_fn_(post_processing_fn) {}
Expand Down Expand Up @@ -958,6 +965,7 @@ class DefaultSchedulerCore : public SchedulerCore {
SchedulerConfig config_;
TargetSchedulingRule target_scheduling_rule_ = nullptr;
TargetSchedulingRule early_target_scheduling_rule_ = nullptr;
OverlapLimitRule scheduling_instruction_crosses_overlap_limit_ = nullptr;
PostProcessingFn post_processing_fn_ = nullptr;
};

Expand Down

0 comments on commit 14362ea

Please sign in to comment.