diff --git a/xla/hlo/utils/hlo_query.cc b/xla/hlo/utils/hlo_query.cc index 90b6ddfd4d2b2a..1fd417a8a2cddf 100644 --- a/xla/hlo/utils/hlo_query.cc +++ b/xla/hlo/utils/hlo_query.cc @@ -178,6 +178,13 @@ bool IsBroadcastOfParameter(const HloInstruction& instr) { instr.operand(0)->opcode() == HloOpcode::kParameter; } +bool IsEffectiveParameter(const HloInstruction& instr) { + return instr.opcode() == HloOpcode::kParameter || + ((instr.opcode() == HloOpcode::kBitcast || + instr.opcode() == HloOpcode::kGetTupleElement) && + IsEffectiveParameter(*instr.operand(0))); +} + HloInstruction* GetFirstInstructionWithOpcode(const HloComputation& computation, const HloOpcode opcode) { auto instructions = computation.instructions(); diff --git a/xla/hlo/utils/hlo_query.h b/xla/hlo/utils/hlo_query.h index f219594024dc7e..3612baf0803d5e 100644 --- a/xla/hlo/utils/hlo_query.h +++ b/xla/hlo/utils/hlo_query.h @@ -79,6 +79,10 @@ bool IsBroadcastOfScalarConstant(const HloInstruction& instr); // Returns whether the `instr` is a broadcast and its input is a parameter. bool IsBroadcastOfParameter(const HloInstruction& instr); +// Returns true for a parameter or a parameter followed by a chain of no-op +// instructions (bitcast, get-tuple-element). +bool IsEffectiveParameter(const HloInstruction&); + // Returns first HLO of the computation with the opcode, otherwise nullptr. HloInstruction* GetFirstInstructionWithOpcode(const HloComputation& computation, HloOpcode opcode); diff --git a/xla/service/gpu/transforms/gemm_fusion_test.cc b/xla/service/gpu/transforms/gemm_fusion_test.cc index ebda001f2a7ba1..509cc8d76b320b 100644 --- a/xla/service/gpu/transforms/gemm_fusion_test.cc +++ b/xla/service/gpu/transforms/gemm_fusion_test.cc @@ -1238,6 +1238,25 @@ ENTRY e { m::Parameter(), m::Parameter())))); } +TEST_F(GemmFusionTest, BroadcastsOfParametersAreFusedAsEpilogueInputs) { + auto module = ParseAndReturnVerifiedModule(R"( +e { + p0 = f16[4,55] parameter(0) + p1 = f16[123,55] parameter(1) + d = f16[4,123] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + p2 = (f16[123,1], f16[456]) parameter(2) + g = get-tuple-element(p2), index=0 + t = f16[123] bitcast(g) + b = f16[4,123] broadcast(t), dimensions={1} + m = f16[4,123] multiply(d, b) +})") + .value(); + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch((m::Fusion(m::Parameter(), m::Parameter(), + m::GetTupleElement())))); +} + // A test fixture class for testing the threshold for small matrices. class SmallDotGemmFusionTest : public GemmFusionTest { public: diff --git a/xla/service/gpu/triton_tiling_propagation.cc b/xla/service/gpu/triton_tiling_propagation.cc index 1d5230ee66855e..86ed1f67679790 100644 --- a/xla/service/gpu/triton_tiling_propagation.cc +++ b/xla/service/gpu/triton_tiling_propagation.cc @@ -1088,11 +1088,16 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( if (i == *src_operand_index) { continue; } - // Currently only broadcasts of scalars or parameters are accepted as - // other inputs of non-unary operations in the output fusion. + // Currently only + // - effective parameters + // - broadcasts of effective parameters + // - broadcasts of scalars + // are accepted as other inputs of non-unary operations in + // the output fusion. if ((operand->opcode() == HloOpcode::kBroadcast && - ShapeUtil::IsScalar(operand->operand(0)->shape())) || - operand->opcode() == HloOpcode::kParameter) { + (ShapeUtil::IsScalar(operand->operand(0)->shape()) || + hlo_query::IsEffectiveParameter(*operand->operand(0)))) || + hlo_query::IsEffectiveParameter(*operand)) { continue; } return FusionDecision::Forbid( diff --git a/xla/stream_executor/cuda/cuda_driver.cc b/xla/stream_executor/cuda/cuda_driver.cc index a66a9920b59cee..8c7ab4a53c3a1a 100644 --- a/xla/stream_executor/cuda/cuda_driver.cc +++ b/xla/stream_executor/cuda/cuda_driver.cc @@ -57,14 +57,6 @@ limitations under the License. namespace stream_executor { namespace gpu { -absl::Status GpuDriver::CreateGraph(CUgraph* graph) { - VLOG(2) << "Create new CUDA graph"; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuGraphCreate(graph, /*flags=*/0), - "Failed to create CUDA graph")); - VLOG(2) << "Created CUDA graph " << *graph; - return absl::OkStatus(); -} - int GpuDriver::GetDeviceCount() { int device_count = 0; auto status = cuda::ToStatus(cuDeviceGetCount(&device_count)); diff --git a/xla/stream_executor/gpu/gpu_driver.h b/xla/stream_executor/gpu/gpu_driver.h index 79aed4c74b5059..0f7058ba445cb6 100644 --- a/xla/stream_executor/gpu/gpu_driver.h +++ b/xla/stream_executor/gpu/gpu_driver.h @@ -57,11 +57,6 @@ namespace gpu { // Thread safety: these functions should not be used from signal handlers. class GpuDriver { public: - // Creates a new GPU graph. - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1gd885f719186010727b75c3315f865fdf - // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management - static absl::Status CreateGraph(GpuGraphHandle* graph); - // The CUDA stream callback type signature. // The data passed to AddStreamCallback is subsequently passed to this // callback when it fires. diff --git a/xla/stream_executor/rocm/rocm_driver.cc b/xla/stream_executor/rocm/rocm_driver.cc index 3bf8e7da5505cb..b0138c4c79bc53 100644 --- a/xla/stream_executor/rocm/rocm_driver.cc +++ b/xla/stream_executor/rocm/rocm_driver.cc @@ -54,14 +54,6 @@ limitations under the License. namespace stream_executor::gpu { -absl::Status GpuDriver::CreateGraph(hipGraph_t* graph) { - VLOG(2) << "Create new HIP graph"; - TF_RETURN_IF_ERROR(ToStatus(wrap::hipGraphCreate(graph, /*flags=*/0), - "Failed to create HIP graph")); - VLOG(2) << "Created HIP graph " << *graph; - return absl::OkStatus(); -} - int GpuDriver::GetDeviceCount() { int device_count = 0; hipError_t res = wrap::hipGetDeviceCount(&device_count);