diff --git a/source/ur/ur.hpp b/source/ur/ur.hpp index cf84bb437d..f60e7eff70 100644 --- a/source/ur/ur.hpp +++ b/source/ur/ur.hpp @@ -352,20 +352,29 @@ static inline void roundToHighestFactorOfGlobalSizeIn3d( size_t *ThreadsPerBlock, const size_t *GlobalSize, const size_t *MaxBlockDim, const size_t MaxBlockSize, const size_t WorkDim) { - ThreadsPerBlock[0] = std::min(GlobalSize[0], MaxBlockDim[0]); - // Make the X dim a factor of 2 + + auto findLowerFactor = [](size_t &ThreadsPerBlock) { + return !isPowerOf2(ThreadsPerBlock) && ThreadsPerBlock > 32 && + --ThreadsPerBlock; + }; + assert(GlobalSize[0] && "GlobalSize[0] cannot be zero"); + assert(GlobalSize[1] && "GlobalSize[1] cannot be zero"); + assert(GlobalSize[2] && "GlobalSize[2] cannot be zero"); + + ThreadsPerBlock[0] = + std::min(GlobalSize[0], std::min(MaxBlockSize, MaxBlockDim[0])); do { roundToHighestFactorOfGlobalSize(ThreadsPerBlock[0], GlobalSize[0]); - } while (WorkDim == 3 && !isPowerOf2(ThreadsPerBlock[0]) && - ThreadsPerBlock[0] > 32 && --ThreadsPerBlock[0]); + } while (WorkDim == 3 && findLowerFactor(ThreadsPerBlock[0])); + assert(ThreadsPerBlock[0] && "ThreadsPerBlock[0] cannot be zero"); ThreadsPerBlock[1] = std::min(GlobalSize[1], std::min(MaxBlockSize / ThreadsPerBlock[0], MaxBlockDim[1])); do { roundToHighestFactorOfGlobalSize(ThreadsPerBlock[1], GlobalSize[1]); - } while (WorkDim == 2 && !isPowerOf2(ThreadsPerBlock[1]) && - ThreadsPerBlock[1] > 32 && --ThreadsPerBlock[1]); + } while (WorkDim == 2 && findLowerFactor(ThreadsPerBlock[1])); + assert(ThreadsPerBlock[1] && "ThreadsPerBlock[1] cannot be zero"); ThreadsPerBlock[2] = std::min( GlobalSize[2], @@ -373,6 +382,6 @@ static inline void roundToHighestFactorOfGlobalSizeIn3d( MaxBlockDim[2])); do { roundToHighestFactorOfGlobalSize(ThreadsPerBlock[2], GlobalSize[2]); - } while (WorkDim == 1 && !isPowerOf2(ThreadsPerBlock[2]) && - ThreadsPerBlock[2] > 32 && --ThreadsPerBlock[2]); + } while (WorkDim == 1 && findLowerFactor(ThreadsPerBlock[2])); + assert(ThreadsPerBlock[2] && "ThreadsPerBlock[2] cannot be zero"); }