Skip to content

Commit

Permalink
[NFCI][SYCL] Support multi_ptr in convertToOpenCLType (#12693)
Browse files Browse the repository at this point in the history
  • Loading branch information
aelovikov-intel authored Feb 13, 2024
1 parent 7999e27 commit e5da67b
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 38 deletions.
34 changes: 32 additions & 2 deletions sycl/include/sycl/detail/generic_type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,14 @@ template <> struct ConvertToOpenCLTypeImpl<Boolean<1>> {
// Or should it be "int"?
using type = Boolean<1>;
};
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
// TODO: It seems we only use this to convert a pointer's element type. As such,
// although it doesn't look very clean, it should be ok having this case handled
// explicitly until further refactoring of this area.
template <> struct ConvertToOpenCLTypeImpl<std::byte> {
using type = uint8_t;
};
#endif
#endif

template <typename T> struct ConvertToOpenCLTypeImpl<T *> {
Expand Down Expand Up @@ -700,8 +708,30 @@ convertDataToType(FROM t) {
// Now fuse the above into a simpler helper that's easy to use.
// TODO: That should probably be moved outside of "type_traits".
template <typename T> auto convertToOpenCLType(T &&x) {
using OpenCLType = ConvertToOpenCLType_t<std::remove_reference_t<T>>;
return convertDataToType<T, OpenCLType>(std::forward<T>(x));
using no_ref = std::remove_reference_t<T>;
if constexpr (is_multi_ptr_v<no_ref>) {
return convertToOpenCLType(x.get_decorated());
} else if constexpr (std::is_pointer_v<no_ref>) {
// TODO: Below ignores volatile, but we didn't have a need for it yet.
using elem_type = remove_decoration_t<std::remove_pointer_t<no_ref>>;
using converted_elem_type_no_cv =
ConvertToOpenCLType_t<std::remove_const_t<elem_type>>;
using converted_elem_type =
std::conditional_t<std::is_const_v<elem_type>,
const converted_elem_type_no_cv,
converted_elem_type_no_cv>;
#ifdef __SYCL_DEVICE_ONLY__
using result_type =
typename DecoratedType<converted_elem_type,
deduce_AS<no_ref>::value>::type *;
#else
using result_type = converted_elem_type *;
#endif
return reinterpret_cast<result_type>(x);
} else {
using OpenCLType = ConvertToOpenCLType_t<no_ref>;
return convertDataToType<T, OpenCLType>(std::forward<T>(x));
}
}

template <typename To, typename From> auto convertFromOpenCLTypeFor(From &&x) {
Expand Down
28 changes: 8 additions & 20 deletions sycl/include/sycl/group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,9 @@ template <int Dimensions = 1> class __SYCL_TYPE(group) group {
global_ptr<dataT> src,
size_t numElements,
size_t srcStride) const {
using DestT = detail::ConvertToOpenCLType_t<decltype(dest)>;
using SrcT = detail::ConvertToOpenCLType_t<decltype(src)>;

__ocl_event_t E = __SYCL_OpGroupAsyncCopyGlobalToLocal(
__spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()),
numElements, srcStride, 0);
__spv::Scope::Workgroup, detail::convertToOpenCLType(dest),
detail::convertToOpenCLType(src), numElements, srcStride, 0);
return device_event(E);
}

Expand All @@ -337,12 +334,9 @@ template <int Dimensions = 1> class __SYCL_TYPE(group) group {
size_t numElements,
size_t destStride)
const {
using DestT = detail::ConvertToOpenCLType_t<decltype(dest)>;
using SrcT = detail::ConvertToOpenCLType_t<decltype(src)>;

__ocl_event_t E = __SYCL_OpGroupAsyncCopyLocalToGlobal(
__spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()),
numElements, destStride, 0);
__spv::Scope::Workgroup, detail::convertToOpenCLType(dest),
detail::convertToOpenCLType(src), numElements, destStride, 0);
return device_event(E);
}

Expand All @@ -359,12 +353,9 @@ template <int Dimensions = 1> class __SYCL_TYPE(group) group {
async_work_group_copy(decorated_local_ptr<DestDataT> dest,
decorated_global_ptr<SrcDataT> src, size_t numElements,
size_t srcStride) const {
using DestT = detail::ConvertToOpenCLType_t<decltype(dest)>;
using SrcT = detail::ConvertToOpenCLType_t<decltype(src)>;

__ocl_event_t E = __SYCL_OpGroupAsyncCopyGlobalToLocal(
__spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()),
numElements, srcStride, 0);
__spv::Scope::Workgroup, detail::convertToOpenCLType(dest),
detail::convertToOpenCLType(src), numElements, srcStride, 0);
return device_event(E);
}

Expand All @@ -381,12 +372,9 @@ template <int Dimensions = 1> class __SYCL_TYPE(group) group {
async_work_group_copy(decorated_global_ptr<DestDataT> dest,
decorated_local_ptr<SrcDataT> src, size_t numElements,
size_t destStride) const {
using DestT = detail::ConvertToOpenCLType_t<decltype(dest)>;
using SrcT = detail::ConvertToOpenCLType_t<decltype(src)>;

__ocl_event_t E = __SYCL_OpGroupAsyncCopyLocalToGlobal(
__spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()),
numElements, destStride, 0);
__spv::Scope::Workgroup, detail::convertToOpenCLType(dest),
detail::convertToOpenCLType(src), numElements, destStride, 0);
return device_event(E);
}

Expand Down
38 changes: 22 additions & 16 deletions sycl/include/sycl/sub_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,24 @@ namespace sub_group {
template <typename T>
using SelectBlockT = select_cl_scalar_integral_unsigned_t<T>;

template <typename MultiPtrTy> auto convertToBlockPtr(MultiPtrTy MultiPtr) {
static_assert(is_multi_ptr_v<MultiPtrTy>);
auto DecoratedPtr = convertToOpenCLType(MultiPtr);
using DecoratedPtrTy = decltype(DecoratedPtr);
using ElemTy = remove_decoration_t<std::remove_pointer_t<DecoratedPtrTy>>;

using TargetElemTy = SelectBlockT<ElemTy>;
// TODO: Handle cv qualifiers.
#ifdef __SYCL_DEVICE_ONLY__
using ResultTy =
typename DecoratedType<TargetElemTy,
deduce_AS<DecoratedPtrTy>::value>::type *;
#else
using ResultTy = TargetElemTy *;
#endif
return reinterpret_cast<ResultTy>(DecoratedPtr);
}

template <typename T, access::address_space Space>
using AcceptableForGlobalLoadStore =
std::bool_constant<!std::is_same_v<void, SelectBlockT<T>> &&
Expand All @@ -57,11 +75,7 @@ template <typename T, access::address_space Space,
access::decorated DecorateAddress>
T load(const multi_ptr<T, Space, DecorateAddress> src) {
using BlockT = SelectBlockT<T>;
using PtrT = sycl::detail::ConvertToOpenCLType_t<
const multi_ptr<BlockT, Space, DecorateAddress>>;

BlockT Ret =
__spirv_SubgroupBlockReadINTEL<BlockT>(reinterpret_cast<PtrT>(src.get()));
BlockT Ret = __spirv_SubgroupBlockReadINTEL<BlockT>(convertToBlockPtr(src));

return sycl::bit_cast<T>(Ret);
}
Expand All @@ -71,11 +85,7 @@ template <int N, typename T, access::address_space Space,
vec<T, N> load(const multi_ptr<T, Space, DecorateAddress> src) {
using BlockT = SelectBlockT<T>;
using VecT = sycl::detail::ConvertToOpenCLType_t<vec<BlockT, N>>;
using PtrT = sycl::detail::ConvertToOpenCLType_t<
const multi_ptr<BlockT, Space, DecorateAddress>>;

VecT Ret =
__spirv_SubgroupBlockReadINTEL<VecT>(reinterpret_cast<PtrT>(src.get()));
VecT Ret = __spirv_SubgroupBlockReadINTEL<VecT>(convertToBlockPtr(src));

return sycl::bit_cast<typename vec<T, N>::vector_t>(Ret);
}
Expand All @@ -84,10 +94,8 @@ template <typename T, access::address_space Space,
access::decorated DecorateAddress>
void store(multi_ptr<T, Space, DecorateAddress> dst, const T &x) {
using BlockT = SelectBlockT<T>;
using PtrT = sycl::detail::ConvertToOpenCLType_t<
multi_ptr<BlockT, Space, DecorateAddress>>;

__spirv_SubgroupBlockWriteINTEL(reinterpret_cast<PtrT>(dst.get()),
__spirv_SubgroupBlockWriteINTEL(convertToBlockPtr(dst),
sycl::bit_cast<BlockT>(x));
}

Expand All @@ -96,10 +104,8 @@ template <int N, typename T, access::address_space Space,
void store(multi_ptr<T, Space, DecorateAddress> dst, const vec<T, N> &x) {
using BlockT = SelectBlockT<T>;
using VecT = sycl::detail::ConvertToOpenCLType_t<vec<BlockT, N>>;
using PtrT = sycl::detail::ConvertToOpenCLType_t<
const multi_ptr<BlockT, Space, DecorateAddress>>;

__spirv_SubgroupBlockWriteINTEL(reinterpret_cast<PtrT>(dst.get()),
__spirv_SubgroupBlockWriteINTEL(convertToBlockPtr(dst),
sycl::bit_cast<VecT>(x));
}
#endif // __SYCL_DEVICE_ONLY__
Expand Down

0 comments on commit e5da67b

Please sign in to comment.