From fa255a223d72f81d43d0350fcfb7c5a58c27499d Mon Sep 17 00:00:00 2001 From: Dhruva Chakrabarti Date: Thu, 5 Sep 2024 01:28:39 -0500 Subject: [PATCH] [OpenMP] [amdgpu] For Big-Jump-Loop kernels, use the tripcount for computing the number of teams. Change-Id: Iaaea3a88cb5d016228a66bb5480a9f2f26334b16 --- offload/plugins-nextgen/amdgpu/src/rtl.cpp | 35 ++++++++++++++----- .../common/include/PluginInterface.h | 3 ++ 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp index 7beb066b5854ae..b22f21164a6eca 100644 --- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp +++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp @@ -1101,13 +1101,22 @@ struct AMDGPUKernelTy : public GenericKernelTy { std::min(static_cast(NumTeamsClause[0]), NumGroups); } else { // num_teams clause is not specified. Choose lower of tripcount-based - // num-groups and a value that maximizes occupancy. At this point, aim - // to have 16 wavefronts in a CU. Allow for override with envar. - uint64_t MaxOccupancyFactor = - GenericDevice.getOMPXBigJumpLoopTeamsPerCU() > 0 - ? GenericDevice.getOMPXBigJumpLoopTeamsPerCU() - : 16 / NumWavesInGroup; - NumGroups = std::min(NumGroups, MaxOccupancyFactor * DeviceNumCUs); + // NumGroups and a value determined as follows: + // - If the number of teams per CU is specified by the user with the + // envar LIBOMPTARGET_AMDGPU_BIG_JUMP_LOOP_TEAMS_PER_CU, compute + // NumGroups from that specified value. This envar is OFF by default. + // - Otherwise, use the max total teams specified with the envar + /// LIBOMPTARGET_AMDGPU_BIG_JUMP_LOOP_MAX_TOTAL_TEAMS. + // This envar is used by default with 1M as the default value. + if (GenericDevice.getOMPXBigJumpLoopTeamsPerCU() > 0) { + NumGroups = + std::min(NumGroups, GenericDevice.getOMPXBigJumpLoopTeamsPerCU() * + DeviceNumCUs); + } else { + NumGroups = std::min( + NumGroups, static_cast( + GenericDevice.getOMPXBigJumpLoopMaxTotalTeams())); + } // If the user specifies a number of teams for low trip count loops, // honor it. @@ -1117,7 +1126,8 @@ struct AMDGPUKernelTy : public GenericKernelTy { NumGroups = LowTripCountBlocks; } } - return NumGroups; + return std::min(NumGroups, + static_cast(GenericDevice.getBlockLimit())); } if (isXTeamReductionsMode()) { @@ -2744,6 +2754,8 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy { "LIBOMPTARGET_AMDGPU_GENERIC_SPMD_TEAMS_PER_CU", 0), OMPX_BigJumpLoopTeamsPerCU( "LIBOMPTARGET_AMDGPU_BIG_JUMP_LOOP_TEAMS_PER_CU", 0), + OMPX_BigJumpLoopMaxTotalTeams( + "LIBOMPTARGET_AMDGPU_BIG_JUMP_LOOP_MAX_TOTAL_TEAMS", 1024 * 1024), OMPX_LowTripCount("LIBOMPTARGET_AMDGPU_LOW_TRIPCOUNT", 2000), OMPX_SmallBlockSize("LIBOMPTARGET_MIN_THREADS_FOR_LOW_TRIP_COUNT", 8), OMPX_NumBlocksForLowTripcount("LIBOMPTARGET_BLOCKS_FOR_LOW_TRIP_COUNT", @@ -2802,6 +2814,9 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy { virtual uint32_t getOMPXBigJumpLoopTeamsPerCU() const override { return OMPX_BigJumpLoopTeamsPerCU; } + virtual uint32_t getOMPXBigJumpLoopMaxTotalTeams() const override { + return OMPX_BigJumpLoopMaxTotalTeams; + } virtual uint32_t getOMPXLowTripCount() const override { return OMPX_LowTripCount; } @@ -4246,6 +4261,10 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy { /// OMPX_BigJumpLoopTeamsPerCU * #CUs. UInt32Envar OMPX_BigJumpLoopTeamsPerCU; + /// Envar controlling the maximum number of teams per device for + /// Big-Jump-Loop kernels. + UInt32Envar OMPX_BigJumpLoopMaxTotalTeams; + /// Envar specifying tripcount below which the blocksize should be adjusted. UInt32Envar OMPX_LowTripCount; diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h index 11b704fd8084f1..dc9959868e6a52 100644 --- a/offload/plugins-nextgen/common/include/PluginInterface.h +++ b/offload/plugins-nextgen/common/include/PluginInterface.h @@ -967,6 +967,9 @@ struct GenericDeviceTy : public DeviceAllocatorTy { virtual uint32_t getOMPXBigJumpLoopTeamsPerCU() const { llvm_unreachable("Unimplemented"); } + virtual uint32_t getOMPXBigJumpLoopMaxTotalTeams() const { + llvm_unreachable("Unimplemented"); + } virtual uint32_t getOMPXLowTripCount() const { llvm_unreachable("Unimplemented"); }