Skip to content

Commit

Permalink
[NFCI][SYCL] Use convertToOpenCLType in more places (#12692)
Browse files Browse the repository at this point in the history
Follow-up for #12674, updating places
where `ConvertToOpenCLType_t` was used with a plain cast instead of
`convertDataToType`.

Not touching `multi_ptr` related uses just yet.
  • Loading branch information
aelovikov-intel authored Feb 13, 2024
1 parent b06cfb5 commit 0c48d9c
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 15 deletions.
14 changes: 6 additions & 8 deletions sycl/include/sycl/detail/spirv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,8 +803,7 @@ EnableIfGenericShuffle<T> SubgroupShuffleUp(T x, uint32_t delta);
template <typename T>
EnableIfNativeShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
#ifndef __NVPTX__
using OCLT = detail::ConvertToOpenCLType_t<T>;
return __spirv_SubgroupShuffleINTEL(OCLT(x),
return __spirv_SubgroupShuffleINTEL(convertToOpenCLType(x),
static_cast<uint32_t>(local_id.get(0)));
#else
return __nvvm_shfl_sync_idx_i32(membermask(), x, local_id.get(0), 0x1f);
Expand All @@ -814,9 +813,8 @@ EnableIfNativeShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
template <typename T>
EnableIfNativeShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
#ifndef __NVPTX__
using OCLT = detail::ConvertToOpenCLType_t<T>;
return __spirv_SubgroupShuffleXorINTEL(
OCLT(x), static_cast<uint32_t>(local_id.get(0)));
convertToOpenCLType(x), static_cast<uint32_t>(local_id.get(0)));
#else
return __nvvm_shfl_sync_bfly_i32(membermask(), x, local_id.get(0), 0x1f);
#endif
Expand All @@ -825,8 +823,8 @@ EnableIfNativeShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
template <typename T>
EnableIfNativeShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
#ifndef __NVPTX__
using OCLT = detail::ConvertToOpenCLType_t<T>;
return __spirv_SubgroupShuffleDownINTEL(OCLT(x), OCLT(x), delta);
return __spirv_SubgroupShuffleDownINTEL(convertToOpenCLType(x),
convertToOpenCLType(x), delta);
#else
return __nvvm_shfl_sync_down_i32(membermask(), x, delta, 0x1f);
#endif
Expand All @@ -835,8 +833,8 @@ EnableIfNativeShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
template <typename T>
EnableIfNativeShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
#ifndef __NVPTX__
using OCLT = detail::ConvertToOpenCLType_t<T>;
return __spirv_SubgroupShuffleUpINTEL(OCLT(x), OCLT(x), delta);
return __spirv_SubgroupShuffleUpINTEL(convertToOpenCLType(x),
convertToOpenCLType(x), delta);
#else
return __nvvm_shfl_sync_up_i32(membermask(), x, delta, 0);
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,9 @@ inline sycl::vec<unsigned, 4> ExtractMask(ext::oneapi::sub_group_mask Mask) {
// TODO: This may need to be generalized beyond uint32_t for big masks
inline uint32_t CallerPositionInMask(ext::oneapi::sub_group_mask Mask) {
sycl::vec<unsigned, 4> MemberMask = ExtractMask(Mask);
auto OCLMask =
sycl::detail::ConvertToOpenCLType_t<sycl::vec<unsigned, 4>>(MemberMask);
return __spirv_GroupNonUniformBallotBitCount(
__spv::Scope::Subgroup, (int)__spv::GroupOperation::ExclusiveScan,
OCLMask);
sycl::detail::convertToOpenCLType(MemberMask));
}
#endif

Expand Down
5 changes: 2 additions & 3 deletions sycl/include/sycl/ext/oneapi/sub_group_mask.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,9 @@ struct sub_group_mask {
for (int i = 0; i < 4; ++i) {
MemberMask[i] = TmpMArray[i];
}
auto OCLMask =
sycl::detail::ConvertToOpenCLType_t<sycl::vec<unsigned, 4>>(MemberMask);
return __spirv_GroupNonUniformBallotBitCount(
__spv::Scope::Subgroup, (int)__spv::GroupOperation::Reduce, OCLMask);
__spv::Scope::Subgroup, (int)__spv::GroupOperation::Reduce,
sycl::detail::convertToOpenCLType(MemberMask));
#else
unsigned int count = 0;
auto word = (Bits & valuable_bits(bits_num));
Expand Down
2 changes: 1 addition & 1 deletion sycl/include/sycl/group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include <sycl/detail/common.hpp> // for NDLoop, __SYCL_ASSERT
#include <sycl/detail/defines.hpp> // for __SYCL_TYPE
#include <sycl/detail/defines_elementary.hpp> // for __SYCL2020_DEPRECATED
#include <sycl/detail/generic_type_traits.hpp> // for ConvertToOpenCLType_t
#include <sycl/detail/generic_type_traits.hpp> // for convertToOpenCLType
#include <sycl/detail/helpers.hpp> // for Builder, getSPIRVMemo...
#include <sycl/detail/item_base.hpp> // for id, range
#include <sycl/detail/type_traits.hpp> // for is_bool, change_base_...
Expand Down

0 comments on commit 0c48d9c

Please sign in to comment.