From f88d72e5d75c636b38dc153ad9883253a9fa7e70 Mon Sep 17 00:00:00 2001 From: Udit Agarwal Date: Mon, 10 Jun 2024 14:09:56 -0700 Subject: [PATCH] [NFC][SYCL] Minor refactoring in `sycl::vec<>` (#13949) Follow-up of https://github.com/intel/llvm/pull/13947 Added comments + Rearranged code + Removed redundant MACRO --- sycl/include/sycl/detail/vector_arith.hpp | 17 ++++- sycl/include/sycl/vector_preview.hpp | 80 +++++++++++++---------- 2 files changed, 59 insertions(+), 38 deletions(-) diff --git a/sycl/include/sycl/detail/vector_arith.hpp b/sycl/include/sycl/detail/vector_arith.hpp index fb92a77389d7c..5cc54d383016e 100644 --- a/sycl/include/sycl/detail/vector_arith.hpp +++ b/sycl/include/sycl/detail/vector_arith.hpp @@ -60,7 +60,7 @@ using rel_t = typename std::conditional_t< } else { \ Ret.m_Data = Lhs.m_Data BINOP Rhs.m_Data; \ if constexpr (std::is_same_v && CONVERT) { \ - Ret.ConvertToDataT(); \ + vec_arith_common::ConvertToDataT(Ret); \ } \ } \ return Ret; \ @@ -189,7 +189,7 @@ class vec_arith : public vec_arith_common { } else { Ret = vec_t{-Lhs.m_Data}; if constexpr (std::is_same_v) { - Ret.ConvertToDataT(); + vec_arith_common::ConvertToDataT(Ret); } return Ret; } @@ -391,12 +391,23 @@ template class vec_arith_common { } else { vec_t Ret{(typename vec_t::DataType) ~Rhs.m_Data}; if constexpr (std::is_same_v) { - Ret.ConvertToDataT(); + vec_arith_common::ConvertToDataT(Ret); } return Ret; } } +#ifdef __SYCL_DEVICE_ONLY__ + using vec_bool_t = vec; + // Required only for std::bool. + static void ConvertToDataT(vec_bool_t &Ret) { + for (size_t I = 0; I < NumElements; ++I) { + DataT Tmp = detail::VecAccess::getValue(Ret, I); + detail::VecAccess::setValue(Ret, I, Tmp); + } + } +#endif + // friends template friend class vec; }; diff --git a/sycl/include/sycl/vector_preview.hpp b/sycl/include/sycl/vector_preview.hpp index f70db78e7959a..3d52b297c0ef2 100644 --- a/sycl/include/sycl/vector_preview.hpp +++ b/sycl/include/sycl/vector_preview.hpp @@ -26,10 +26,6 @@ #error "SYCL device compiler is built without ext_vector_type support" #endif -#if defined(__SYCL_DEVICE_ONLY__) -#define __SYCL_USE_EXT_VECTOR_TYPE__ -#endif - #include // for decorated, address_space #include // for half, cl_char, cl_int #include // for ArrayCreator, RepeatV... @@ -47,7 +43,7 @@ #include // bfloat16 #include // for array -#include // for assert +#include // for assert #include // for size_t, NULL, byte #include // for uint8_t, int16_t, int... #include // for divides, multiplies @@ -363,18 +359,30 @@ template using vec_data_t = typename detail::vec_helper::RetType; ///////////////////////// class sycl::vec ///////////////////////// -/// Provides a cross-patform vector class template that works efficiently on -/// SYCL devices as well as in host C++ code. -/// -/// \ingroup sycl_api +// Provides a cross-platform vector class template that works efficiently on +// SYCL devices as well as in host C++ code. template class vec : public detail::vec_arith { using DataT = Type; + // https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#memory-layout-and-alignment + // It is required by the SPEC to align vec with vec. + static constexpr size_t AdjustedNum = (NumElements == 3) ? 4 : NumElements; + // This represent type of underlying value. There should be only one field // in the class, so vec should be equal to float16 in memory. using DataType = typename detail::VecStorage::DataType; +public: +#ifdef __SYCL_DEVICE_ONLY__ + // Type used for passing sycl::vec to SPIRV builtins. + // We can not use ext_vector_type(1) as it's not supported by SPIRV + // plugins (CTS fails). + using vector_t = + typename detail::VecStorage::VectorDataType; +#endif // __SYCL_DEVICE_ONLY__ + +private: static constexpr bool IsHostHalf = std::is_same_v && std::is_same_v { static constexpr bool IsBfloat16 = std::is_same_v; - static constexpr size_t AdjustedNum = (NumElements == 3) ? 4 : NumElements; static constexpr size_t Sz = sizeof(DataT) * AdjustedNum; static constexpr bool IsSizeGreaterThanMaxAlign = (Sz > detail::MaxVecAlignment); @@ -456,6 +463,8 @@ class vec : public detail::vec_arith { } template static constexpr auto FlattenVecArgHelper(const T &A) { + // static_cast required to avoid narrowing conversion warning + // when T = unsigned long int and DataT_ = int. return std::array{vec_data::get(static_cast(A))}; } template struct FlattenVecArg { @@ -551,6 +560,7 @@ class vec : public detail::vec_arith { using EnableIfSuitableNumElements = typename std::enable_if_t::value>; + // Implementation detail for the next public ctor. template constexpr vec(const std::array, NumElements> &Arr, std::index_sequence) @@ -562,14 +572,13 @@ class vec : public detail::vec_arith { })(Arr[Is])...} {} public: + // Aliases required by SPEC to make sycl::vec consistent + // with that of marray and buffer. using element_type = DataT; using value_type = DataT; using rel_t = detail::rel_t; -#ifdef __SYCL_DEVICE_ONLY__ - using vector_t = - typename detail::VecStorage::VectorDataType; -#endif // __SYCL_DEVICE_ONLY__ + /****************** Constructors **************/ vec() = default; constexpr vec(const vec &Rhs) = default; @@ -587,7 +596,7 @@ class vec : public detail::vec_arith { return *this; } -#ifdef __SYCL_USE_EXT_VECTOR_TYPE__ +#ifdef __SYCL_DEVICE_ONLY__ template using EnableIfNotHostHalf = typename std::enable_if_t; @@ -601,7 +610,7 @@ class vec : public detail::vec_arith { template using EnableIfNotUsingArrayOnDevice = typename std::enable_if_t; -#endif // __SYCL_USE_EXT_VECTOR_TYPE__ +#endif // __SYCL_DEVICE_ONLY__ template using EnableIfUsingArray = @@ -612,7 +621,7 @@ class vec : public detail::vec_arith { typename std::enable_if_t; -#ifdef __SYCL_USE_EXT_VECTOR_TYPE__ +#ifdef __SYCL_DEVICE_ONLY__ template explicit constexpr vec(const EnableIfNotUsingArrayOnDevice &arg) @@ -645,12 +654,17 @@ class vec : public detail::vec_arith { } return *this; } -#else // __SYCL_USE_EXT_VECTOR_TYPE__ +#else // __SYCL_DEVICE_ONLY__ explicit constexpr vec(const DataT &arg) : vec{detail::RepeatValue( static_cast>(arg)), std::make_index_sequence()} {} + /****************** Assignment Operators **************/ + + // Template required to prevent ambiguous overload with the copy assignment + // when NumElements == 1. The template prevents implicit conversion from + // vec<_, 1> to DataT. template typename std::enable_if_t< std::is_fundamental_v> || @@ -662,9 +676,9 @@ class vec : public detail::vec_arith { } return *this; } -#endif // __SYCL_USE_EXT_VECTOR_TYPE__ +#endif // __SYCL_DEVICE_ONLY__ -#ifdef __SYCL_USE_EXT_VECTOR_TYPE__ +#ifdef __SYCL_DEVICE_ONLY__ // Optimized naive constructors with NumElements of DataT values. // We don't expect compilers to optimize vararg recursive functions well. @@ -713,7 +727,7 @@ class vec : public detail::vec_arith { vec_data::get(ArgA), vec_data::get(ArgB), vec_data::get(ArgC), vec_data::get(ArgD), vec_data::get(ArgE), vec_data::get(ArgF)} {} -#endif // __SYCL_USE_EXT_VECTOR_TYPE__ +#endif // __SYCL_DEVICE_ONLY__ // Constructor from values of base type or vec of base type. Checks that // base types are match and that the NumElements == sum of lengths of args. @@ -736,6 +750,10 @@ class vec : public detail::vec_arith { } } + /* Available only when: compiled for the device. + * Converts this SYCL vec instance to the underlying backend-native vector + * type defined by vector_t. + */ operator vector_t() const { if constexpr (!IsUsingArrayOnDevice) { return m_Data; @@ -986,17 +1004,9 @@ class vec : public detail::vec_arith { store(Offset, MultiPtr); } - void ConvertToDataT() { - for (size_t i = 0; i < NumElements; ++i) { - DataT tmp = getValue(i); - setValue(i, tmp); - } - } - private: // Generic method that execute "Operation" on underlying values. - -#ifdef __SYCL_USE_EXT_VECTOR_TYPE__ +#ifdef __SYCL_DEVICE_ONLY__ template