Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA GPU] LHS enhancement for multiple collective resources #19026

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/container:node_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/gpu_hlo_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -536,8 +536,8 @@ absl::StatusOr<ScheduleMetadata> ScheduleGpuModule(
return GetSizeOfShape(shape, pointer_size);
};
auto scheduler_core = std::make_unique<DefaultSchedulerCore>(
shape_size_in_bytes, async_tracker.get(), latency_estimator.get(),
config);
shape_size_in_bytes, async_tracker.get(), latency_estimator.get(), config,
GpuScheduleCrossesOverlapLimit);
pipeline.AddPass<SchedulingInstructionAnnotator>();
pipeline.AddPass<LatencyHidingScheduler>(
std::move(latency_estimator), std::move(async_tracker),
Expand Down
84 changes: 84 additions & 0 deletions xla/service/gpu/gpu_latency_hiding_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,27 @@ bool IsAsyncPair(const HloInstruction& from, const HloInstruction& target) {
return IsGpuAsyncStart(from) && IsGpuAsyncDone(target);
}

// Count the maximum overlapping count in subgroups of group and other
size_t CountOverlappingRanks(const std::vector<ReplicaGroup>& group,
const std::vector<ReplicaGroup>& other) {
size_t overlapping_count = 0;
for (const auto& curr_replica_group : group) {
absl::flat_hash_set<int> curr_replica_ids;
for (const auto curr_replica_id : curr_replica_group.replica_ids()) {
curr_replica_ids.insert(curr_replica_id);
}

for (const auto& replica_group : other) {
size_t subgroup_count = 0;
for (const auto replica_id : replica_group.replica_ids()) {
if (curr_replica_ids.contains(replica_id)) ++subgroup_count;
}
overlapping_count = std::max(overlapping_count, subgroup_count);
}
}
return overlapping_count;
}

} // namespace

int64_t GetSizeOfShape(const Shape& shape, int pointer_size) {
Expand Down Expand Up @@ -141,6 +162,69 @@ CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo) {
}
}

bool GpuScheduleCrossesOverlapLimit(
const DefaultSchedulerCore::SchedulingState& sched_state,
const HloGraphNode* node) {
for (const auto& [resource, limit] : sched_state.max_concurrent_resource) {
// No resources in flight of this kind. Continue.
auto it = sched_state.resource_occupiers_in_flight.find(resource);
if (it == sched_state.resource_occupiers_in_flight.end() ||
it->second.size() == 0) {
continue;
}
// Number of instances of 'resource' needed if this instruction was
// to be scheduled.
const int64_t num_resources_needed =
sched_state.async_tracker->GetNumResourcesPerInstruction(
resource, node->GetInstr());
if (limit < num_resources_needed) {
return true;
}
}

if (node->GetResources().size() == 0) return false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we always prefer the following style:

if (condition) {
  return false;
}

@golechwierowicz I'm not sure how you check and communicate our style preferences to external contributors. I leave this you. I'm also seeing different integer types (int, int64_t, size_t), just FYI.

Copy link
Contributor Author

@terryysun terryysun Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a linter/formatter I can run to get some warnings on these style preference? I'd love to follow the preferred style

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw in case this can save people a minute, I looked into https://google.github.io/styleguide/cppguide.html and here's what I found:

regarding curly braces in branching statements:

// OK - braces are optional in this case.
if (x == kFoo) return new Foo();

// OK - condition fits on one line, body fits on another.
if (x == kBar)
  Bar(arg1, arg2, arg3);

Google C++ Style Guide allows omitting curly braces for short statements. Nonetheless, I updated it to using curly braces as above are said to be exceptions.

regarding integer types:

The standard library header <cstdint> defines types like int16_t, uint32_t, int64_t, etc. You should always use those in preference to short, unsigned long long and the like, when you need a guarantee on the size of an integer. Prefer to omit the std:: prefix for these types, as the extra 5 characters do not merit the added clutter. Of the built-in integer types, only int should be used. When appropriate, you are welcome to use standard type aliases like size_t and ptrdiff_t.

We use int very often, for integers we know are not going to be too big, e.g., loop counters. Use plain old int for such things. You should assume that an int is at least 32 bits, but don't assume that it has more than 32 bits. If you need a 64-bit integer type, use int64_t or uint64_t.

I believe the usages of int/int64_t/size_t in this PR comply with the style guide, but I'm open to suggestions in case you guys have higher standard internally.

auto resource_type = node->GetResources().at(0).first;
// If the candidate collective has more than 1 overlapping ranks with
// in-flight collectives, they can form cyclic dependency and cannot be
// overlapped
if ((resource_type - AsyncTracker::GetFirstTargetDefinedResource()) ==
static_cast<int64_t>(GpuResourceType::kGpuAsyncStreamCollectives) &&
sched_state.resource_occupiers_in_flight.contains(resource_type) &&
sched_state.resource_occupiers_in_flight.at(resource_type).size() > 0) {
const HloInstruction& curr_hlo_inst = node->GetInstr();
if (hlo_query::IsAsyncCollectiveDoneOp(&curr_hlo_inst, true)) {
CHECK(
hlo_query::IsAsyncCollectiveStartOp(curr_hlo_inst.operand(0), true));
const HloInstruction* curr_start_inst =
curr_hlo_inst.operand(0)->async_wrapped_instruction();

// If candidate can be overlapped with in-flight collectives
bool can_overlap = true;
for (const auto occupier :
sched_state.resource_occupiers_in_flight.at(resource_type)) {
if (hlo_query::IsAsyncCollectiveStartOp(occupier, true)) {
// Number of overlapping ranks between this occupier and candidate
size_t overlapping_count = CountOverlappingRanks(
curr_start_inst->replica_groups(), occupier->replica_groups());
// VLOG(0) << "overlapping_count: " << overlapping_count;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be forgotten.

Copy link
Contributor Author

@terryysun terryysun Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! fixed

if (overlapping_count > 1) {
can_overlap = false;
VLOG(3) << "Collectives have " << overlapping_count
<< "overlapping ranks and cannot be overlapped. Candidate "
"collective: "
<< curr_start_inst->ToString()
<< ", in flight collective: " << occupier->ToString();
break;
}
}
}
if (!can_overlap) return true;
}
}

return false;
}

//===--------------------------------------------------------------------===//
// GpuAsyncTrackerBase
//===--------------------------------------------------------------------===//
Expand Down
9 changes: 9 additions & 0 deletions xla/service/gpu/gpu_latency_hiding_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo);
// Returns size of the `shape` given the `pointer_size`.
int64_t GetSizeOfShape(const Shape& shape, int pointer_size);

// GPU overlap limit rule rule for scheduling candidate.
// On top of the default rule, we do not allow collectives with more than 1
// overlapping ranks to overlap. This is because the execution order of NCCL
// kernels is not deterministic and cannot be controlled by launch order at the
// moment. A cyclic dependency can be formed with at least 2 overlapping ranks.
bool GpuScheduleCrossesOverlapLimit(
const DefaultSchedulerCore::SchedulingState& sched_state,
const HloGraphNode* node);

// GPU specific resources for latency hiding scheduler.
//
// We use two different set of resources to model the scheduling of asynchronous
Expand Down
56 changes: 55 additions & 1 deletion xla/service/gpu/gpu_latency_hiding_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ TEST_F(GpuLatencyHidingSchedulerBaseTest,
std::vector<HloInstruction*> instruction_sequence =
schedule.sequence(module->entry_computation()).instructions();
// Since we allow 2 collectives in-flight, we should expect this pattern:
// ar(rs)-start -> rs(ar)-start -> add -> ar(rs)-done -> ar(rs)-done
// ar(rs)-start -> rs(ar)-start -> add -> ar(rs)-done -> rs(ar)-done
EXPECT_TRUE(GetIndexByName(instruction_sequence, "ar_0") <
GetIndexByName(instruction_sequence, "rs_1") &&
GetIndexByName(instruction_sequence, "rs_0") <
Expand All @@ -386,5 +386,59 @@ TEST_F(GpuLatencyHidingSchedulerBaseTest,
GetIndexByName(instruction_sequence, "rs_1"));
}

TEST_F(GpuLatencyHidingSchedulerBaseTest,
OverlappingRanksPreventOverlappingCollectives) {
absl::string_view kFdoProfile = R"pb(
costs { name: "add_0" cost_us: 100000.0 }
costs { name: "ar_0" cost_us: 10.0 }
costs { name: "rs_0" cost_us: 10.0 }
)pb";
;
absl::string_view kHloModule = R"(
HloModule m

reduce {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT _ = f32[] add(x, y)
}

ENTRY main {
p0 = f32[] parameter(0)
p1 = f32[2] parameter(1)
p2 = f32[2] parameter(2)
ar_0 = f32[] all-reduce-start(p0), to_apply=reduce, replica_groups={{0,1}}
ar_1 = f32[] all-reduce-done(ar_0)
rs_0 = ((f32[2]), f32[1]) reduce-scatter-start(p1), to_apply=reduce, dimensions={0}, replica_groups={{0, 1}}
rs_1 = f32[1] reduce-scatter-done(rs_0)
add_0 = f32[2] add(p1, p2)
ROOT _ = (f32[], f32[1], f32[2]) tuple(ar_1, rs_1, add_0)
}
)";

auto config = GetModuleConfig(kFdoProfile);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloModule, config));

TF_EXPECT_OK(ScheduleModule(module.get(), /*num_parallel_resources=*/2));
auto schedule = module->schedule();
std::vector<HloInstruction*> instruction_sequence =
schedule.sequence(module->entry_computation()).instructions();
// AR and RS have two ranks in common so cannot be overlapped, expect pattern:
// rs(ar)-start -> add -> rs(ar)-done -> ar(rs)-start -> ar(rs)-done
EXPECT_TRUE(GetIndexByName(instruction_sequence, "ar_1") <
GetIndexByName(instruction_sequence, "rs_0") ||
GetIndexByName(instruction_sequence, "rs_1") <
GetIndexByName(instruction_sequence, "ar_0"));
EXPECT_TRUE((GetIndexByName(instruction_sequence, "ar_0") <
GetIndexByName(instruction_sequence, "add_0") &&
GetIndexByName(instruction_sequence, "add_0") <
GetIndexByName(instruction_sequence, "ar_1")) ||
(GetIndexByName(instruction_sequence, "rs_0") <
GetIndexByName(instruction_sequence, "add_0") &&
GetIndexByName(instruction_sequence, "add_0") <
GetIndexByName(instruction_sequence, "rs_1")));
}

} // namespace
} // namespace xla::gpu
67 changes: 42 additions & 25 deletions xla/service/latency_hiding_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/map_util.h"
#include "xla/service/dump.h"
#include "xla/service/hlo_buffer.h"
Expand Down Expand Up @@ -1142,6 +1143,8 @@ class ReadySetLt {
const DefaultSchedulerCore::SchedulingState& sched_state_;
DefaultSchedulerCore::TargetSchedulingRule target_scheduling_rule_;
DefaultSchedulerCore::TargetSchedulingRule early_target_scheduling_rule_;
DefaultSchedulerCore::OverlapLimitRule
scheduling_instruction_crosses_overlap_limit_;

int ReadyIfScheduled(const HloGraphNode& gn) const {
int ready_nodes_if_scheduled = 0;
Expand Down Expand Up @@ -1275,9 +1278,9 @@ class ReadySetLt {
cand.node->GetResources());
int64_t num_conflicting_resources = 0;
for (int64_t resource : resources) {
if (!sched_state_.resources_in_flight.contains(resource)) continue;
if (!sched_state_.resource_occupiers_in_flight.count(resource)) continue;
num_conflicting_resources +=
sched_state_.resources_in_flight.at(resource);
sched_state_.resource_occupiers_in_flight.at(resource).size();
}
return num_conflicting_resources;
}
Expand Down Expand Up @@ -1307,26 +1310,29 @@ DefaultSchedulerCore::FindAndExtractBestNodeAvailable(
DefaultSchedulerCore::ShouldSkipNodeFunction should_skip_node) {
absl::InlinedVector<std::pair<HloGraphNode*, SkipNodeReason>, 2>
skipped_nodes_and_reasons;
auto scheduling_instruction_crosses_overlap_limit =
[&sched_state](const HloInstruction& instr) {
for (const auto& [resource, limit] :
sched_state.max_concurrent_resource) {
// No resources in flight of this kind. Continue.
auto it = sched_state.resources_in_flight.find(resource);
if (it == sched_state.resources_in_flight.end() || it->second == 0) {
continue;
}
// Number of instances of 'resource' needed if this instruction was to
// be scheduled.
const int64_t num_resources_needed =
sched_state.async_tracker->GetNumResourcesPerInstruction(resource,
instr);
if (limit < num_resources_needed) {
return true;
if (!scheduling_instruction_crosses_overlap_limit_) {
scheduling_instruction_crosses_overlap_limit_ =
[](const SchedulingState& sched_state, const HloGraphNode* node) {
for (const auto& [resource, limit] :
sched_state.max_concurrent_resource) {
// No resources in flight of this kind. Continue.
auto it = sched_state.resource_occupiers_in_flight.find(resource);
if (it == sched_state.resource_occupiers_in_flight.end() ||
it->second.size() == 0) {
continue;
}
// Number of instances of 'resource' needed if this instruction was
// to be scheduled.
const int64_t num_resources_needed =
sched_state.async_tracker->GetNumResourcesPerInstruction(
resource, node->GetInstr());
if (limit < num_resources_needed) {
return true;
}
}
}
return false;
};
return false;
};
}
VLOG(2) << "Current time: " << sched_state.current_time;
ReadySetLt ready_lt{&sched_state, target_scheduling_rule_,
early_target_scheduling_rule_};
Expand All @@ -1347,8 +1353,8 @@ DefaultSchedulerCore::FindAndExtractBestNodeAvailable(
}
// If this node would cause the max_concurrent_resource count to go beyond
// the limit do not schedule it and pass to the next node.
if (scheduling_instruction_crosses_overlap_limit(
(*ready_node_it)->GetInstr())) {
if (scheduling_instruction_crosses_overlap_limit_(sched_state,
*ready_node_it)) {
if (ready_chosen.node == nullptr) {
skipped_nodes_and_reasons.push_back(
{*ready_node_it, SkipNodeReason::kExceedsOverlapLimit});
Expand Down Expand Up @@ -1722,9 +1728,20 @@ absl::StatusOr<HloGraphNode::TimeCost> DefaultSchedulerCore::ScheduleNode(
++sched_state->scheduled_count;
for (auto& resource : n->GetResources()) {
if (resource.second == ResourceUsageType::kResourceRelease) {
--sched_state->resources_in_flight[resource.first];
sched_state->resource_occupiers_in_flight.at(resource.first)
.erase(&n->GetInstr());
} else if (resource.second == ResourceUsageType::kResourceOccupy) {
++sched_state->resources_in_flight[resource.first];
// For async collective done ops, save their corresponding start ops to
// the map
if (hlo_query::IsAsyncCollectiveDoneOp(&n->GetInstr(), true)) {
CHECK(hlo_query::IsAsyncCollectiveStartOp(n->GetInstr().operand(0),
true));
sched_state->resource_occupiers_in_flight[resource.first].insert(
n->GetInstr().operand(0));
} else {
sched_state->resource_occupiers_in_flight[resource.first].insert(
&n->GetInstr());
}
}
}
VLOG(10) << "Memory pressure before schedule: "
Expand Down
12 changes: 10 additions & 2 deletions xla/service/latency_hiding_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/container/node_hash_set.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -857,8 +858,9 @@ class DefaultSchedulerCore : public SchedulerCore {
std::vector<HloInstruction*> new_sequence_reversed;
// Units of time passed in the schedule. To keep track of latency hiding.
HloGraphNode::TimeCost current_time = 0;
// Number of resources in flight.
ResourceMap resources_in_flight;
// Resources and corresponding occupiers in flight.
absl::flat_hash_map<int64_t, absl::node_hash_set<const HloInstruction*>>
resource_occupiers_in_flight;
// Number of instructions using the key resource type in the set waiting to
// be scheduled.
ResourceMap resource_users_in_queue;
Expand Down Expand Up @@ -899,19 +901,24 @@ class DefaultSchedulerCore : public SchedulerCore {
config(config) {}
};

using OverlapLimitRule =
std::function<bool(const SchedulingState&, const HloGraphNode*)>;
using PostProcessingFn = std::function<void(SchedulingState&)>;

DefaultSchedulerCore(
HloCostAnalysis::ShapeSizeFunction shape_size_bytes,
const AsyncTracker* async_tracker,
const LatencyEstimator* latency_estimator, const SchedulerConfig& config,
OverlapLimitRule scheduling_instruction_crosses_overlap_limit = nullptr,
TargetSchedulingRule target_scheduling_rule = nullptr,
TargetSchedulingRule early_target_scheduling_rule = nullptr,
PostProcessingFn post_processing_fn = nullptr)
: shape_size_bytes_(shape_size_bytes),
async_tracker_(async_tracker),
latency_estimator_(latency_estimator),
config_(config),
scheduling_instruction_crosses_overlap_limit_(
scheduling_instruction_crosses_overlap_limit),
target_scheduling_rule_(target_scheduling_rule),
early_target_scheduling_rule_(early_target_scheduling_rule),
post_processing_fn_(post_processing_fn) {}
Expand Down Expand Up @@ -958,6 +965,7 @@ class DefaultSchedulerCore : public SchedulerCore {
SchedulerConfig config_;
TargetSchedulingRule target_scheduling_rule_ = nullptr;
TargetSchedulingRule early_target_scheduling_rule_ = nullptr;
OverlapLimitRule scheduling_instruction_crosses_overlap_limit_ = nullptr;
PostProcessingFn post_processing_fn_ = nullptr;
};

Expand Down