Skip to content

Commit

Permalink
Test CI
Browse files Browse the repository at this point in the history
  • Loading branch information
hdelan committed Mar 15, 2024
1 parent b5af496 commit 12b3bf0
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions source/ur/ur.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,27 +352,36 @@ static inline void roundToHighestFactorOfGlobalSizeIn3d(
size_t *ThreadsPerBlock, const size_t *GlobalSize,
const size_t *MaxBlockDim, const size_t MaxBlockSize,
const size_t WorkDim) {
ThreadsPerBlock[0] = std::min(GlobalSize[0], MaxBlockDim[0]);
// Make the X dim a factor of 2

auto findLowerFactor = [](size_t &ThreadsPerBlock) {
return !isPowerOf2(ThreadsPerBlock) && ThreadsPerBlock > 32 &&
--ThreadsPerBlock;
};
assert(GlobalSize[0] && "GlobalSize[0] cannot be zero");
assert(GlobalSize[1] && "GlobalSize[1] cannot be zero");
assert(GlobalSize[2] && "GlobalSize[2] cannot be zero");

ThreadsPerBlock[0] =
std::min(GlobalSize[0], std::min(MaxBlockSize, MaxBlockDim[0]));
do {
roundToHighestFactorOfGlobalSize(ThreadsPerBlock[0], GlobalSize[0]);
} while (WorkDim == 3 && !isPowerOf2(ThreadsPerBlock[0]) &&
ThreadsPerBlock[0] > 32 && --ThreadsPerBlock[0]);
} while (WorkDim == 3 && findLowerFactor(ThreadsPerBlock[0]));
assert(ThreadsPerBlock[0] && "ThreadsPerBlock[0] cannot be zero");

ThreadsPerBlock[1] =
std::min(GlobalSize[1],
std::min(MaxBlockSize / ThreadsPerBlock[0], MaxBlockDim[1]));
do {
roundToHighestFactorOfGlobalSize(ThreadsPerBlock[1], GlobalSize[1]);
} while (WorkDim == 2 && !isPowerOf2(ThreadsPerBlock[1]) &&
ThreadsPerBlock[1] > 32 && --ThreadsPerBlock[1]);
} while (WorkDim == 2 && findLowerFactor(ThreadsPerBlock[1]));
assert(ThreadsPerBlock[1] && "ThreadsPerBlock[1] cannot be zero");

ThreadsPerBlock[2] = std::min(
GlobalSize[2],
std::min(MaxBlockSize / (ThreadsPerBlock[1] * ThreadsPerBlock[0]),
MaxBlockDim[2]));
do {
roundToHighestFactorOfGlobalSize(ThreadsPerBlock[2], GlobalSize[2]);
} while (WorkDim == 1 && !isPowerOf2(ThreadsPerBlock[2]) &&
ThreadsPerBlock[2] > 32 && --ThreadsPerBlock[2]);
} while (WorkDim == 1 && findLowerFactor(ThreadsPerBlock[2]));
assert(ThreadsPerBlock[2] && "ThreadsPerBlock[2] cannot be zero");
}

0 comments on commit 12b3bf0

Please sign in to comment.