diff --git a/xla/service/gpu/transforms/horizontal_loop_fusion.cc b/xla/service/gpu/transforms/horizontal_loop_fusion.cc index c02ab1da5e5e2..0eed83f90eb74 100644 --- a/xla/service/gpu/transforms/horizontal_loop_fusion.cc +++ b/xla/service/gpu/transforms/horizontal_loop_fusion.cc @@ -141,6 +141,11 @@ class HorizontalLoopFusionImpl { std::string prefix_; }; // HorizontalLoopFusionImpl +bool IsConcatenationInputFusion(const HloInstruction& instr) { + return instr.IsInputFusion() && + instr.fused_expression_root()->opcode() == HloOpcode::kConcatenate; +} + bool IsFusibleCandidate(const HloInstruction& instr) { // For now, we do not support fusing instruction with control flow. if (!instr.control_successors().empty() || @@ -158,8 +163,7 @@ bool IsFusibleCandidate(const HloInstruction& instr) { return true; } - // Exclude fusions other than kLoop. - if (!instr.IsLoopFusion()) { + if (!(instr.IsLoopFusion() || IsConcatenationInputFusion(instr))) { return false; } @@ -196,7 +200,8 @@ bool IsProfitableFusionCandidate(const HloInstruction& instr, // GPU thread can only process 1 element. From experience, we enable larger // tensor size threshold for kLoop fusion. const int64_t kShapeThreshold = - sliced_input_fusion ? 128 * 2048 : 8192 * 8192; + (sliced_input_fusion || IsConcatenationInputFusion(instr)) ? 128 * 2048 + : 8192 * 8192; const int64_t kInstrCountThreshold = sliced_input_fusion ? 30 : 128; const HloInstruction* root = (instr.opcode() == HloOpcode::kFusion) ? instr.fused_expression_root() @@ -253,8 +258,6 @@ void HorizontalLoopFusionImpl::FusionCandidates::Initialize( std::vector ordered_fusible_candidates; for (HloInstruction* opnd : consumer->operands()) { HloInstruction* predecessor = opnd->LatestNonGteAncestor(); - // We support kLoop fusion and element-wise HLOs now. We may extend the - // support list if needs arise. if (IsFusibleCandidate(*predecessor)) { if (fusible_candidates.insert(predecessor).second) { // Add unseen fusion to ordered list. diff --git a/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc b/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc index 0e168898b35b7..19ac0c32c5fea 100644 --- a/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc +++ b/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc @@ -624,6 +624,60 @@ e { EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value()); } +TEST_F(HorizontalLoopFusionTest, FuseSmallConcatenationInputFusions) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +a { + p = s4[1] parameter(0) + q = s4[2] parameter(1) + c = s4[3] concatenate(p, q), dimensions={0} +} + +b { + p = s4[4] parameter(0) + q = s4[5] parameter(1) + c = s4[9] concatenate(p, q), dimensions={0} +} + +e { + p = s4[1] constant({...}) + q = s4[2] constant({...}) + x = s4[3] fusion(p, q), kind=kInput, calls=a + r = s4[4] constant({...}) + s = s4[5] constant({...}) + y = s4[9] fusion(r, s), kind=kInput, calls=b + t = tuple(x, y) +})")); + + EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); +} + +TEST_F(HorizontalLoopFusionTest, DoNotFuseLargerConcatenationInputFusions) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +a { + p = s4[100000] parameter(0) + q = s4[200000] parameter(1) + c = s4[300000] concatenate(p, q), dimensions={0} +} + +b { + p = s4[200000] parameter(0) + q = s4[100000] parameter(1) + c = s4[300000] concatenate(p, q), dimensions={0} +} + +e { + p = s4[100000] constant({...}) + q = s4[200000] constant({...}) + x = s4[300000] fusion(p, q), kind=kInput, calls=a + r = s4[200000] constant({...}) + s = s4[100000] constant({...}) + y = s4[300000] fusion(r, s), kind=kInput, calls=b + t = tuple(x, y) +})")); + + EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value()); +} + TEST_F(HorizontalLoopFusionTest, IterativeHorizontalFusion) { auto module = ParseAndReturnVerifiedModule(R"( HloModule NonfusionInstrs