Skip to content

Commit

Permalink
remove the bool var copy_start from the StreamAttributeAnnotator class
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenying-liu committed Mar 18, 2024
1 parent 31a575c commit 47a2ebd
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 28 deletions.
2 changes: 1 addition & 1 deletion xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2194,7 +2194,7 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines(
/*host_memory_offload_config=*/std::nullopt);
HloRematerialization::RematerializationSizes sizes;
pipeline.AddPass<HloRematerialization>(options, sizes);
pipeline.AddPass<StreamAttributeAnnotator>(/*copy_start=*/true);
pipeline.AddPass<StreamAttributeAnnotator>();
pipeline.AddPass<OptimizationBarrierExpander>();

TF_ASSIGN_OR_RETURN(bool changed, pipeline.Run(module));
Expand Down
37 changes: 18 additions & 19 deletions xla/service/gpu/stream_attribute_annotator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ namespace {
bool IsOnlyRootNonDefaultStream(HloComputation* computation) {
HloInstruction* root = computation->root_instruction();
auto root_gpu_config = root->backend_config<GpuBackendConfig>();
if (!root_gpu_config.ok() || root->opcode() == HloOpcode::kTuple) {
// Disable the annotation if its root is copy-start
if (!root_gpu_config.ok() || root->opcode() == HloOpcode::kTuple ||
root->opcode() == HloOpcode::kCopyStart) {
return false;
}
int64_t root_stream_id = root_gpu_config->operation_queue_id();
Expand Down Expand Up @@ -142,24 +144,21 @@ absl::StatusOr<bool> StreamAttributeAnnotator::Run(
if (!instr_gpu_config.ok()) {
continue;
}
if (!copy_start_) {
// For fusion instruction, only annotate
// when the root of fusion is a single instruction
// running on non-default stream.
if (instr->opcode() == HloOpcode::kFusion) {
TF_ASSIGN_OR_RETURN(bool comp_result,
AnnotateStreamAttributesForInstruction(
instr, instr_gpu_config.value()));
changed |= comp_result;
}
} else {
if (instr->opcode() == HloOpcode::kCopyStart) {
TF_ASSIGN_OR_RETURN(bool comp_result,
AnnotateStreamAttributesForCopyStart(
instr, channel_id));
changed |= comp_result;
continue;
}
// For fusion instruction, only annotate
// when the root of fusion is a single instruction
// running on non-default stream.
if (instr->opcode() == HloOpcode::kFusion) {
TF_ASSIGN_OR_RETURN(bool comp_result,
AnnotateStreamAttributesForInstruction(
instr, instr_gpu_config.value()));
changed |= comp_result;
}
if (instr->opcode() == HloOpcode::kCopyStart) {
TF_ASSIGN_OR_RETURN(bool comp_result,
AnnotateStreamAttributesForCopyStart(
instr, channel_id));
changed |= comp_result;
continue;
}

TF_ASSIGN_OR_RETURN(
Expand Down
7 changes: 0 additions & 7 deletions xla/service/gpu/stream_attribute_annotator.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@ namespace xla::gpu {

class StreamAttributeAnnotator : public HloModulePass {
public:
// copy_start_ is used to differentiate copy-start from
// other instructions as it doesn't need the annotations to split into
// AsyncStart and AsyncDone: copy-done is available already.
explicit StreamAttributeAnnotator(bool copy_start = false)
: HloModulePass(), copy_start_(copy_start) {}
absl::string_view name() const override {
return "stream-attribute-annotator";
}
Expand All @@ -58,8 +53,6 @@ class StreamAttributeAnnotator : public HloModulePass {
absl::StatusOr<bool> Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
private:
bool copy_start_;
};

} // namespace xla::gpu
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/stream_attribute_annotator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ TEST_F(StreamAttributeAnnotatorTest, CopyStartIsAnnotated) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloString));

StreamAttributeAnnotator attr_annotator(/*copy_start_done=*/true);
StreamAttributeAnnotator attr_annotator;
bool changed;
TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get()));
EXPECT_TRUE(changed);
Expand Down

0 comments on commit 47a2ebd

Please sign in to comment.