Skip to content

Commit

Permalink
[OpenMP] [amdgpu] For Big-Jump-Loop kernels, use the tripcount for
Browse files Browse the repository at this point in the history
computing the number of teams.

Change-Id: Iaaea3a88cb5d016228a66bb5480a9f2f26334b16
  • Loading branch information
dhruvachak authored and ronlieb committed Sep 5, 2024
1 parent dcf5a00 commit fa255a2
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
35 changes: 27 additions & 8 deletions offload/plugins-nextgen/amdgpu/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1101,13 +1101,22 @@ struct AMDGPUKernelTy : public GenericKernelTy {
std::min(static_cast<uint64_t>(NumTeamsClause[0]), NumGroups);
} else {
// num_teams clause is not specified. Choose lower of tripcount-based
// num-groups and a value that maximizes occupancy. At this point, aim
// to have 16 wavefronts in a CU. Allow for override with envar.
uint64_t MaxOccupancyFactor =
GenericDevice.getOMPXBigJumpLoopTeamsPerCU() > 0
? GenericDevice.getOMPXBigJumpLoopTeamsPerCU()
: 16 / NumWavesInGroup;
NumGroups = std::min(NumGroups, MaxOccupancyFactor * DeviceNumCUs);
// NumGroups and a value determined as follows:
// - If the number of teams per CU is specified by the user with the
// envar LIBOMPTARGET_AMDGPU_BIG_JUMP_LOOP_TEAMS_PER_CU, compute
// NumGroups from that specified value. This envar is OFF by default.
// - Otherwise, use the max total teams specified with the envar
/// LIBOMPTARGET_AMDGPU_BIG_JUMP_LOOP_MAX_TOTAL_TEAMS.
// This envar is used by default with 1M as the default value.
if (GenericDevice.getOMPXBigJumpLoopTeamsPerCU() > 0) {
NumGroups =
std::min(NumGroups, GenericDevice.getOMPXBigJumpLoopTeamsPerCU() *
DeviceNumCUs);
} else {
NumGroups = std::min(
NumGroups, static_cast<uint64_t>(
GenericDevice.getOMPXBigJumpLoopMaxTotalTeams()));
}

// If the user specifies a number of teams for low trip count loops,
// honor it.
Expand All @@ -1117,7 +1126,8 @@ struct AMDGPUKernelTy : public GenericKernelTy {
NumGroups = LowTripCountBlocks;
}
}
return NumGroups;
return std::min(NumGroups,
static_cast<uint64_t>(GenericDevice.getBlockLimit()));
}

if (isXTeamReductionsMode()) {
Expand Down Expand Up @@ -2744,6 +2754,8 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
"LIBOMPTARGET_AMDGPU_GENERIC_SPMD_TEAMS_PER_CU", 0),
OMPX_BigJumpLoopTeamsPerCU(
"LIBOMPTARGET_AMDGPU_BIG_JUMP_LOOP_TEAMS_PER_CU", 0),
OMPX_BigJumpLoopMaxTotalTeams(
"LIBOMPTARGET_AMDGPU_BIG_JUMP_LOOP_MAX_TOTAL_TEAMS", 1024 * 1024),
OMPX_LowTripCount("LIBOMPTARGET_AMDGPU_LOW_TRIPCOUNT", 2000),
OMPX_SmallBlockSize("LIBOMPTARGET_MIN_THREADS_FOR_LOW_TRIP_COUNT", 8),
OMPX_NumBlocksForLowTripcount("LIBOMPTARGET_BLOCKS_FOR_LOW_TRIP_COUNT",
Expand Down Expand Up @@ -2802,6 +2814,9 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
virtual uint32_t getOMPXBigJumpLoopTeamsPerCU() const override {
return OMPX_BigJumpLoopTeamsPerCU;
}
virtual uint32_t getOMPXBigJumpLoopMaxTotalTeams() const override {
return OMPX_BigJumpLoopMaxTotalTeams;
}
virtual uint32_t getOMPXLowTripCount() const override {
return OMPX_LowTripCount;
}
Expand Down Expand Up @@ -4246,6 +4261,10 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
/// OMPX_BigJumpLoopTeamsPerCU * #CUs.
UInt32Envar OMPX_BigJumpLoopTeamsPerCU;

/// Envar controlling the maximum number of teams per device for
/// Big-Jump-Loop kernels.
UInt32Envar OMPX_BigJumpLoopMaxTotalTeams;

/// Envar specifying tripcount below which the blocksize should be adjusted.
UInt32Envar OMPX_LowTripCount;

Expand Down
3 changes: 3 additions & 0 deletions offload/plugins-nextgen/common/include/PluginInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,9 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
virtual uint32_t getOMPXBigJumpLoopTeamsPerCU() const {
llvm_unreachable("Unimplemented");
}
virtual uint32_t getOMPXBigJumpLoopMaxTotalTeams() const {
llvm_unreachable("Unimplemented");
}
virtual uint32_t getOMPXLowTripCount() const {
llvm_unreachable("Unimplemented");
}
Expand Down

0 comments on commit fa255a2

Please sign in to comment.