Skip to content

Commit

Permalink
Add an algebraic simplification pattern for add(broadcast(const_0), a…
Browse files Browse the repository at this point in the history
…dd(broadcast(const_1, conv(...)))) -> add(broadcast(add(const_0, const_1)), conv(...))

PiperOrigin-RevId: 684600060
  • Loading branch information
eunjaekim-0 authored and Google-ML-Automation committed Oct 10, 2024
1 parent cae9085 commit 1ffcdaa
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 0 deletions.
7 changes: 7 additions & 0 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,23 @@ cc_library(
"algebraic_simplifier.h",
],
deps = [
"//xla:shape_util",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"//xla/service:algebraic_simplifier",
"//xla/service:pattern_matcher",
"//xla/service/gpu:matmul_utils",
"//xla/service/gpu/fusions/triton:triton_support",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

Expand All @@ -47,6 +52,8 @@ xla_cc_test(
":algebraic_simplifier",
"//xla/hlo/ir:hlo",
"//xla/service:algebraic_simplifier",
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
"//xla/stream_executor:device_description",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
Expand Down
54 changes: 54 additions & 0 deletions xla/service/gpu/transforms/algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,69 @@ limitations under the License.
#include "xla/service/gpu/transforms/algebraic_simplifier.h"

#include "absl/log/check.h"
#include "absl/status/status.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/algebraic_simplifier.h"
#include "xla/service/gpu/fusions/triton/triton_support_legacy.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/pattern_matcher.h"
#include "xla/shape_util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace xla::gpu {

namespace m = ::xla::match;

absl::StatusOr<bool>
GpuAlgebraicSimplifierVisitor::TryToSinkBroadcastOperandsOfChainedAdds(
HloInstruction* add) {
if (!options_.enable_sink_broadcast()) {
return false;
}

HloInstruction *conv, *constant_0, *broadcast_0, *add_0, *constant_1,
*broadcast_1;
if (!Match(add, m::AddAnyOrder(
m::AddAnyOrder(
&add_0, m::Convolution(&conv, m::Op(), m::Op()),
m::Broadcast(&broadcast_0, m::Constant(&constant_0))),
m::Broadcast(&broadcast_1, m::Constant(&constant_1))))) {
return false;
}

// Skip when the broadcast shapes and dimensions don't match.
if (!ShapeUtil::Equal(constant_0->shape(), constant_1->shape()) ||
broadcast_0->dimensions() != broadcast_1->dimensions()) {
return false;
}

HloInstruction* new_constant_add =
add->AddInstruction(HloInstruction::CreateBinary(
constant_0->shape(), HloOpcode::kAdd, constant_0, constant_1));
HloInstruction* new_bcast =
add->AddInstruction(HloInstruction::CreateBroadcast(
broadcast_0->shape(), new_constant_add, broadcast_0->dimensions()));
TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd,
new_bcast, conv)));
return true;
}

absl::Status GpuAlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
TF_ASSIGN_OR_RETURN(bool replaced,
TryToSinkBroadcastOperandsOfChainedAdds(add));
if (replaced) {
return absl::OkStatus();
}

return AlgebraicSimplifierVisitor::HandleAdd(add);
}

bool GpuAlgebraicSimplifierVisitor::ShouldStrengthReduceDotToReduce(
const HloInstruction* hlo) {
if (!options_.enable_dot_strength_reduction()) {
Expand Down
12 changes: 12 additions & 0 deletions xla/service/gpu/transforms/algebraic_simplifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <utility>

#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_instruction.h"
Expand All @@ -38,9 +39,20 @@ class GpuAlgebraicSimplifierVisitor : public AlgebraicSimplifierVisitor {
: AlgebraicSimplifierVisitor(options, simplifier),
compute_capability_(std::move(compute_capability)) {}

absl::Status HandleAdd(HloInstruction* add) override;

bool ShouldStrengthReduceDotToReduce(const HloInstruction* hlo) override;

private:
// Try to convert add(broadcast(const_0), add(broadcast(const_1), conv(...)))
// into add(broadcast(add(const_0, const_1)), conv(...)) and return true if
// successful. The particular sink happens only when enable_sink_broadcast is
// true and the broadcast shapes and dimensions match. The sink only happens
// when following a convolution to avoid having a side input when the
// instructions are fused to cudnnConvolutionBiasActivationForward later.
absl::StatusOr<bool> TryToSinkBroadcastOperandsOfChainedAdds(
HloInstruction* add);

se::GpuComputeCapability compute_capability_;
};

Expand Down
117 changes: 117 additions & 0 deletions xla/service/gpu/transforms/algebraic_simplifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,135 @@ limitations under the License.

#include <string>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/algebraic_simplifier.h"
#include "xla/service/pattern_matcher.h"
#include "xla/service/pattern_matcher_gmock.h"
#include "xla/stream_executor/device_description.h"
#include "xla/tests/hlo_test_base.h"
#include "tsl/platform/statusor.h"

namespace xla::gpu {
namespace {

namespace m = ::xla::match;

class GpuAlgebraicSimplifierTest : public HloTestBase {};

TEST_F(GpuAlgebraicSimplifierTest, SinkBroadcastOperandsOfChainedAdds) {
const std::string& hlo_string = R"(
HloModule m
test {
in = bf16[1,3,3,1] parameter(0)
filter = bf16[2,2,1,1] constant({{{{1.1}}, {{2.1}}},
{{{3.1}}, {{4.1}}}})
conv = bf16[1,2,2,1] convolution(in, filter),
window={size=2x2}, dim_labels=b01f_01io->b01f
const0 = bf16[2] constant({0, 0.25})
bcast0 = bf16[1,2,2,1] broadcast(const0), dimensions={1}
add0 = bf16[1,2,2,1] add(conv, bcast0)
const1 = bf16[2] constant({1, 1.25})
bcast1 = bf16[1,2,2,1] broadcast(const1), dimensions={1}
ROOT add1 = bf16[1,2,2,1] add(add0, bcast1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AlgebraicSimplifierOptions options;
options.set_enable_sink_broadcast(true);
ASSERT_TRUE(
GpuAlgebraicSimplifier(options, se::CudaComputeCapability::Ampere())
.Run(module.get())
.value());
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::AddAnyOrder(
m::Broadcast(m::Add(m::Constant(), m::Constant())),
m::Convolution(m::Op(), m::Op()))));
}

TEST_F(GpuAlgebraicSimplifierTest,
DoNotSinkBroadcastOperandsOfChainedAddsWhenDisabled) {
const std::string& hlo_string = R"(
HloModule m
test {
in = bf16[1,3,3,1] parameter(0)
filter = bf16[2,2,1,1] constant({{{{1.1}}, {{2.1}}},
{{{3.1}}, {{4.1}}}})
conv = bf16[1,2,2,1] convolution(in, filter),
window={size=2x2}, dim_labels=b01f_01io->b01f
const0 = bf16[2] constant({0, 0.25})
bcast0 = bf16[1,2,2,1] broadcast(const0), dimensions={1}
add0 = bf16[1,2,2,1] add(conv, bcast0)
const1 = bf16[2] constant({1, 1.25})
bcast1 = bf16[1,2,2,1] broadcast(const1), dimensions={1}
ROOT add1 = bf16[1,2,2,1] add(add0, bcast1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AlgebraicSimplifierOptions options;
options.set_enable_sink_broadcast(false);
EXPECT_FALSE(
GpuAlgebraicSimplifier(options, se::CudaComputeCapability::Ampere())
.Run(module.get())
.value());
}

TEST_F(GpuAlgebraicSimplifierTest,
DoNotSinkBroadcastOperandsOfChainedAddsWithoutConvolution) {
const std::string& hlo_string = R"(
HloModule m
test {
p = bf16[4, 4] parameter(0)
const0 = bf16[4] constant({0, 0.25, 0.5, 0.75})
bcast0 = bf16[4,4] broadcast(const0), dimensions={0}
add0 = bf16[4,4] add(p, bcast0)
const1 = bf16[4] constant({1, 1.25, 1.5, 1.75})
bcast1 = bf16[4,4] broadcast(const1), dimensions={0}
ROOT add1 = bf16[4,4] add(add0, bcast1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AlgebraicSimplifierOptions options;
options.set_enable_sink_broadcast(true);
EXPECT_FALSE(
GpuAlgebraicSimplifier(options, se::CudaComputeCapability::Ampere())
.Run(module.get())
.value());
}

TEST_F(
GpuAlgebraicSimplifierTest,
DoNotSinkBroadcastOperandsOfChainedAddsWithMismatchedBroadcastDimensions) {
const std::string& hlo_string = R"(
HloModule m
test {
in = bf16[1,3,3,1] parameter(0)
filter = bf16[2,2,1,1] constant({{{{1.1}}, {{2.1}}},
{{{3.1}}, {{4.1}}}})
conv = bf16[1,2,2,1] convolution(in, filter),
window={size=2x2}, dim_labels=b01f_01io->b01f
const0 = bf16[2] constant({0, 0.25})
bcast0 = bf16[1,2,2,1] broadcast(const0), dimensions={1}
add0 = bf16[1,2,2,1] add(conv, bcast0)
const1 = bf16[2] constant({1, 1.25})
bcast1 = bf16[1,2,2,1] broadcast(const1), dimensions={2}
ROOT add1 = bf16[1,2,2,1] add(add0, bcast1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AlgebraicSimplifierOptions options;
options.set_enable_sink_broadcast(true);
EXPECT_FALSE(
GpuAlgebraicSimplifier(options, se::CudaComputeCapability::Ampere())
.Run(module.get())
.value());
}

TEST_F(GpuAlgebraicSimplifierTest, VectorVectorDotShouldBeStrengthReduced) {
const std::string& hlo_string = R"(
HloModule m
Expand Down

0 comments on commit 1ffcdaa

Please sign in to comment.