-
Notifications
You must be signed in to change notification settings - Fork 434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVIDIA GPU] LHS enhancement for multiple collective resources #19026
Open
terryysun
wants to merge
2
commits into
openxla:main
Choose a base branch
from
terryysun:terryysun/overlapping_collectives
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+204
−30
Open
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -114,6 +114,27 @@ bool IsAsyncPair(const HloInstruction& from, const HloInstruction& target) { | |
return IsGpuAsyncStart(from) && IsGpuAsyncDone(target); | ||
} | ||
|
||
// Count the maximum overlapping count in subgroups of group and other | ||
size_t CountOverlappingRanks(const std::vector<ReplicaGroup>& group, | ||
const std::vector<ReplicaGroup>& other) { | ||
size_t overlapping_count = 0; | ||
for (const auto& curr_replica_group : group) { | ||
absl::flat_hash_set<int> curr_replica_ids; | ||
for (const auto curr_replica_id : curr_replica_group.replica_ids()) { | ||
curr_replica_ids.insert(curr_replica_id); | ||
} | ||
|
||
for (const auto& replica_group : other) { | ||
size_t subgroup_count = 0; | ||
for (const auto replica_id : replica_group.replica_ids()) { | ||
if (curr_replica_ids.contains(replica_id)) ++subgroup_count; | ||
} | ||
overlapping_count = std::max(overlapping_count, subgroup_count); | ||
} | ||
} | ||
return overlapping_count; | ||
} | ||
|
||
} // namespace | ||
|
||
int64_t GetSizeOfShape(const Shape& shape, int pointer_size) { | ||
|
@@ -141,6 +162,69 @@ CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo) { | |
} | ||
} | ||
|
||
bool GpuScheduleCrossesOverlapLimit( | ||
const DefaultSchedulerCore::SchedulingState& sched_state, | ||
const HloGraphNode* node) { | ||
for (const auto& [resource, limit] : sched_state.max_concurrent_resource) { | ||
// No resources in flight of this kind. Continue. | ||
auto it = sched_state.resource_occupiers_in_flight.find(resource); | ||
if (it == sched_state.resource_occupiers_in_flight.end() || | ||
it->second.size() == 0) { | ||
continue; | ||
} | ||
// Number of instances of 'resource' needed if this instruction was | ||
// to be scheduled. | ||
const int64_t num_resources_needed = | ||
sched_state.async_tracker->GetNumResourcesPerInstruction( | ||
resource, node->GetInstr()); | ||
if (limit < num_resources_needed) { | ||
return true; | ||
} | ||
} | ||
|
||
if (node->GetResources().size() == 0) return false; | ||
auto resource_type = node->GetResources().at(0).first; | ||
// If the candidate collective has more than 1 overlapping ranks with | ||
// in-flight collectives, they can form cyclic dependency and cannot be | ||
// overlapped | ||
if ((resource_type - AsyncTracker::GetFirstTargetDefinedResource()) == | ||
static_cast<int64_t>(GpuResourceType::kGpuAsyncStreamCollectives) && | ||
sched_state.resource_occupiers_in_flight.contains(resource_type) && | ||
sched_state.resource_occupiers_in_flight.at(resource_type).size() > 0) { | ||
const HloInstruction& curr_hlo_inst = node->GetInstr(); | ||
if (hlo_query::IsAsyncCollectiveDoneOp(&curr_hlo_inst, true)) { | ||
CHECK( | ||
hlo_query::IsAsyncCollectiveStartOp(curr_hlo_inst.operand(0), true)); | ||
const HloInstruction* curr_start_inst = | ||
curr_hlo_inst.operand(0)->async_wrapped_instruction(); | ||
|
||
// If candidate can be overlapped with in-flight collectives | ||
bool can_overlap = true; | ||
for (const auto occupier : | ||
sched_state.resource_occupiers_in_flight.at(resource_type)) { | ||
if (hlo_query::IsAsyncCollectiveStartOp(occupier, true)) { | ||
// Number of overlapping ranks between this occupier and candidate | ||
size_t overlapping_count = CountOverlappingRanks( | ||
curr_start_inst->replica_groups(), occupier->replica_groups()); | ||
// VLOG(0) << "overlapping_count: " << overlapping_count; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems to be forgotten. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch! fixed |
||
if (overlapping_count > 1) { | ||
can_overlap = false; | ||
VLOG(3) << "Collectives have " << overlapping_count | ||
<< "overlapping ranks and cannot be overlapped. Candidate " | ||
"collective: " | ||
<< curr_start_inst->ToString() | ||
<< ", in flight collective: " << occupier->ToString(); | ||
break; | ||
} | ||
} | ||
} | ||
if (!can_overlap) return true; | ||
} | ||
} | ||
|
||
return false; | ||
} | ||
|
||
//===--------------------------------------------------------------------===// | ||
// GpuAsyncTrackerBase | ||
//===--------------------------------------------------------------------===// | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we always prefer the following style:
@golechwierowicz I'm not sure how you check and communicate our style preferences to external contributors. I leave this you. I'm also seeing different integer types (int, int64_t, size_t), just FYI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a linter/formatter I can run to get some warnings on these style preference? I'd love to follow the preferred style
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw in case this can save people a minute, I looked into https://google.github.io/styleguide/cppguide.html and here's what I found:
regarding curly braces in branching statements:
Google C++ Style Guide allows omitting curly braces for short statements. Nonetheless, I updated it to using curly braces as above are said to be exceptions.
regarding integer types:
I believe the usages of int/int64_t/size_t in this PR comply with the style guide, but I'm open to suggestions in case you guys have higher standard internally.