diff --git a/source/adapters/hip/enqueue.cpp b/source/adapters/hip/enqueue.cpp index c4b1b86045..4cc095f00e 100644 --- a/source/adapters/hip/enqueue.cpp +++ b/source/adapters/hip/enqueue.cpp @@ -15,6 +15,8 @@ #include "memory.hpp" #include "queue.hpp" +#include + extern size_t imageElementByteSize(hipArray_Format ArrayFormat); namespace { @@ -49,23 +51,36 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t, hipStream_t Stream, } } -void simpleGuessLocalWorkSize(size_t *ThreadsPerBlock, - const size_t *GlobalWorkSize, - const size_t MaxThreadsPerBlock[3], - ur_kernel_handle_t Kernel) { +// Determine local work sizes that result in uniform work groups. +// The default threadsPerBlock only require handling the first work_dim +// dimension. +void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock, + const size_t *GlobalWorkSize, const uint32_t WorkDim, + const size_t MaxThreadsPerBlock[3], + ur_kernel_handle_t Kernel) { assert(ThreadsPerBlock != nullptr); assert(GlobalWorkSize != nullptr); assert(Kernel != nullptr); - std::ignore = Kernel; + // FIXME: The below assumes a three dimensional range but this is not + // guaranteed by UR. + size_t GlobalSizeNormalized[3] = {1, 1, 1}; + for (uint32_t i = 0; i < WorkDim; i++) { + GlobalSizeNormalized[i] = GlobalWorkSize[i]; + } + + size_t MaxBlockDim[3]; + MaxBlockDim[0] = MaxThreadsPerBlock[0]; + MaxBlockDim[1] = Device->getMaxBlockDimY(); + MaxBlockDim[2] = Device->getMaxBlockDimZ(); - ThreadsPerBlock[0] = std::min(MaxThreadsPerBlock[0], GlobalWorkSize[0]); + int MinGrid, MaxBlockSize; + UR_CHECK_ERROR(hipOccupancyMaxPotentialBlockSize( + &MinGrid, &MaxBlockSize, Kernel->get(), Kernel->getLocalSize(), + MaxThreadsPerBlock[0])); - // Find a local work group size that is a divisor of the global - // work group size to produce uniform work groups. - while (GlobalWorkSize[0] % ThreadsPerBlock[0]) { - --ThreadsPerBlock[0]; - } + roundToHighestFactorOfGlobalSizeIn3d(ThreadsPerBlock, GlobalSizeNormalized, + MaxBlockDim, MaxBlockSize); } ur_result_t setHipMemAdvise(const void *DevPtr, const size_t Size, @@ -344,8 +359,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( return err; } } else { - simpleGuessLocalWorkSize(ThreadsPerBlock, pGlobalWorkSize, - MaxThreadsPerBlock, hKernel); + guessLocalWorkSize(hQueue->getDevice(), ThreadsPerBlock, pGlobalWorkSize, + workDim, MaxThreadsPerBlock, hKernel); } }