diff --git a/source/adapters/cuda/enqueue.cpp b/source/adapters/cuda/enqueue.cpp index 5046f4c865..b21da2ce97 100644 --- a/source/adapters/cuda/enqueue.cpp +++ b/source/adapters/cuda/enqueue.cpp @@ -161,26 +161,32 @@ void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock, cuOccupancyMaxPotentialBlockSize(&MinGrid, &MaxBlockSize, Kernel->get(), NULL, LocalSize, MaxThreadsPerBlock[0])); + // Helper lambda to make sure each x, y, z dim divide the global dimension. + // Can optionally specify that we want the wg size to be a power of 2 in a + // given dimension, which is useful for the X dim for performance reasons. + static auto roundToHighestFactorOfGlobalSize = + [](size_t &ThreadsPerBlockInDim, const size_t GlobalWorkSizeInDim, + bool MakePowerOfTwo) { + auto IsPowerOf2 = [](size_t Value) -> bool { + return Value && !(Value & (Value - 1)); + }; + while (GlobalWorkSizeInDim % ThreadsPerBlockInDim || + (MakePowerOfTwo && !IsPowerOf2(ThreadsPerBlockInDim))) + --ThreadsPerBlockInDim; + }; + ThreadsPerBlock[2] = std::min(GlobalSizeNormalized[2], MaxBlockDim[2]); + roundToHighestFactorOfGlobalSize(ThreadsPerBlock[2], GlobalWorkSize[2], + false); ThreadsPerBlock[1] = std::min(GlobalSizeNormalized[1], std::min(MaxBlockSize / ThreadsPerBlock[2], MaxBlockDim[1])); + roundToHighestFactorOfGlobalSize(ThreadsPerBlock[1], GlobalWorkSize[1], + false); MaxBlockDim[0] = MaxBlockSize / (ThreadsPerBlock[1] * ThreadsPerBlock[2]); ThreadsPerBlock[0] = std::min( MaxThreadsPerBlock[0], std::min(GlobalSizeNormalized[0], MaxBlockDim[0])); - - static auto IsPowerOf2 = [](size_t Value) -> bool { - return Value && !(Value & (Value - 1)); - }; - - // Find a local work group size that is a divisor of the global - // work group size to produce uniform work groups. - // Additionally, for best compute utilisation, the local size has - // to be a power of two. - while (0u != (GlobalSizeNormalized[0] % ThreadsPerBlock[0]) || - !IsPowerOf2(ThreadsPerBlock[0])) { - --ThreadsPerBlock[0]; - } + roundToHighestFactorOfGlobalSize(ThreadsPerBlock[0], GlobalWorkSize[0], true); } // Helper to verify out-of-registers case (exceeded block max registers).