Skip to content

Commit

Permalink
[CUDA] Fix cuda group/non-uniform group shuffles. (#13230)
Browse files Browse the repository at this point in the history
This follows on from discussion of
#12705 (comment) to
impl/fix non-uniform group shuffles on cuda.

- Non-uniform group algorithm impls fixes for permute/left/right
- Generalize group shuffles to support double/half/long/short correctly
for both uniform and non-uniform groups
- Make fixed_size_group test fail if group member "local id" mapping not
correct or removed.
- Update ballot_group_algorithms.cpp to test previously failing cases on
cuda backend.

Shuffle impls in ::detail match those in syclomatic for masked shuffle
builtins (which don't exist in oneapi outside syclomatic).

---------

Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
  • Loading branch information
JackAKirk authored Apr 9, 2024
1 parent 0939f39 commit a0c3b32
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 121 deletions.
55 changes: 41 additions & 14 deletions sycl/include/sycl/detail/spirv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

#include <sycl/ext/oneapi/experimental/non_uniform_groups.hpp> // for IdToMaskPosition

#if defined(__NVPTX__)
#include <sycl/ext/oneapi/experimental/cuda/masked_shuffles.hpp>
#endif

#include <sycl/detail/memcpy.hpp> // sycl::detail::memcpy

namespace sycl {
Expand Down Expand Up @@ -870,10 +874,10 @@ EnableIfNativeShuffle<T> Shuffle(GroupT g, T x, id<1> local_id) {
#else
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT>) {
return __nvvm_shfl_sync_idx_i32(detail::ExtractMask(detail::GetMask(g))[0],
x, LocalId, 0x1f);
return cuda_shfl_sync_idx_i32(detail::ExtractMask(detail::GetMask(g))[0], x,
LocalId, 31);
} else {
return __nvvm_shfl_sync_idx_i32(membermask(), x, LocalId, 0x1f);
return cuda_shfl_sync_idx_i32(membermask(), x, LocalId, 31);
}
#endif
}
Expand Down Expand Up @@ -908,12 +912,20 @@ EnableIfNativeShuffle<T> ShuffleXor(GroupT g, T x, id<1> mask) {
#else
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT>) {
return __nvvm_shfl_sync_bfly_i32(detail::ExtractMask(detail::GetMask(g))[0],
x, static_cast<uint32_t>(mask.get(0)),
0x1f);
} else {
return __nvvm_shfl_sync_bfly_i32(membermask(), x,
auto MemberMask = detail::ExtractMask(detail::GetMask(g))[0];
if constexpr (is_fixed_size_group_v<GroupT>) {
return cuda_shfl_sync_bfly_i32(MemberMask, x,
static_cast<uint32_t>(mask.get(0)), 0x1f);

} else {
int unfoldedSrcSetBit =
(g.get_local_id()[0] ^ static_cast<uint32_t>(mask.get(0))) + 1;
return cuda_shfl_sync_idx_i32(
MemberMask, x, __nvvm_fns(MemberMask, 0, unfoldedSrcSetBit), 31);
}
} else {
return cuda_shfl_sync_bfly_i32(membermask(), x,
static_cast<uint32_t>(mask.get(0)), 0x1f);
}
#endif
}
Expand Down Expand Up @@ -948,10 +960,17 @@ EnableIfNativeShuffle<T> ShuffleDown(GroupT g, T x, uint32_t delta) {
#else
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT>) {
return __nvvm_shfl_sync_down_i32(detail::ExtractMask(detail::GetMask(g))[0],
x, delta, 0x1f);
auto MemberMask = detail::ExtractMask(detail::GetMask(g))[0];
if constexpr (is_fixed_size_group_v<GroupT>) {
return cuda_shfl_sync_down_i32(MemberMask, x, delta, 31);
} else {
unsigned localSetBit = g.get_local_id()[0] + 1;
int unfoldedSrcSetBit = localSetBit + delta;
return cuda_shfl_sync_idx_i32(
MemberMask, x, __nvvm_fns(MemberMask, 0, unfoldedSrcSetBit), 31);
}
} else {
return __nvvm_shfl_sync_down_i32(membermask(), x, delta, 0x1f);
return cuda_shfl_sync_down_i32(membermask(), x, delta, 31);
}
#endif
}
Expand Down Expand Up @@ -985,10 +1004,18 @@ EnableIfNativeShuffle<T> ShuffleUp(GroupT g, T x, uint32_t delta) {
#else
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT>) {
return __nvvm_shfl_sync_up_i32(detail::ExtractMask(detail::GetMask(g))[0],
x, delta, 0);
auto MemberMask = detail::ExtractMask(detail::GetMask(g))[0];
if constexpr (is_fixed_size_group_v<GroupT>) {
return cuda_shfl_sync_up_i32(MemberMask, x, delta, 0);
} else {
unsigned localSetBit = g.get_local_id()[0] + 1;
int unfoldedSrcSetBit = localSetBit - delta;

return cuda_shfl_sync_idx_i32(
MemberMask, x, __nvvm_fns(MemberMask, 0, unfoldedSrcSetBit), 31);
}
} else {
return __nvvm_shfl_sync_up_i32(membermask(), x, delta, 0);
return cuda_shfl_sync_up_i32(membermask(), x, delta, 0);
}
#endif
}
Expand Down
62 changes: 62 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/cuda/masked_shuffles.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//==--------- masked_shuffles.hpp - cuda masked shuffle algorithms ---------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#pragma once

#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)

namespace sycl {
inline namespace _V1 {
namespace detail {

#define CUDA_SHFL_SYNC(SHUFFLE_INSTR) \
template <typename T> \
inline __SYCL_ALWAYS_INLINE T cuda_shfl_sync_##SHUFFLE_INSTR( \
unsigned int mask, T val, unsigned int shfl_param, int c) { \
T res; \
if constexpr (std::is_same_v<T, double>) { \
int x_a, x_b; \
asm("mov.b64 {%0,%1},%2;" : "=r"(x_a), "=r"(x_b) : "d"(val)); \
auto tmp_a = __nvvm_shfl_sync_##SHUFFLE_INSTR(mask, x_a, shfl_param, c); \
auto tmp_b = __nvvm_shfl_sync_##SHUFFLE_INSTR(mask, x_b, shfl_param, c); \
asm("mov.b64 %0,{%1,%2};" : "=d"(res) : "r"(tmp_a), "r"(tmp_b)); \
} else if constexpr (std::is_same_v<T, long> || \
std::is_same_v<T, unsigned long>) { \
int x_a, x_b; \
asm("mov.b64 {%0,%1},%2;" : "=r"(x_a), "=r"(x_b) : "l"(val)); \
auto tmp_a = __nvvm_shfl_sync_##SHUFFLE_INSTR(mask, x_a, shfl_param, c); \
auto tmp_b = __nvvm_shfl_sync_##SHUFFLE_INSTR(mask, x_b, shfl_param, c); \
asm("mov.b64 %0,{%1,%2};" : "=l"(res) : "r"(tmp_a), "r"(tmp_b)); \
} else if constexpr (std::is_same_v<T, half>) { \
short tmp_b16; \
asm("mov.b16 %0,%1;" : "=h"(tmp_b16) : "h"(val)); \
auto tmp_b32 = __nvvm_shfl_sync_##SHUFFLE_INSTR( \
mask, static_cast<int>(tmp_b16), shfl_param, c); \
asm("mov.b16 %0,%1;" : "=h"(res) : "h"(static_cast<short>(tmp_b32))); \
} else if constexpr (std::is_same_v<T, float>) { \
auto tmp_b32 = __nvvm_shfl_sync_##SHUFFLE_INSTR( \
mask, __nvvm_bitcast_f2i(val), shfl_param, c); \
res = __nvvm_bitcast_i2f(tmp_b32); \
} else { \
res = __nvvm_shfl_sync_##SHUFFLE_INSTR(mask, val, shfl_param, c); \
} \
return res; \
}

CUDA_SHFL_SYNC(bfly_i32)
CUDA_SHFL_SYNC(up_i32)
CUDA_SHFL_SYNC(down_i32)
CUDA_SHFL_SYNC(idx_i32)

#undef CUDA_SHFL_SYNC

} // namespace detail
} // namespace _V1
} // namespace sycl

#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
#include "masked_shuffles.hpp"

namespace sycl {
inline namespace _V1 {
Expand Down Expand Up @@ -100,87 +101,12 @@ inline __SYCL_ALWAYS_INLINE std::enable_if_t<is_fixed_size_group_v<Group>, T>
masked_reduction_cuda_shfls(Group g, T x, BinaryOperation binary_op,
const uint32_t MemberMask) {
for (int i = g.get_local_range()[0] / 2; i > 0; i /= 2) {
T tmp;
if constexpr (std::is_same_v<T, double>) {
int x_a, x_b;
asm volatile("mov.b64 {%0,%1},%2;" : "=r"(x_a), "=r"(x_b) : "d"(x));
auto tmp_a = __nvvm_shfl_sync_bfly_i32(MemberMask, x_a, -1, i);
auto tmp_b = __nvvm_shfl_sync_bfly_i32(MemberMask, x_b, -1, i);
asm volatile("mov.b64 %0,{%1,%2};" : "=d"(tmp) : "r"(tmp_a), "r"(tmp_b));
} else if constexpr (std::is_same_v<T, long> ||
std::is_same_v<T, unsigned long>) {
int x_a, x_b;
asm volatile("mov.b64 {%0,%1},%2;" : "=r"(x_a), "=r"(x_b) : "l"(x));
auto tmp_a = __nvvm_shfl_sync_bfly_i32(MemberMask, x_a, -1, i);
auto tmp_b = __nvvm_shfl_sync_bfly_i32(MemberMask, x_b, -1, i);
asm volatile("mov.b64 %0,{%1,%2};" : "=l"(tmp) : "r"(tmp_a), "r"(tmp_b));
} else if constexpr (std::is_same_v<T, half>) {
short tmp_b16;
asm volatile("mov.b16 %0,%1;" : "=h"(tmp_b16) : "h"(x));
auto tmp_b32 = __nvvm_shfl_sync_bfly_i32(
MemberMask, static_cast<int>(tmp_b16), -1, i);
asm volatile("mov.b16 %0,%1;"
: "=h"(tmp)
: "h"(static_cast<short>(tmp_b32)));
} else if constexpr (std::is_same_v<T, float>) {
auto tmp_b32 =
__nvvm_shfl_sync_bfly_i32(MemberMask, __nvvm_bitcast_f2i(x), -1, i);
tmp = __nvvm_bitcast_i2f(tmp_b32);
} else {
tmp = __nvvm_shfl_sync_bfly_i32(MemberMask, x, -1, i);
}
T tmp = cuda_shfl_sync_bfly_i32(MemberMask, x, i, 0x1f);
x = binary_op(x, tmp);
}
return x;
}

template <typename Group, typename T>
inline __SYCL_ALWAYS_INLINE std::enable_if_t<
ext::oneapi::experimental::is_user_constructed_group_v<Group>, T>
non_uniform_shfl_T(const uint32_t MemberMask, T x, int shfl_param) {
if constexpr (is_fixed_size_group_v<Group>) {
return __nvvm_shfl_sync_up_i32(MemberMask, x, shfl_param, 0);
} else {
return __nvvm_shfl_sync_idx_i32(MemberMask, x, shfl_param, 31);
}
}

template <typename Group, typename T>
inline __SYCL_ALWAYS_INLINE std::enable_if_t<
ext::oneapi::experimental::is_user_constructed_group_v<Group>, T>
non_uniform_shfl(Group g, const uint32_t MemberMask, T x, int shfl_param) {
T res;
if constexpr (std::is_same_v<T, double>) {
int x_a, x_b;
asm volatile("mov.b64 {%0,%1},%2;" : "=r"(x_a), "=r"(x_b) : "d"(x));
auto tmp_a = non_uniform_shfl_T<Group>(MemberMask, x_a, shfl_param);
auto tmp_b = non_uniform_shfl_T<Group>(MemberMask, x_b, shfl_param);
asm volatile("mov.b64 %0,{%1,%2};" : "=d"(res) : "r"(tmp_a), "r"(tmp_b));
} else if constexpr (std::is_same_v<T, long> ||
std::is_same_v<T, unsigned long>) {
int x_a, x_b;
asm volatile("mov.b64 {%0,%1},%2;" : "=r"(x_a), "=r"(x_b) : "l"(x));
auto tmp_a = non_uniform_shfl_T<Group>(MemberMask, x_a, shfl_param);
auto tmp_b = non_uniform_shfl_T<Group>(MemberMask, x_b, shfl_param);
asm volatile("mov.b64 %0,{%1,%2};" : "=l"(res) : "r"(tmp_a), "r"(tmp_b));
} else if constexpr (std::is_same_v<T, half>) {
short tmp_b16;
asm volatile("mov.b16 %0,%1;" : "=h"(tmp_b16) : "h"(x));
auto tmp_b32 = non_uniform_shfl_T<Group>(
MemberMask, static_cast<int>(tmp_b16), shfl_param);
asm volatile("mov.b16 %0,%1;"
: "=h"(res)
: "h"(static_cast<short>(tmp_b32)));
} else if constexpr (std::is_same_v<T, float>) {
auto tmp_b32 = non_uniform_shfl_T<Group>(MemberMask, __nvvm_bitcast_f2i(x),
shfl_param);
res = __nvvm_bitcast_i2f(tmp_b32);
} else {
res = non_uniform_shfl_T<Group>(MemberMask, x, shfl_param);
}
return res;
}

// Opportunistic/Ballot group reduction using shfls
template <typename Group, typename T, class BinaryOperation>
inline __SYCL_ALWAYS_INLINE std::enable_if_t<
Expand All @@ -207,8 +133,8 @@ masked_reduction_cuda_shfls(Group g, T x, BinaryOperation binary_op,

// __nvvm_fns automatically wraps around to the correct bit position.
// There is no performance impact on src_set_bit position wrt localSetBit
auto tmp = non_uniform_shfl(g, MemberMask, x,
__nvvm_fns(MemberMask, 0, unfoldedSrcSetBit));
T tmp = cuda_shfl_sync_idx_i32(
MemberMask, x, __nvvm_fns(MemberMask, 0, unfoldedSrcSetBit), 31);

if (!(localSetBit == 1 && remainder != 0)) {
x = binary_op(x, tmp);
Expand All @@ -224,7 +150,8 @@ masked_reduction_cuda_shfls(Group g, T x, BinaryOperation binary_op,
: "=r"(broadID)
: "r"(MemberMask));

return non_uniform_shfl(g, MemberMask, x, broadID);
x = cuda_shfl_sync_idx_i32(MemberMask, x, broadID, 31);
return x;
}

// Non Redux types must fall back to shfl based implementations.
Expand Down Expand Up @@ -265,18 +192,19 @@ inline __SYCL_ALWAYS_INLINE
return ~0;
}

#define GET_ID(OP_CHECK, OP) \
template <typename T, class BinaryOperation> \
inline __SYCL_ALWAYS_INLINE \
std::enable_if_t<OP_CHECK<T, BinaryOperation>::value, T> \
get_identity() { \
return std::numeric_limits<T>::OP(); \
}

GET_ID(IsMinimum, max)
GET_ID(IsMaximum, min)
template <typename T, class BinaryOperation>
inline __SYCL_ALWAYS_INLINE
std::enable_if_t<IsMinimum<T, BinaryOperation>::value, T>
get_identity() {
return std::numeric_limits<T>::min();
}

#undef GET_ID
template <typename T, class BinaryOperation>
inline __SYCL_ALWAYS_INLINE
std::enable_if_t<IsMaximum<T, BinaryOperation>::value, T>
get_identity() {
return std::numeric_limits<T>::max();
}

//// Shuffle based masked reduction impls

Expand All @@ -288,13 +216,12 @@ masked_scan_cuda_shfls(Group g, T x, BinaryOperation binary_op,
const uint32_t MemberMask) {
unsigned localIdVal = g.get_local_id()[0];
for (int i = 1; i < g.get_local_range()[0]; i *= 2) {
auto tmp = non_uniform_shfl(g, MemberMask, x, i);
T tmp = cuda_shfl_sync_up_i32(MemberMask, x, i, 0);
if (localIdVal >= i)
x = binary_op(x, tmp);
}
if constexpr (Op == __spv::GroupOperation::ExclusiveScan) {

x = non_uniform_shfl(g, MemberMask, x, 1);
x = cuda_shfl_sync_up_i32(MemberMask, x, 1, 0);
if (localIdVal == 0) {
return get_identity<T, BinaryOperation>();
}
Expand All @@ -316,14 +243,15 @@ masked_scan_cuda_shfls(Group g, T x, BinaryOperation binary_op,
for (int i = 1; i < g.get_local_range()[0]; i *= 2) {
int unfoldedSrcSetBit = localSetBit - i;

auto tmp = non_uniform_shfl(g, MemberMask, x,
__nvvm_fns(MemberMask, 0, unfoldedSrcSetBit));
T tmp = cuda_shfl_sync_idx_i32(
MemberMask, x, __nvvm_fns(MemberMask, 0, unfoldedSrcSetBit), 31);

if (localIdVal >= i)
x = binary_op(x, tmp);
}
if constexpr (Op == __spv::GroupOperation::ExclusiveScan) {
x = non_uniform_shfl(g, MemberMask, x,
__nvvm_fns(MemberMask, 0, localSetBit - 1));
x = cuda_shfl_sync_idx_i32(MemberMask, x,
__nvvm_fns(MemberMask, 0, localSetBit - 1), 31);
if (localIdVal == 0) {
return get_identity<T, BinaryOperation>();
}
Expand Down
12 changes: 4 additions & 8 deletions sycl/test-e2e/NonUniformGroups/ballot_group_algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,10 @@ int main() {
assert(ReduceAcc[WI] == true);
assert(ExScanAcc[WI] == true);
assert(IncScanAcc[WI] == true);
// TODO: Enable for CUDA devices when issue with shuffles have been
// addressed.
if (Q.get_backend() != sycl::backend::ext_oneapi_cuda) {
assert(ShiftLeftAcc[WI] == true);
assert(ShiftRightAcc[WI] == true);
assert(SelectAcc[WI] == true);
assert(PermuteXorAcc[WI] == true);
}
assert(ShiftLeftAcc[WI] == true);
assert(ShiftRightAcc[WI] == true);
assert(SelectAcc[WI] == true);
assert(PermuteXorAcc[WI] == true);
}
return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@ template <size_t PartitionSize> void test() {
ShiftRightAcc[WI] = (LID < 2 || ShiftRightResult == LID - 2);

uint32_t SelectResult = sycl::select_from_group(
Partition, LID, (Partition.get_local_id() + 2) % PartitionSize);
SelectAcc[WI] = (SelectResult == (LID + 2) % PartitionSize);
Partition, OriginalLID,
(Partition.get_local_id() + 2) % PartitionSize);
SelectAcc[WI] =
SelectResult == OriginalLID - LID + ((LID + 2) % PartitionSize);

uint32_t Mask = PartitionSize <= 2 ? 0 : 2;
uint32_t PermuteXorResult =
Expand Down

0 comments on commit a0c3b32

Please sign in to comment.