diff --git a/sycl/include/sycl/vector_preview.hpp b/sycl/include/sycl/vector_preview.hpp index 8c1660ec1d338..5f1e40a9b09ed 100644 --- a/sycl/include/sycl/vector_preview.hpp +++ b/sycl/include/sycl/vector_preview.hpp @@ -26,10 +26,6 @@ #error "SYCL device compiler is built without ext_vector_type support" #endif -#if defined(__SYCL_DEVICE_ONLY__) -#define __SYCL_USE_EXT_VECTOR_TYPE__ -#endif - #include // for decorated, address_space #include // for half, cl_char, cl_int #include // for ArrayCreator, RepeatV... @@ -46,7 +42,7 @@ #include // bfloat16 #include // for array -#include // for assert +#include // for assert #include // for size_t, NULL, byte #include // for uint8_t, int16_t, int... #include // for divides, multiplies @@ -354,17 +350,27 @@ template using vec_data_t = typename detail::vec_helper::RetType; ///////////////////////// class sycl::vec ///////////////////////// -/// Provides a cross-patform vector class template that works efficiently on -/// SYCL devices as well as in host C++ code. -/// -/// \ingroup sycl_api +// Provides a cross-patform vector class template that works efficiently on +// SYCL devices as well as in host C++ code. template class vec { using DataT = Type; + static constexpr size_t AdjustedNum = (NumElements == 3) ? 4 : NumElements; + // This represent type of underlying value. There should be only one field // in the class, so vec should be equal to float16 in memory. using DataType = typename detail::VecStorage::DataType; +public: +#ifdef __SYCL_DEVICE_ONLY__ + // Type used for passing sycl::vec to SPIRV builtins. + // We can not use ext_vector_type(1) as it's not supported by SPIRV + // plugins (CTS fails). + using vector_t = + typename detail::VecStorage::VectorDataType; +#endif // __SYCL_DEVICE_ONLY__ + +private: static constexpr bool IsHostHalf = std::is_same_v && std::is_same_v class vec { static constexpr bool IsBfloat16 = std::is_same_v; - static constexpr size_t AdjustedNum = (NumElements == 3) ? 4 : NumElements; static constexpr size_t Sz = sizeof(DataT) * AdjustedNum; static constexpr bool IsSizeGreaterThanMaxAlign = (Sz > detail::MaxVecAlignment); @@ -446,6 +451,8 @@ template class vec { } template static constexpr auto FlattenVecArgHelper(const T &A) { + // static_cast required to avoid narrowing conversion warning + // when T = unsigned long int and DataT_ = int. return std::array{vec_data::get(static_cast(A))}; } template struct FlattenVecArg { @@ -541,6 +548,7 @@ template class vec { using EnableIfSuitableNumElements = typename std::enable_if_t::value>; + // Implementation detail for the next public ctor. template constexpr vec(const std::array, NumElements> &Arr, std::index_sequence) @@ -552,14 +560,13 @@ template class vec { })(Arr[Is])...} {} public: + // Aliases required by SPEC to make sycl::vec consistent + // with that of marray and buffer. using element_type = DataT; using value_type = DataT; using rel_t = detail::rel_t; -#ifdef __SYCL_DEVICE_ONLY__ - using vector_t = - typename detail::VecStorage::VectorDataType; -#endif // __SYCL_DEVICE_ONLY__ + /****************** Constructors **************/ vec() = default; constexpr vec(const vec &Rhs) = default; @@ -577,7 +584,7 @@ template class vec { return *this; } -#ifdef __SYCL_USE_EXT_VECTOR_TYPE__ +#ifdef __SYCL_DEVICE_ONLY__ template using EnableIfNotHostHalf = typename std::enable_if_t; @@ -591,7 +598,7 @@ template class vec { template using EnableIfNotUsingArrayOnDevice = typename std::enable_if_t; -#endif // __SYCL_USE_EXT_VECTOR_TYPE__ +#endif // __SYCL_DEVICE_ONLY__ template using EnableIfUsingArray = @@ -602,7 +609,7 @@ template class vec { typename std::enable_if_t; -#ifdef __SYCL_USE_EXT_VECTOR_TYPE__ +#ifdef __SYCL_DEVICE_ONLY__ template explicit constexpr vec(const EnableIfNotUsingArrayOnDevice &arg) @@ -635,12 +642,15 @@ template class vec { } return *this; } -#else // __SYCL_USE_EXT_VECTOR_TYPE__ +#else // __SYCL_DEVICE_ONLY__ explicit constexpr vec(const DataT &arg) : vec{detail::RepeatValue( static_cast>(arg)), std::make_index_sequence()} {} + // Template required to prevent ambiguous overload with the copy assignment + // when NumElements == 1. The template prevents implicit conversion from + // vec<_, 1> to DataT. template typename std::enable_if_t< std::is_fundamental_v> || @@ -652,9 +662,9 @@ template class vec { } return *this; } -#endif // __SYCL_USE_EXT_VECTOR_TYPE__ +#endif // __SYCL_DEVICE_ONLY__ -#ifdef __SYCL_USE_EXT_VECTOR_TYPE__ +#ifdef __SYCL_DEVICE_ONLY__ // Optimized naive constructors with NumElements of DataT values. // We don't expect compilers to optimize vararg recursive functions well. @@ -703,7 +713,7 @@ template class vec { vec_data::get(ArgA), vec_data::get(ArgB), vec_data::get(ArgC), vec_data::get(ArgD), vec_data::get(ArgE), vec_data::get(ArgF)} {} -#endif // __SYCL_USE_EXT_VECTOR_TYPE__ +#endif // __SYCL_DEVICE_ONLY__ // Constructor from values of base type or vec of base type. Checks that // base types are match and that the NumElements == sum of lengths of args. @@ -726,6 +736,11 @@ template class vec { } } + /* @SYCL2020 + * Available only when: compiled for the device. + * Converts this SYCL vec instance to the underlying backend-native vector + * type defined by vector_t. + */ operator vector_t() const { if constexpr (!IsUsingArrayOnDevice) { return m_Data; @@ -976,12 +991,17 @@ template class vec { store(Offset, MultiPtr); } +#ifdef __SYCL_DEVICE_ONLY__ + // Require only for std::bool. void ConvertToDataT() { for (size_t i = 0; i < NumElements; ++i) { DataT tmp = getValue(i); setValue(i, tmp); } } +#endif + + /******************* sycl::vec math operations ***********************/ #if defined(__SYCL_BINOP) || defined(BINOP_BASE) #error "Undefine __SYCL_BINOP and BINOP_BASE macro" @@ -1222,11 +1242,13 @@ template class vec { } return Ret; } else { +#ifdef __SYCL_DEVICE_ONLY__ vec Ret{(typename vec::DataType) ~Rhs.m_Data}; if constexpr (std::is_same_v) { Ret.ConvertToDataT(); } return Ret; +#endif // __SYCL_DEVICE_ONLY__ } } @@ -1293,11 +1315,13 @@ template class vec { I, vec_data::get(-vec_data::get(Lhs.getValue(I)))); return Ret; } else { +#ifdef __SYCL_DEVICE_ONLY__ Ret = vec{-Lhs.m_Data}; if constexpr (std::is_same_v) { Ret.ConvertToDataT(); } return Ret; +#endif // __SYCL_DEVICE_ONLY__ } } @@ -1311,7 +1335,7 @@ template class vec { private: // Generic method that execute "Operation" on underlying values. -#ifdef __SYCL_USE_EXT_VECTOR_TYPE__ +#ifdef __SYCL_DEVICE_ONLY__ template