Skip to content

Commit

Permalink
[GPU] GEMM fusions: let fusing effective parameters and their broadca…
Browse files Browse the repository at this point in the history
…sts in the epilogues.
  • Loading branch information
sergachev committed Nov 15, 2024
1 parent 7dda04e commit 37dc0d2
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 4 deletions.
7 changes: 7 additions & 0 deletions xla/hlo/utils/hlo_query.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ bool IsBroadcastOfParameter(const HloInstruction& instr) {
instr.operand(0)->opcode() == HloOpcode::kParameter;
}

bool IsEffectiveParameter(const HloInstruction& instr) {
return instr.opcode() == HloOpcode::kParameter ||
((instr.opcode() == HloOpcode::kBitcast ||
instr.opcode() == HloOpcode::kGetTupleElement) &&
IsEffectiveParameter(*instr.operand(0)));
}

HloInstruction* GetFirstInstructionWithOpcode(const HloComputation& computation,
const HloOpcode opcode) {
auto instructions = computation.instructions();
Expand Down
4 changes: 4 additions & 0 deletions xla/hlo/utils/hlo_query.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ bool IsBroadcastOfScalarConstant(const HloInstruction& instr);
// Returns whether the `instr` is a broadcast and its input is a parameter.
bool IsBroadcastOfParameter(const HloInstruction& instr);

// Returns true for a parameter or a parameter followed by a chain of no-op
// instructions (bitcast, get-tuple-element).
bool IsEffectiveParameter(const HloInstruction&);

// Returns first HLO of the computation with the opcode, otherwise nullptr.
HloInstruction* GetFirstInstructionWithOpcode(const HloComputation& computation,
HloOpcode opcode);
Expand Down
19 changes: 19 additions & 0 deletions xla/service/gpu/transforms/gemm_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,25 @@ ENTRY e {
m::Parameter(), m::Parameter()))));
}

TEST_F(GemmFusionTest, BroadcastsOfParametersAreFusedAsEpilogueInputs) {
auto module = ParseAndReturnVerifiedModule(R"(
e {
p0 = f16[4,55] parameter(0)
p1 = f16[123,55] parameter(1)
d = f16[4,123] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
p2 = (f16[123,1], f16[456]) parameter(2)
g = get-tuple-element(p2), index=0
t = f16[123] bitcast(g)
b = f16[4,123] broadcast(t), dimensions={1}
m = f16[4,123] multiply(d, b)
})")
.value();
EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch((m::Fusion(m::Parameter(), m::Parameter(),
m::GetTupleElement()))));
}

// A test fixture class for testing the threshold for small matrices.
class SmallDotGemmFusionTest : public GemmFusionTest {
public:
Expand Down
13 changes: 9 additions & 4 deletions xla/service/gpu/triton_tiling_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1091,11 +1091,16 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible(
if (i == *src_operand_index) {
continue;
}
// Currently only broadcasts of scalars or parameters are accepted as
// other inputs of non-unary operations in the output fusion.
// Currently only
// - effective parameters
// - broadcasts of effective parameters
// - broadcasts of scalars
// are accepted as other inputs of non-unary operations in
// the output fusion.
if ((operand->opcode() == HloOpcode::kBroadcast &&
ShapeUtil::IsScalar(operand->operand(0)->shape())) ||
operand->opcode() == HloOpcode::kParameter) {
(ShapeUtil::IsScalar(operand->operand(0)->shape()) ||
hlo_query::IsEffectiveParameter(*operand->operand(0)))) ||
hlo_query::IsEffectiveParameter(*operand)) {
continue;
}
return FusionDecision::Forbid(
Expand Down

0 comments on commit 37dc0d2

Please sign in to comment.