diff --git a/source/adapters/cuda/enqueue.cpp b/source/adapters/cuda/enqueue.cpp index e7ee2bf523..755ddb0000 100644 --- a/source/adapters/cuda/enqueue.cpp +++ b/source/adapters/cuda/enqueue.cpp @@ -164,7 +164,7 @@ void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock, MaxThreadsPerBlock[0])); roundToHighestFactorOfGlobalSizeIn3d(ThreadsPerBlock, GlobalSizeNormalized, - MaxBlockDim, MaxBlockSize); + MaxBlockDim, MaxBlockSize, WorkDim); } // Helper to verify out-of-registers case (exceeded block max registers). diff --git a/source/adapters/hip/enqueue.cpp b/source/adapters/hip/enqueue.cpp index d3abe7e7c4..4db76cc158 100644 --- a/source/adapters/hip/enqueue.cpp +++ b/source/adapters/hip/enqueue.cpp @@ -74,7 +74,8 @@ void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock, MaxBlockDim[2] = Device->getMaxBlockDimZ(); roundToHighestFactorOfGlobalSizeIn3d(ThreadsPerBlock, GlobalSizeNormalized, - MaxBlockDim, MaxThreadsPerBlock[0]); + MaxBlockDim, MaxThreadsPerBlock[0], + WorkDim); } namespace { diff --git a/source/ur/ur.hpp b/source/ur/ur.hpp index 303a011796..cf84bb437d 100644 --- a/source/ur/ur.hpp +++ b/source/ur/ur.hpp @@ -347,23 +347,32 @@ template inline bool isPowerOf2(const T &Value) { // dims == 1) // In: MaxBlockDim - The max size of block in 3d // In: MaxBlockSize - The max total size of block in all dimensions +// In: WorkDim - The workdim (1, 2 or 3) static inline void roundToHighestFactorOfGlobalSizeIn3d( size_t *ThreadsPerBlock, const size_t *GlobalSize, - const size_t *MaxBlockDim, const size_t MaxBlockSize) { + const size_t *MaxBlockDim, const size_t MaxBlockSize, + const size_t WorkDim) { ThreadsPerBlock[0] = std::min(GlobalSize[0], MaxBlockDim[0]); // Make the X dim a factor of 2 do { roundToHighestFactorOfGlobalSize(ThreadsPerBlock[0], GlobalSize[0]); - } while (!isPowerOf2(ThreadsPerBlock[0]) && ThreadsPerBlock[0] > 32 && - --ThreadsPerBlock[0]); + } while (WorkDim == 3 && !isPowerOf2(ThreadsPerBlock[0]) && + ThreadsPerBlock[0] > 32 && --ThreadsPerBlock[0]); ThreadsPerBlock[1] = std::min(GlobalSize[1], std::min(MaxBlockSize / ThreadsPerBlock[0], MaxBlockDim[1])); - roundToHighestFactorOfGlobalSize(ThreadsPerBlock[1], GlobalSize[1]); + do { + roundToHighestFactorOfGlobalSize(ThreadsPerBlock[1], GlobalSize[1]); + } while (WorkDim == 2 && !isPowerOf2(ThreadsPerBlock[1]) && + ThreadsPerBlock[1] > 32 && --ThreadsPerBlock[1]); ThreadsPerBlock[2] = std::min( - GlobalSize[2], MaxBlockSize / (ThreadsPerBlock[1] * ThreadsPerBlock[0])); - roundToHighestFactorOfGlobalSize(ThreadsPerBlock[2], GlobalSize[2]); - + GlobalSize[2], + std::min(MaxBlockSize / (ThreadsPerBlock[1] * ThreadsPerBlock[0]), + MaxBlockDim[2])); + do { + roundToHighestFactorOfGlobalSize(ThreadsPerBlock[2], GlobalSize[2]); + } while (WorkDim == 1 && !isPowerOf2(ThreadsPerBlock[2]) && + ThreadsPerBlock[2] > 32 && --ThreadsPerBlock[2]); }