Skip to content

Commit

Permalink
PR #18825: [GPU] GEMM fusions: let fusing effective parameters and th…
Browse files Browse the repository at this point in the history
…eir broadcasts in the epilogues.

Imported from GitHub PR #18825

Copybara import of the project:

--
2c9840f by Ilia Sergachev <isergachev@nvidia.com>:

[GPU] GEMM fusions: let fusing effective parameters and their broadcasts in the epilogues.

Merging this change closes #18825

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18825 from openxla:gemm_fusion_effective_parameters 2c9840f
PiperOrigin-RevId: 691406714
  • Loading branch information
sergachev authored and Google-ML-Automation committed Oct 30, 2024
1 parent 3254938 commit 767d251
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 25 deletions.
7 changes: 7 additions & 0 deletions xla/hlo/utils/hlo_query.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ bool IsBroadcastOfParameter(const HloInstruction& instr) {
instr.operand(0)->opcode() == HloOpcode::kParameter;
}

bool IsEffectiveParameter(const HloInstruction& instr) {
return instr.opcode() == HloOpcode::kParameter ||
((instr.opcode() == HloOpcode::kBitcast ||
instr.opcode() == HloOpcode::kGetTupleElement) &&
IsEffectiveParameter(*instr.operand(0)));
}

HloInstruction* GetFirstInstructionWithOpcode(const HloComputation& computation,
const HloOpcode opcode) {
auto instructions = computation.instructions();
Expand Down
4 changes: 4 additions & 0 deletions xla/hlo/utils/hlo_query.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ bool IsBroadcastOfScalarConstant(const HloInstruction& instr);
// Returns whether the `instr` is a broadcast and its input is a parameter.
bool IsBroadcastOfParameter(const HloInstruction& instr);

// Returns true for a parameter or a parameter followed by a chain of no-op
// instructions (bitcast, get-tuple-element).
bool IsEffectiveParameter(const HloInstruction&);

// Returns first HLO of the computation with the opcode, otherwise nullptr.
HloInstruction* GetFirstInstructionWithOpcode(const HloComputation& computation,
HloOpcode opcode);
Expand Down
19 changes: 19 additions & 0 deletions xla/service/gpu/transforms/gemm_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,25 @@ ENTRY e {
m::Parameter(), m::Parameter()))));
}

TEST_F(GemmFusionTest, BroadcastsOfParametersAreFusedAsEpilogueInputs) {
auto module = ParseAndReturnVerifiedModule(R"(
e {
p0 = f16[4,55] parameter(0)
p1 = f16[123,55] parameter(1)
d = f16[4,123] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
p2 = (f16[123,1], f16[456]) parameter(2)
g = get-tuple-element(p2), index=0
t = f16[123] bitcast(g)
b = f16[4,123] broadcast(t), dimensions={1}
m = f16[4,123] multiply(d, b)
})")
.value();
EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value());
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch((m::Fusion(m::Parameter(), m::Parameter(),
m::GetTupleElement()))));
}

// A test fixture class for testing the threshold for small matrices.
class SmallDotGemmFusionTest : public GemmFusionTest {
public:
Expand Down
13 changes: 9 additions & 4 deletions xla/service/gpu/triton_tiling_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1088,11 +1088,16 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible(
if (i == *src_operand_index) {
continue;
}
// Currently only broadcasts of scalars or parameters are accepted as
// other inputs of non-unary operations in the output fusion.
// Currently only
// - effective parameters
// - broadcasts of effective parameters
// - broadcasts of scalars
// are accepted as other inputs of non-unary operations in
// the output fusion.
if ((operand->opcode() == HloOpcode::kBroadcast &&
ShapeUtil::IsScalar(operand->operand(0)->shape())) ||
operand->opcode() == HloOpcode::kParameter) {
(ShapeUtil::IsScalar(operand->operand(0)->shape()) ||
hlo_query::IsEffectiveParameter(*operand->operand(0)))) ||
hlo_query::IsEffectiveParameter(*operand)) {
continue;
}
return FusionDecision::Forbid(
Expand Down
8 changes: 0 additions & 8 deletions xla/stream_executor/cuda/cuda_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,6 @@ limitations under the License.
namespace stream_executor {
namespace gpu {

absl::Status GpuDriver::CreateGraph(CUgraph* graph) {
VLOG(2) << "Create new CUDA graph";
TF_RETURN_IF_ERROR(cuda::ToStatus(cuGraphCreate(graph, /*flags=*/0),
"Failed to create CUDA graph"));
VLOG(2) << "Created CUDA graph " << *graph;
return absl::OkStatus();
}

int GpuDriver::GetDeviceCount() {
int device_count = 0;
auto status = cuda::ToStatus(cuDeviceGetCount(&device_count));
Expand Down
5 changes: 0 additions & 5 deletions xla/stream_executor/gpu/gpu_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,6 @@ namespace gpu {
// Thread safety: these functions should not be used from signal handlers.
class GpuDriver {
public:
// Creates a new GPU graph.
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1gd885f719186010727b75c3315f865fdf
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management
static absl::Status CreateGraph(GpuGraphHandle* graph);

// The CUDA stream callback type signature.
// The data passed to AddStreamCallback is subsequently passed to this
// callback when it fires.
Expand Down
8 changes: 0 additions & 8 deletions xla/stream_executor/rocm/rocm_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,6 @@ limitations under the License.

namespace stream_executor::gpu {

absl::Status GpuDriver::CreateGraph(hipGraph_t* graph) {
VLOG(2) << "Create new HIP graph";
TF_RETURN_IF_ERROR(ToStatus(wrap::hipGraphCreate(graph, /*flags=*/0),
"Failed to create HIP graph"));
VLOG(2) << "Created HIP graph " << *graph;
return absl::OkStatus();
}

int GpuDriver::GetDeviceCount() {
int device_count = 0;
hipError_t res = wrap::hipGetDeviceCount(&device_count);
Expand Down

0 comments on commit 767d251

Please sign in to comment.