From 1ffcdaa26873c7aa6648513b9424f31e38f483f2 Mon Sep 17 00:00:00 2001 From: Eunjae Kim Date: Thu, 10 Oct 2024 15:32:51 -0700 Subject: [PATCH] Add an algebraic simplification pattern for add(broadcast(const_0), add(broadcast(const_1, conv(...)))) -> add(broadcast(add(const_0, const_1)), conv(...)) PiperOrigin-RevId: 684600060 --- xla/service/gpu/transforms/BUILD | 7 ++ .../gpu/transforms/algebraic_simplifier.cc | 54 ++++++++ .../gpu/transforms/algebraic_simplifier.h | 12 ++ .../transforms/algebraic_simplifier_test.cc | 117 ++++++++++++++++++ 4 files changed, 190 insertions(+) diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index f28812f3a8cdb..fb2fb8b7ec6a9 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -25,18 +25,23 @@ cc_library( "algebraic_simplifier.h", ], deps = [ + "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:algebraic_simplifier", + "//xla/service:pattern_matcher", "//xla/service/gpu:matmul_utils", "//xla/service/gpu/fusions/triton:triton_support", "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", ], ) @@ -47,6 +52,8 @@ xla_cc_test( ":algebraic_simplifier", "//xla/hlo/ir:hlo", "//xla/service:algebraic_simplifier", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", diff --git a/xla/service/gpu/transforms/algebraic_simplifier.cc b/xla/service/gpu/transforms/algebraic_simplifier.cc index d59ae2b6a1d03..954017fba4ede 100644 --- a/xla/service/gpu/transforms/algebraic_simplifier.cc +++ b/xla/service/gpu/transforms/algebraic_simplifier.cc @@ -16,15 +16,69 @@ limitations under the License. #include "xla/service/gpu/transforms/algebraic_simplifier.h" #include "absl/log/check.h" +#include "absl/status/status.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/algebraic_simplifier.h" #include "xla/service/gpu/fusions/triton/triton_support_legacy.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/service/pattern_matcher.h" +#include "xla/shape_util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla::gpu { +namespace m = ::xla::match; + +absl::StatusOr +GpuAlgebraicSimplifierVisitor::TryToSinkBroadcastOperandsOfChainedAdds( + HloInstruction* add) { + if (!options_.enable_sink_broadcast()) { + return false; + } + + HloInstruction *conv, *constant_0, *broadcast_0, *add_0, *constant_1, + *broadcast_1; + if (!Match(add, m::AddAnyOrder( + m::AddAnyOrder( + &add_0, m::Convolution(&conv, m::Op(), m::Op()), + m::Broadcast(&broadcast_0, m::Constant(&constant_0))), + m::Broadcast(&broadcast_1, m::Constant(&constant_1))))) { + return false; + } + + // Skip when the broadcast shapes and dimensions don't match. + if (!ShapeUtil::Equal(constant_0->shape(), constant_1->shape()) || + broadcast_0->dimensions() != broadcast_1->dimensions()) { + return false; + } + + HloInstruction* new_constant_add = + add->AddInstruction(HloInstruction::CreateBinary( + constant_0->shape(), HloOpcode::kAdd, constant_0, constant_1)); + HloInstruction* new_bcast = + add->AddInstruction(HloInstruction::CreateBroadcast( + broadcast_0->shape(), new_constant_add, broadcast_0->dimensions())); + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, + new_bcast, conv))); + return true; +} + +absl::Status GpuAlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { + TF_ASSIGN_OR_RETURN(bool replaced, + TryToSinkBroadcastOperandsOfChainedAdds(add)); + if (replaced) { + return absl::OkStatus(); + } + + return AlgebraicSimplifierVisitor::HandleAdd(add); +} + bool GpuAlgebraicSimplifierVisitor::ShouldStrengthReduceDotToReduce( const HloInstruction* hlo) { if (!options_.enable_dot_strength_reduction()) { diff --git a/xla/service/gpu/transforms/algebraic_simplifier.h b/xla/service/gpu/transforms/algebraic_simplifier.h index c63cf1e51a4c7..56f3aa814f759 100644 --- a/xla/service/gpu/transforms/algebraic_simplifier.h +++ b/xla/service/gpu/transforms/algebraic_simplifier.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -38,9 +39,20 @@ class GpuAlgebraicSimplifierVisitor : public AlgebraicSimplifierVisitor { : AlgebraicSimplifierVisitor(options, simplifier), compute_capability_(std::move(compute_capability)) {} + absl::Status HandleAdd(HloInstruction* add) override; + bool ShouldStrengthReduceDotToReduce(const HloInstruction* hlo) override; private: + // Try to convert add(broadcast(const_0), add(broadcast(const_1), conv(...))) + // into add(broadcast(add(const_0, const_1)), conv(...)) and return true if + // successful. The particular sink happens only when enable_sink_broadcast is + // true and the broadcast shapes and dimensions match. The sink only happens + // when following a convolution to avoid having a side input when the + // instructions are fused to cudnnConvolutionBiasActivationForward later. + absl::StatusOr TryToSinkBroadcastOperandsOfChainedAdds( + HloInstruction* add); + se::GpuComputeCapability compute_capability_; }; diff --git a/xla/service/gpu/transforms/algebraic_simplifier_test.cc b/xla/service/gpu/transforms/algebraic_simplifier_test.cc index c1e52e90a417c..db74d98ac01cb 100644 --- a/xla/service/gpu/transforms/algebraic_simplifier_test.cc +++ b/xla/service/gpu/transforms/algebraic_simplifier_test.cc @@ -17,9 +17,12 @@ limitations under the License. #include +#include #include #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/algebraic_simplifier.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" @@ -27,8 +30,122 @@ limitations under the License. namespace xla::gpu { namespace { +namespace m = ::xla::match; + class GpuAlgebraicSimplifierTest : public HloTestBase {}; +TEST_F(GpuAlgebraicSimplifierTest, SinkBroadcastOperandsOfChainedAdds) { + const std::string& hlo_string = R"( + HloModule m + test { + in = bf16[1,3,3,1] parameter(0) + filter = bf16[2,2,1,1] constant({{{{1.1}}, {{2.1}}}, + {{{3.1}}, {{4.1}}}}) + conv = bf16[1,2,2,1] convolution(in, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f + const0 = bf16[2] constant({0, 0.25}) + bcast0 = bf16[1,2,2,1] broadcast(const0), dimensions={1} + add0 = bf16[1,2,2,1] add(conv, bcast0) + const1 = bf16[2] constant({1, 1.25}) + bcast1 = bf16[1,2,2,1] broadcast(const1), dimensions={1} + ROOT add1 = bf16[1,2,2,1] add(add0, bcast1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AlgebraicSimplifierOptions options; + options.set_enable_sink_broadcast(true); + ASSERT_TRUE( + GpuAlgebraicSimplifier(options, se::CudaComputeCapability::Ampere()) + .Run(module.get()) + .value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::AddAnyOrder( + m::Broadcast(m::Add(m::Constant(), m::Constant())), + m::Convolution(m::Op(), m::Op())))); +} + +TEST_F(GpuAlgebraicSimplifierTest, + DoNotSinkBroadcastOperandsOfChainedAddsWhenDisabled) { + const std::string& hlo_string = R"( + HloModule m + test { + in = bf16[1,3,3,1] parameter(0) + filter = bf16[2,2,1,1] constant({{{{1.1}}, {{2.1}}}, + {{{3.1}}, {{4.1}}}}) + conv = bf16[1,2,2,1] convolution(in, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f + const0 = bf16[2] constant({0, 0.25}) + bcast0 = bf16[1,2,2,1] broadcast(const0), dimensions={1} + add0 = bf16[1,2,2,1] add(conv, bcast0) + const1 = bf16[2] constant({1, 1.25}) + bcast1 = bf16[1,2,2,1] broadcast(const1), dimensions={1} + ROOT add1 = bf16[1,2,2,1] add(add0, bcast1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AlgebraicSimplifierOptions options; + options.set_enable_sink_broadcast(false); + EXPECT_FALSE( + GpuAlgebraicSimplifier(options, se::CudaComputeCapability::Ampere()) + .Run(module.get()) + .value()); +} + +TEST_F(GpuAlgebraicSimplifierTest, + DoNotSinkBroadcastOperandsOfChainedAddsWithoutConvolution) { + const std::string& hlo_string = R"( + HloModule m + test { + p = bf16[4, 4] parameter(0) + const0 = bf16[4] constant({0, 0.25, 0.5, 0.75}) + bcast0 = bf16[4,4] broadcast(const0), dimensions={0} + add0 = bf16[4,4] add(p, bcast0) + const1 = bf16[4] constant({1, 1.25, 1.5, 1.75}) + bcast1 = bf16[4,4] broadcast(const1), dimensions={0} + ROOT add1 = bf16[4,4] add(add0, bcast1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AlgebraicSimplifierOptions options; + options.set_enable_sink_broadcast(true); + EXPECT_FALSE( + GpuAlgebraicSimplifier(options, se::CudaComputeCapability::Ampere()) + .Run(module.get()) + .value()); +} + +TEST_F( + GpuAlgebraicSimplifierTest, + DoNotSinkBroadcastOperandsOfChainedAddsWithMismatchedBroadcastDimensions) { + const std::string& hlo_string = R"( + HloModule m + test { + in = bf16[1,3,3,1] parameter(0) + filter = bf16[2,2,1,1] constant({{{{1.1}}, {{2.1}}}, + {{{3.1}}, {{4.1}}}}) + conv = bf16[1,2,2,1] convolution(in, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f + const0 = bf16[2] constant({0, 0.25}) + bcast0 = bf16[1,2,2,1] broadcast(const0), dimensions={1} + add0 = bf16[1,2,2,1] add(conv, bcast0) + const1 = bf16[2] constant({1, 1.25}) + bcast1 = bf16[1,2,2,1] broadcast(const1), dimensions={2} + ROOT add1 = bf16[1,2,2,1] add(add0, bcast1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AlgebraicSimplifierOptions options; + options.set_enable_sink_broadcast(true); + EXPECT_FALSE( + GpuAlgebraicSimplifier(options, se::CudaComputeCapability::Ampere()) + .Run(module.get()) + .value()); +} + TEST_F(GpuAlgebraicSimplifierTest, VectorVectorDotShouldBeStrengthReduced) { const std::string& hlo_string = R"( HloModule m