From aa015f33153756cc7c1e10d24abcb2a73830ac27 Mon Sep 17 00:00:00 2001 From: aelovikov-intel Date: Thu, 15 Feb 2024 20:42:13 -0800 Subject: [PATCH] [NFCI] More convertToOpenCLType-related simplifications (#12717) --- sycl/include/sycl/builtins_preview.hpp | 23 +- .../sycl/detail/generic_type_traits.hpp | 326 ++++-------------- sycl/include/sycl/detail/type_traits.hpp | 6 + sycl/source/detail/builtins_relational.cpp | 2 +- sycl/test/basic_tests/generic_type_traits.cpp | 103 +++--- 5 files changed, 141 insertions(+), 319 deletions(-) diff --git a/sycl/include/sycl/builtins_preview.hpp b/sycl/include/sycl/builtins_preview.hpp index 03c1c401fbcd4..b75b0e5264711 100644 --- a/sycl/include/sycl/builtins_preview.hpp +++ b/sycl/include/sycl/builtins_preview.hpp @@ -95,28 +95,13 @@ template auto convert_arg(T &&x) { __attribute__((ext_vector_type(N)))>; // TODO: We should have this bit_cast impl inside vec::convert. return bit_cast(static_cast(x)); - } else if constexpr (std::is_same_v) - return static_cast(x); - else if constexpr (is_multi_ptr_v) { - return convert_arg(x.get_decorated()); - } else if constexpr (is_scalar_arithmetic_v) { - // E.g. on linux: long long -> int64_t (long), or char -> int8_t (signed - // char) and same for unsigned; Windows has long/long long reversed. - // TODO: Inline this scalar impl. - return static_cast>(x); - } else if constexpr (std::is_pointer_v) { - using elem_type = remove_decoration_t>; - using converted_elem_type = - decltype(convert_arg(std::declval())); - using result_type = - typename DecoratedType::value>::type *; - return reinterpret_cast(x); } else if constexpr (is_swizzle_v) { return convert_arg(simplify_if_swizzle_t{x}); } else { - // TODO: should it be unreachable? What can it be? - return std::forward(x); + static_assert(is_scalar_arithmetic_v || + is_multi_ptr_v || std::is_pointer_v || + std::is_same_v); + return convertToOpenCLType(std::forward(x)); } } diff --git a/sycl/include/sycl/detail/generic_type_traits.hpp b/sycl/include/sycl/detail/generic_type_traits.hpp index bd59361ff2eeb..3bc68243859b5 100644 --- a/sycl/include/sycl/detail/generic_type_traits.hpp +++ b/sycl/include/sycl/detail/generic_type_traits.hpp @@ -433,72 +433,6 @@ template class TryToGetElementType { static constexpr bool value = !std::is_same_v; }; -template struct PointerConverter { - template static To Convert(From *t) { - return reinterpret_cast(t); - } - - template static To Convert(From &t) { - if constexpr (is_non_legacy_multi_ptr_v) { - return detail::cast_AS(t.get_decorated()); - } else if constexpr (is_legacy_multi_ptr_v) { - return detail::cast_AS(t.get()); - } else { - // TODO find the better way to get the pointer to underlying data from vec - // class - return reinterpret_cast(t.get()); - } - } -}; - -template -struct PointerConverter> { - template - static multi_ptr Convert(From *t) { - return address_space_cast( - reinterpret_cast>(t)); - } - - template - static multi_ptr Convert(From &t) { - return address_space_cast( - reinterpret_cast>(t.get())); - } - - template - static multi_ptr - Convert(multi_ptr &t) { - return t; - } -}; - -template To ConvertNonVectorType(From &t) { - if constexpr (is_pointer_v) - return PointerConverter::Convert(t); - else - return static_cast(t); -} - -template struct mptr_or_vec_elem_type { - using type = typename T::element_type; -}; -template -struct mptr_or_vec_elem_type< - multi_ptr, - std::enable_if_t> { - using type = typename multi_ptr::value_type; -}; -template -struct mptr_or_vec_elem_type> - : mptr_or_vec_elem_type> {}; - -template -using mptr_or_vec_elem_type_t = typename mptr_or_vec_elem_type::type; - // select_apply_cl_scalar_t selects from T8/T16/T32/T64 basing on // sizeof(IN). expected to handle scalar types. template @@ -518,14 +452,11 @@ using select_cl_scalar_integral_unsigned_t = select_apply_cl_scalar_t; -template -using select_cl_scalar_float_t = - select_apply_cl_scalar_t; - // Use SFINAE so that std::complex specialization could be implemented in // include/sycl/stl_wrappers/complex that would only be available if STL's -// is included by users. +// is included by users. Note that "function template partial +// specialization" is not allowed, so we cannot perform that trick on +// convertToOpenCLType function directly. template struct select_cl_scalar_complex_or_T { using type = T; }; @@ -534,179 +465,6 @@ template using select_cl_scalar_complex_or_T_t = typename select_cl_scalar_complex_or_T::type; -template -using select_cl_scalar_integral_t = - std::conditional_t, - select_cl_scalar_integral_signed_t, - select_cl_scalar_integral_unsigned_t>; - -// select_cl_scalar_t picks corresponding cl_* type for input -// scalar T or returns T if T is not scalar. -template -using select_cl_scalar_t = std::conditional_t< - std::is_integral_v, select_cl_scalar_integral_t, - std::conditional_t< - std::is_floating_point_v, select_cl_scalar_float_t, - // half and bfloat16 are special cases: they are implemented differently - // on host and device and therefore might lower to different types - std::conditional_t< - is_half_v, sycl::detail::half_impl::BIsRepresentationT, - std::conditional_t, T, - select_cl_scalar_complex_or_T_t>>>>; - -// select_cl_vector_or_scalar_or_ptr does cl_* type selection for element type -// of a vector type T, pointer type substitution, and scalar type substitution. -// If T is not vector, scalar, or pointer unmodified T is returned. -template -struct select_cl_vector_or_scalar_or_ptr; - -template -struct select_cl_vector_or_scalar_or_ptr< - T, typename std::enable_if_t>> { - using type = - // select_cl_scalar_t may return _Float16, so, we try to instantiate vec - // class with _Float16 DataType, which is not expected there - // So, leave vector as-is - vec>, - mptr_or_vec_elem_type_t, - select_cl_scalar_t>>, - T::size()>; -}; - -template -struct select_cl_vector_or_scalar_or_ptr< - T, typename std::enable_if_t && !std::is_pointer_v>> { - using type = select_cl_scalar_t; -}; - -template -struct select_cl_vector_or_scalar_or_ptr< - T, typename std::enable_if_t && std::is_pointer_v>> { - using elem_ptr_type = - typename select_cl_vector_or_scalar_or_ptr>::type - *; -#ifdef __SYCL_DEVICE_ONLY__ - using type = typename DecoratedType::value>::type; -#else - using type = elem_ptr_type; -#endif -}; - -// select_cl_mptr_or_vector_or_scalar_or_ptr does cl_* type selection for type -// pointed by multi_ptr, for raw pointers, for element type of a vector type T, -// and does scalar type substitution. If T is not mutlti_ptr or vector or -// scalar or pointer unmodified T is returned. -template -struct select_cl_mptr_or_vector_or_scalar_or_ptr; - -// this struct helps to use std::uint8_t instead of std::byte, -// which is not supported on device -template struct TypeHelper { - using RetType = T; -}; - -#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) -template <> struct TypeHelper { - using RetType = std::uint8_t; -}; -#endif - -template struct TypeHelper { - using RetType = const typename TypeHelper::RetType; -}; - -template struct TypeHelper { - using RetType = volatile typename TypeHelper::RetType; -}; - -template struct TypeHelper { - using RetType = const volatile typename TypeHelper::RetType; -}; - -template using type_helper = typename TypeHelper::RetType; - -template -struct select_cl_mptr_or_vector_or_scalar_or_ptr< - T, typename std::enable_if_t && !std::is_pointer_v>> { - using type = multi_ptr>>::type, - T::address_space, access::decorated::yes>; -}; - -template -struct select_cl_mptr_or_vector_or_scalar_or_ptr< - T, typename std::enable_if_t || std::is_pointer_v>> { - using type = typename select_cl_vector_or_scalar_or_ptr::type; -}; - -// All types converting shortcut. -template -using SelectMatchingOpenCLType_t = - typename select_cl_mptr_or_vector_or_scalar_or_ptr::type; - -// Converts T to OpenCL friendly -// -template struct ConvertToOpenCLTypeImpl { - using type = T; -}; - -#ifdef __SYCL_DEVICE_ONLY__ -template -struct ConvertToOpenCLTypeImpl> { - using type = typename vec::vector_t; -}; - -template struct ConvertToOpenCLTypeImpl> { - using type = typename Boolean::vector_t; -}; -template <> struct ConvertToOpenCLTypeImpl> { - // Or should it be "int"? - using type = Boolean<1>; -}; -#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) -// TODO: It seems we only use this to convert a pointer's element type. As such, -// although it doesn't look very clean, it should be ok having this case handled -// explicitly until further refactoring of this area. -template <> struct ConvertToOpenCLTypeImpl { - using type = uint8_t; -}; -#endif -#endif - -template struct ConvertToOpenCLTypeImpl { -#ifdef __SYCL_DEVICE_ONLY__ - using type = typename DecoratedType< - typename ConvertToOpenCLTypeImpl>::type, - deduce_AS::value>::type *; -#else - using type = typename ConvertToOpenCLTypeImpl::type *; -#endif -}; - -template -struct ConvertToOpenCLTypeImpl> { - using type = typename DecoratedType< - typename ConvertToOpenCLTypeImpl::type, Space>::type *; -}; - -template -using ConvertToOpenCLType_t = - typename ConvertToOpenCLTypeImpl>::type; - -// convertDataToType() function converts data from FROM type to TO type using -// 'as' method for vector type and copy otherwise. -template -typename std::enable_if_t -convertDataToType(FROM t) { - if constexpr (is_vgentype_v && is_vgentype_v) - return t.template as(); - else - return ConvertNonVectorType(t); -} - -// Now fuse the above into a simpler helper that's easy to use. -// TODO: That should probably be moved outside of "type_traits". template auto convertToOpenCLType(T &&x) { using no_ref = std::remove_reference_t; if constexpr (is_multi_ptr_v) { @@ -714,8 +472,8 @@ template auto convertToOpenCLType(T &&x) { } else if constexpr (std::is_pointer_v) { // TODO: Below ignores volatile, but we didn't have a need for it yet. using elem_type = remove_decoration_t>; - using converted_elem_type_no_cv = - ConvertToOpenCLType_t>; + using converted_elem_type_no_cv = decltype(convertToOpenCLType( + std::declval>())); using converted_elem_type = std::conditional_t, const converted_elem_type_no_cv, @@ -728,21 +486,85 @@ template auto convertToOpenCLType(T &&x) { using result_type = converted_elem_type *; #endif return reinterpret_cast(x); + } else if constexpr (is_vec_v) { + using ElemTy = typename no_ref::element_type; + // sycl::half may convert to _Float16, and we would try to instantiate + // vec class with _Float16 DataType, which is not expected there. As + // such, leave vector as-is. + using MatchingVec = vec, ElemTy, + decltype(convertToOpenCLType( + std::declval()))>, + no_ref::size()>; +#ifdef __SYCL_DEVICE_ONLY__ + // TODO: for some mysterious reasons on NonUniformGroups E2E tests fail if + // we use the "else" version only. I suspect that's an issues with + // non-uniform groups implementation. + if constexpr (std::is_same_v) + return static_cast(x); + else + return static_cast( + x.template as()); +#else + return x.template as(); +#endif + } else if constexpr (is_boolean_v) { +#ifdef __SYCL_DEVICE_ONLY__ + if constexpr (std::is_same_v, no_ref>) { + // Or should it be "int"? + return std::forward(x); + } else { + return static_cast(x); + } +#else + return std::forward(x); +#endif +#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) + } else if constexpr (std::is_same_v) { + return static_cast(x); +#endif + } else if constexpr (std::is_integral_v) { + using OpenCLType = + std::conditional_t, + select_cl_scalar_integral_signed_t, + select_cl_scalar_integral_unsigned_t>; + static_assert(sizeof(OpenCLType) == sizeof(T)); + return static_cast(x); + } else if constexpr (is_half_v) { + using OpenCLType = sycl::detail::half_impl::BIsRepresentationT; + static_assert(sizeof(OpenCLType) == sizeof(T)); + return static_cast(x); + } else if constexpr (is_bfloat16_v) { + return std::forward(x); + } else if constexpr (std::is_floating_point_v) { + static_assert(std::is_same_v || + std::is_same_v, + "Other FP types are not expected/supported (yet?)"); + static_assert(std::is_same_v && + std::is_same_v); + return std::forward(x); } else { - using OpenCLType = ConvertToOpenCLType_t; - return convertDataToType(std::forward(x)); + using OpenCLType = select_cl_scalar_complex_or_T_t; + static_assert(sizeof(OpenCLType) == sizeof(T)); + return static_cast(x); } } +template +using ConvertToOpenCLType_t = decltype(convertToOpenCLType(std::declval())); + template auto convertFromOpenCLTypeFor(From &&x) { if constexpr (std::is_same_v && std::is_same_v, bool>) { // FIXME: Something seems to be wrong elsewhere... return x; } else { - static_assert(std::is_same_v, - ConvertToOpenCLType_t>); - return convertDataToType(std::forward(x)); + using OpenCLType = decltype(convertToOpenCLType(std::declval())); + static_assert(std::is_same_v, OpenCLType>); + static_assert(sizeof(OpenCLType) == sizeof(To)); + if constexpr (is_vec_v && is_vec_v) + return x.template as(); + else + return static_cast(x); } } diff --git a/sycl/include/sycl/detail/type_traits.hpp b/sycl/include/sycl/detail/type_traits.hpp index 5d698f206f56c..df4ab8f37e17a 100644 --- a/sycl/include/sycl/detail/type_traits.hpp +++ b/sycl/include/sycl/detail/type_traits.hpp @@ -342,6 +342,12 @@ template struct is_bool : std::bool_constant>::value> {}; +// is_boolean +template struct Boolean; +template struct is_boolean : std::false_type {}; +template struct is_boolean> : std::true_type {}; +template inline constexpr bool is_boolean_v = is_boolean::value; + // is_pointer template struct is_pointer_impl : std::false_type {}; diff --git a/sycl/source/detail/builtins_relational.cpp b/sycl/source/detail/builtins_relational.cpp index 92fc579c3c0c6..407aad6c8a77f 100644 --- a/sycl/source/detail/builtins_relational.cpp +++ b/sycl/source/detail/builtins_relational.cpp @@ -147,7 +147,7 @@ template inline T2 __vselect(T2 a, T2 b, T c) { // ---------- 4.13.7 Relational functions. Host implementations. --------------- -using rel_res_t = d::select_cl_scalar_t; +using rel_res_t = d::ConvertToOpenCLType_t; // FOrdEqual-isequal __SYCL_EXPORT rel_res_t sycl_host_FOrdEqual(s::cl_float x, diff --git a/sycl/test/basic_tests/generic_type_traits.cpp b/sycl/test/basic_tests/generic_type_traits.cpp index 7805adec45789..10c6c803da6a2 100644 --- a/sycl/test/basic_tests/generic_type_traits.cpp +++ b/sycl/test/basic_tests/generic_type_traits.cpp @@ -164,79 +164,88 @@ int main() { */ // checks for some type conversions. - static_assert(std::is_same, - s::opencl::cl_int>::value); + static_assert(std::is_same_v, + s::opencl::cl_int>); +#ifdef __SYCL_DEVICE_ONLY__ static_assert( - std::is_same>, - s::vec>::value); + std::is_same_v>, + s::opencl::cl_int __attribute__((ext_vector_type(2)))>); - static_assert(std::is_same< - d::SelectMatchingOpenCLType_t>, - s::multi_ptr>::value); + __attribute__((opencl_global)) s::opencl::cl_int *>); + using int_vec2 = s::opencl::cl_int __attribute__((ext_vector_type(2))); static_assert( - std::is_same, - s::access::address_space::global_space, - s::access::decorated::yes>>, - s::multi_ptr, - s::access::address_space::global_space, - s::access::decorated::yes>>::value); + std::is_same_v, + s::access::address_space::global_space, + s::access::decorated::yes>>, + __attribute__((opencl_global)) int_vec2 *>); - static_assert(std::is_same, - s::opencl::cl_long>::value); + static_assert( + std::is_same_v, s::opencl::cl_long>); static_assert( - std::is_same>, - s::vec>::value); + std::is_same_v>, + s::opencl::cl_long __attribute__((ext_vector_type(2)))>); static_assert( - std::is_same>, - s::multi_ptr>::value); + std::is_same_v>, + __attribute__((opencl_global)) s::opencl::cl_long *>); + using long_vec2 = s::opencl::cl_long __attribute__((ext_vector_type(2))); static_assert( - std::is_same< - d::SelectMatchingOpenCLType_t, s::access::address_space::global_space, s::access::decorated::yes>>, - s::multi_ptr, - s::access::address_space::global_space, - s::access::decorated::yes>>::value); + __attribute__((opencl_global)) long_vec2 *>); + + using signed_char2 = s::opencl::cl_char __attribute__((ext_vector_type(2))); + static_assert(std::is_same_v< + d::ConvertToOpenCLType_t, s::access::address_space::global_space, + s::access::decorated::yes>>, + __attribute__((opencl_global)) signed_char2 *>); + static_assert( + std::is_same_v< + d::ConvertToOpenCLType_t, s::access::address_space::global_space, + s::access::decorated::yes>>, + __attribute__((opencl_global)) signed_char2 *>); + +#endif #ifdef __SYCL_DEVICE_ONLY__ static_assert( - std::is_same>, - s::vec::vector_t>::value); - static_assert(std::is_same>, - s::vec::vector_t>::value); - static_assert(std::is_same< + std::is_same_v>, + s::vec::vector_t>); + static_assert(std::is_same_v>, + s::vec::vector_t>); + static_assert(std::is_same_v< d::ConvertToOpenCLType_t>, s::multi_ptr::pointer>::value); + s::access::decorated::yes>::pointer>); static_assert( - std::is_same, - s::access::address_space::global_space, - s::access::decorated::yes>>, - s::multi_ptr::vector_t, - s::access::address_space::global_space, - s::access::decorated::yes>::pointer>::value); + std::is_same_v, + s::access::address_space::global_space, + s::access::decorated::yes>>, + s::multi_ptr::vector_t, + s::access::address_space::global_space, + s::access::decorated::yes>::pointer>); #endif - static_assert(std::is_same, - d::half_impl::BIsRepresentationT>::value, - ""); + static_assert(std::is_same_v, + d::half_impl::BIsRepresentationT>); s::multi_ptr