From e5da67bbb6ed55f77dd823e3cb9c77abfabdb37b Mon Sep 17 00:00:00 2001 From: aelovikov-intel Date: Tue, 13 Feb 2024 15:46:41 -0800 Subject: [PATCH] [NFCI][SYCL] Support multi_ptr in convertToOpenCLType (#12693) --- .../sycl/detail/generic_type_traits.hpp | 34 ++++++++++++++++- sycl/include/sycl/group.hpp | 28 ++++---------- sycl/include/sycl/sub_group.hpp | 38 +++++++++++-------- 3 files changed, 62 insertions(+), 38 deletions(-) diff --git a/sycl/include/sycl/detail/generic_type_traits.hpp b/sycl/include/sycl/detail/generic_type_traits.hpp index 7cf893778394c..bd59361ff2eeb 100644 --- a/sycl/include/sycl/detail/generic_type_traits.hpp +++ b/sycl/include/sycl/detail/generic_type_traits.hpp @@ -663,6 +663,14 @@ template <> struct ConvertToOpenCLTypeImpl> { // Or should it be "int"? using type = Boolean<1>; }; +#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) +// TODO: It seems we only use this to convert a pointer's element type. As such, +// although it doesn't look very clean, it should be ok having this case handled +// explicitly until further refactoring of this area. +template <> struct ConvertToOpenCLTypeImpl { + using type = uint8_t; +}; +#endif #endif template struct ConvertToOpenCLTypeImpl { @@ -700,8 +708,30 @@ convertDataToType(FROM t) { // Now fuse the above into a simpler helper that's easy to use. // TODO: That should probably be moved outside of "type_traits". template auto convertToOpenCLType(T &&x) { - using OpenCLType = ConvertToOpenCLType_t>; - return convertDataToType(std::forward(x)); + using no_ref = std::remove_reference_t; + if constexpr (is_multi_ptr_v) { + return convertToOpenCLType(x.get_decorated()); + } else if constexpr (std::is_pointer_v) { + // TODO: Below ignores volatile, but we didn't have a need for it yet. + using elem_type = remove_decoration_t>; + using converted_elem_type_no_cv = + ConvertToOpenCLType_t>; + using converted_elem_type = + std::conditional_t, + const converted_elem_type_no_cv, + converted_elem_type_no_cv>; +#ifdef __SYCL_DEVICE_ONLY__ + using result_type = + typename DecoratedType::value>::type *; +#else + using result_type = converted_elem_type *; +#endif + return reinterpret_cast(x); + } else { + using OpenCLType = ConvertToOpenCLType_t; + return convertDataToType(std::forward(x)); + } } template auto convertFromOpenCLTypeFor(From &&x) { diff --git a/sycl/include/sycl/group.hpp b/sycl/include/sycl/group.hpp index ceb0c58dcf99c..faeacd10f1998 100644 --- a/sycl/include/sycl/group.hpp +++ b/sycl/include/sycl/group.hpp @@ -315,12 +315,9 @@ template class __SYCL_TYPE(group) group { global_ptr src, size_t numElements, size_t srcStride) const { - using DestT = detail::ConvertToOpenCLType_t; - using SrcT = detail::ConvertToOpenCLType_t; - __ocl_event_t E = __SYCL_OpGroupAsyncCopyGlobalToLocal( - __spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()), - numElements, srcStride, 0); + __spv::Scope::Workgroup, detail::convertToOpenCLType(dest), + detail::convertToOpenCLType(src), numElements, srcStride, 0); return device_event(E); } @@ -337,12 +334,9 @@ template class __SYCL_TYPE(group) group { size_t numElements, size_t destStride) const { - using DestT = detail::ConvertToOpenCLType_t; - using SrcT = detail::ConvertToOpenCLType_t; - __ocl_event_t E = __SYCL_OpGroupAsyncCopyLocalToGlobal( - __spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()), - numElements, destStride, 0); + __spv::Scope::Workgroup, detail::convertToOpenCLType(dest), + detail::convertToOpenCLType(src), numElements, destStride, 0); return device_event(E); } @@ -359,12 +353,9 @@ template class __SYCL_TYPE(group) group { async_work_group_copy(decorated_local_ptr dest, decorated_global_ptr src, size_t numElements, size_t srcStride) const { - using DestT = detail::ConvertToOpenCLType_t; - using SrcT = detail::ConvertToOpenCLType_t; - __ocl_event_t E = __SYCL_OpGroupAsyncCopyGlobalToLocal( - __spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()), - numElements, srcStride, 0); + __spv::Scope::Workgroup, detail::convertToOpenCLType(dest), + detail::convertToOpenCLType(src), numElements, srcStride, 0); return device_event(E); } @@ -381,12 +372,9 @@ template class __SYCL_TYPE(group) group { async_work_group_copy(decorated_global_ptr dest, decorated_local_ptr src, size_t numElements, size_t destStride) const { - using DestT = detail::ConvertToOpenCLType_t; - using SrcT = detail::ConvertToOpenCLType_t; - __ocl_event_t E = __SYCL_OpGroupAsyncCopyLocalToGlobal( - __spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()), - numElements, destStride, 0); + __spv::Scope::Workgroup, detail::convertToOpenCLType(dest), + detail::convertToOpenCLType(src), numElements, destStride, 0); return device_event(E); } diff --git a/sycl/include/sycl/sub_group.hpp b/sycl/include/sycl/sub_group.hpp index 8cfd4ea9b555b..908b5d4aebff0 100644 --- a/sycl/include/sycl/sub_group.hpp +++ b/sycl/include/sycl/sub_group.hpp @@ -42,6 +42,24 @@ namespace sub_group { template using SelectBlockT = select_cl_scalar_integral_unsigned_t; +template auto convertToBlockPtr(MultiPtrTy MultiPtr) { + static_assert(is_multi_ptr_v); + auto DecoratedPtr = convertToOpenCLType(MultiPtr); + using DecoratedPtrTy = decltype(DecoratedPtr); + using ElemTy = remove_decoration_t>; + + using TargetElemTy = SelectBlockT; + // TODO: Handle cv qualifiers. +#ifdef __SYCL_DEVICE_ONLY__ + using ResultTy = + typename DecoratedType::value>::type *; +#else + using ResultTy = TargetElemTy *; +#endif + return reinterpret_cast(DecoratedPtr); +} + template using AcceptableForGlobalLoadStore = std::bool_constant> && @@ -57,11 +75,7 @@ template T load(const multi_ptr src) { using BlockT = SelectBlockT; - using PtrT = sycl::detail::ConvertToOpenCLType_t< - const multi_ptr>; - - BlockT Ret = - __spirv_SubgroupBlockReadINTEL(reinterpret_cast(src.get())); + BlockT Ret = __spirv_SubgroupBlockReadINTEL(convertToBlockPtr(src)); return sycl::bit_cast(Ret); } @@ -71,11 +85,7 @@ template load(const multi_ptr src) { using BlockT = SelectBlockT; using VecT = sycl::detail::ConvertToOpenCLType_t>; - using PtrT = sycl::detail::ConvertToOpenCLType_t< - const multi_ptr>; - - VecT Ret = - __spirv_SubgroupBlockReadINTEL(reinterpret_cast(src.get())); + VecT Ret = __spirv_SubgroupBlockReadINTEL(convertToBlockPtr(src)); return sycl::bit_cast::vector_t>(Ret); } @@ -84,10 +94,8 @@ template void store(multi_ptr dst, const T &x) { using BlockT = SelectBlockT; - using PtrT = sycl::detail::ConvertToOpenCLType_t< - multi_ptr>; - __spirv_SubgroupBlockWriteINTEL(reinterpret_cast(dst.get()), + __spirv_SubgroupBlockWriteINTEL(convertToBlockPtr(dst), sycl::bit_cast(x)); } @@ -96,10 +104,8 @@ template dst, const vec &x) { using BlockT = SelectBlockT; using VecT = sycl::detail::ConvertToOpenCLType_t>; - using PtrT = sycl::detail::ConvertToOpenCLType_t< - const multi_ptr>; - __spirv_SubgroupBlockWriteINTEL(reinterpret_cast(dst.get()), + __spirv_SubgroupBlockWriteINTEL(convertToBlockPtr(dst), sycl::bit_cast(x)); } #endif // __SYCL_DEVICE_ONLY__