diff --git a/sycl/include/sycl/vector_preview.hpp b/sycl/include/sycl/vector_preview.hpp index 2f26ab47cb079..1cdfd1ba1fad6 100644 --- a/sycl/include/sycl/vector_preview.hpp +++ b/sycl/include/sycl/vector_preview.hpp @@ -696,12 +696,191 @@ template struct LShift { } }; +// Some assignment operators on swizzle are dependent on how many indices are +// present. This base class represents the two variants: one index and multiple +// indices. +template +class SwizzleOpDependentAssignmentOperatorBase; + +// Specialization for a swizzles with a single index. +template class OperationCurrentT, int Index> +class SwizzleOpDependentAssignmentOperatorBase> { + using SwizzleOpParentT = SwizzleOp; + + using DataT = typename VecT::element_type; + +public: + const SwizzleOpDependentAssignmentOperatorBase &operator=(DataT &&Rhs) const { + const SwizzleOpParentT &Self = *static_cast(this); + (*Self.m_Vector)[Index] = Rhs; + return *this; + } +}; + +// Specialization for a swizzles with multiple indices index. +template class OperationCurrentT, int Index0, + int... IndexRest> +class SwizzleOpDependentAssignmentOperatorBase< + SwizzleOp> { + using SwizzleOpParentT = SwizzleOp; + + using DataT = typename VecT::element_type; + +public: + const SwizzleOpDependentAssignmentOperatorBase & + operator=(const vec &Rhs) const { + const SwizzleOpParentT &Self = *static_cast(this); + for (size_t I = 0; I < Self.Idxs.size(); ++I) { + (*Self.m_Vector)[Self.Idxs[I]] = Rhs[I]; + } + return *this; + } +}; + +// Base class for mutable swizzles to mask out mutating operators. +template +class SwizzleOpMutatingOperatorBase { +public: + SwizzleOpMutatingOperatorBase & + operator=(const SwizzleOpMutatingOperatorBase &) = delete; +}; + +// Specialization for mutable swizzles. +template class OperationCurrentT, int... Indexes> +class SwizzleOpMutatingOperatorBase< + SwizzleOp, + std::enable_if_t>> + : public SwizzleOpDependentAssignmentOperatorBase< + SwizzleOp> { + using SwizzleOpParentT = SwizzleOp; + + static constexpr int getNumElements() { return sizeof...(Indexes); } + static constexpr std::array Idxs{Indexes...}; + + using DataT = typename VecT::element_type; + using vec_t = vec; + +public: + // TODO: Check that Rhs arg is suitable. +#ifdef __SYCL_OPASSIGN +#error "Undefine __SYCL_OPASSIGN macro." +#endif +#define __SYCL_OPASSIGN(OPASSIGN, OP) \ + friend const SwizzleOpParentT &operator OPASSIGN( \ + const SwizzleOpParentT & Lhs, const DataT & Rhs) { \ + Lhs.template operatorHelper(vec_t(Rhs)); \ + return Lhs; \ + } \ + template \ + friend const SwizzleOpParentT &operator OPASSIGN( \ + const SwizzleOpParentT & Lhs, const RhsOperation & Rhs) { \ + Lhs.template operatorHelper(Rhs); \ + return Lhs; \ + } \ + friend const SwizzleOpParentT &operator OPASSIGN( \ + const SwizzleOpParentT & Lhs, const vec_t & Rhs) { \ + Lhs.template operatorHelper(Rhs); \ + return Lhs; \ + } + + __SYCL_OPASSIGN(+=, std::plus) + __SYCL_OPASSIGN(-=, std::minus) + __SYCL_OPASSIGN(*=, std::multiplies) + __SYCL_OPASSIGN(/=, std::divides) + __SYCL_OPASSIGN(%=, std::modulus) + __SYCL_OPASSIGN(&=, std::bit_and) + __SYCL_OPASSIGN(|=, std::bit_or) + __SYCL_OPASSIGN(^=, std::bit_xor) + __SYCL_OPASSIGN(>>=, RShift) + __SYCL_OPASSIGN(<<=, LShift) +#undef __SYCL_OPASSIGN + +#ifdef __SYCL_UOP +#error "Undefine __SYCL_UOP macro" +#endif +#define __SYCL_UOP(UOP, OPASSIGN) \ + friend const SwizzleOpParentT &operator UOP(const SwizzleOpParentT & sv) { \ + sv OPASSIGN static_cast(1); \ + return sv; \ + } \ + friend vec_t operator UOP(const SwizzleOpParentT &sv, int) { \ + vec_t Ret = sv; \ + sv OPASSIGN static_cast(1); \ + return Ret; \ + } + + __SYCL_UOP(++, +=) + __SYCL_UOP(--, -=) +#undef __SYCL_UOP + + using SwizzleOpDependentAssignmentOperatorBase::operator=; + + const SwizzleOpMutatingOperatorBase &operator=(const DataT &Rhs) const { + const SwizzleOpParentT &Self = *static_cast(this); + if constexpr (SwizzleOpParentT::getNumElements() == 1) { + (*Self.m_Vector)[Self.Idxs[0]] = Rhs; + } else { + for (auto Idx : Self.Idxs) + (*Self.m_Vector)[Idx] = Rhs; + } + return *this; + } + + template < + typename T1, typename T2, typename T3, template class T4, + int... T5, + typename = typename std::enable_if_t> + const SwizzleOpMutatingOperatorBase & + operator=(const SwizzleOp &Rhs) const { + const SwizzleOpParentT &Self = *static_cast(this); + for (size_t I = 0; I < Self.Idxs.size(); ++I) { + (*Self.m_Vector)[Self.Idxs[I]] = Rhs.getValue(I); + } + return *this; + } + + template < + typename T1, typename T2, typename T3, template class T4, + int... T5, + typename = typename std::enable_if_t> + const SwizzleOpMutatingOperatorBase & + operator=(SwizzleOp &&Rhs) const { + const SwizzleOpParentT &Self = *static_cast(this); + for (size_t I = 0; I < Self.Idxs.size(); ++I) { + (*Self.m_Vector)[Self.Idxs[I]] = Rhs.getValue(I); + } + return *this; + } + + // Leave store() interface to automatic conversion to vec<>. + // Load to vec_t and then assign to swizzle. + template + void load(size_t offset, multi_ptr ptr) const { + const SwizzleOpParentT &Self = *static_cast(this); + vec_t Tmp; + Tmp.load(offset, ptr); + Self = Tmp; + } +}; + ///////////////////////// class SwizzleOp ///////////////////////// // SwizzleOP represents expression templates that operate on vec. // Actual computation performed on conversion or assignment operators. template class OperationCurrentT, int... Indexes> -class SwizzleOp { +class SwizzleOp : public SwizzleOpMutatingOperatorBase< + SwizzleOp> { using DataT = typename VecT::element_type; // Certain operators return a vector with a different element type. Also, the // left and right operand types may differ. CommonDataT selects a result type @@ -736,6 +915,8 @@ class SwizzleOp { DataT, std::common_type_t>; static constexpr int getNumElements() { return sizeof...(Indexes); } + static constexpr std::array Idxs{Indexes...}; + using rel_t = detail::rel_t; using vec_t = vec; using vec_rel_t = vec; @@ -760,14 +941,20 @@ class SwizzleOp { SwizzleOp, OperationCurrentT_, Idx_...>; + template + static constexpr bool HasOneIndex = + 1 == IdxNum && SwizzleOp::getNumElements() == IdxNum; template - using EnableIfOneIndex = typename std::enable_if_t< - 1 == IdxNum && SwizzleOp::getNumElements() == IdxNum, T>; + using EnableIfOneIndex = typename std::enable_if_t, T>; + + template + static constexpr bool HasMultipleIndices = + 1 != IdxNum && SwizzleOp::getNumElements() == IdxNum; template - using EnableIfMultipleIndexes = typename std::enable_if_t< - 1 != IdxNum && SwizzleOp::getNumElements() == IdxNum, T>; + using EnableIfMultipleIndexes = + typename std::enable_if_t, T>; template using EnableIfScalarType = typename std::enable_if_t< @@ -797,14 +984,16 @@ class SwizzleOp { using vector_t = typename vec_t::vector_t; #endif // __SYCL_DEVICE_ONLY__ + using SwizzleOpMutatingOperatorBase< + SwizzleOp>::operator=; + const DataT &operator[](int i) const { - std::array Idxs{Indexes...}; return (*m_Vector)[Idxs[i]]; } template std::enable_if_t, DataT> &operator[](int i) { - std::array Idxs{Indexes...}; return (*m_Vector)[Idxs[i]]; } @@ -851,58 +1040,6 @@ class SwizzleOp { Rhs.m_Vector, GetScalarOp(Lhs), Rhs); } - // TODO: Check that Rhs arg is suitable. -#ifdef __SYCL_OPASSIGN -#error "Undefine __SYCL_OPASSIGN macro." -#endif -#define __SYCL_OPASSIGN(OPASSIGN, OP) \ - friend const SwizzleOp &operator OPASSIGN(const SwizzleOp & Lhs, \ - const DataT & Rhs) { \ - Lhs.operatorHelper(vec_t(Rhs)); \ - return Lhs; \ - } \ - template \ - friend const SwizzleOp &operator OPASSIGN(const SwizzleOp & Lhs, \ - const RhsOperation & Rhs) { \ - Lhs.operatorHelper(Rhs); \ - return Lhs; \ - } \ - friend const SwizzleOp &operator OPASSIGN(const SwizzleOp & Lhs, \ - const vec_t & Rhs) { \ - Lhs.operatorHelper(Rhs); \ - return Lhs; \ - } - - __SYCL_OPASSIGN(+=, std::plus) - __SYCL_OPASSIGN(-=, std::minus) - __SYCL_OPASSIGN(*=, std::multiplies) - __SYCL_OPASSIGN(/=, std::divides) - __SYCL_OPASSIGN(%=, std::modulus) - __SYCL_OPASSIGN(&=, std::bit_and) - __SYCL_OPASSIGN(|=, std::bit_or) - __SYCL_OPASSIGN(^=, std::bit_xor) - __SYCL_OPASSIGN(>>=, RShift) - __SYCL_OPASSIGN(<<=, LShift) -#undef __SYCL_OPASSIGN - -#ifdef __SYCL_UOP -#error "Undefine __SYCL_UOP macro" -#endif -#define __SYCL_UOP(UOP, OPASSIGN) \ - friend const SwizzleOp &operator UOP(const SwizzleOp & sv) { \ - sv OPASSIGN static_cast(1); \ - return sv; \ - } \ - friend vec_t operator UOP(const SwizzleOp &sv, int) { \ - vec_t Ret = sv; \ - sv OPASSIGN static_cast(1); \ - return Ret; \ - } - - __SYCL_UOP(++, +=) - __SYCL_UOP(--, -=) -#undef __SYCL_UOP - template friend typename std::enable_if_t< std::is_same_v && !detail::is_vgenfloat_v, vec_t> @@ -1027,40 +1164,6 @@ class SwizzleOp { __SYCL_RELLOGOP(||, (!detail::is_byte_v && !detail::is_vgenfloat_v)) #undef __SYCL_RELLOGOP - template > - SwizzleOp &operator=(const vec &Rhs) { - std::array Idxs{Indexes...}; - for (size_t I = 0; I < Idxs.size(); ++I) { - (*m_Vector)[Idxs[I]] = Rhs[I]; - } - return *this; - } - - template > - SwizzleOp &operator=(const DataT &Rhs) { - std::array Idxs{Indexes...}; - (*m_Vector)[Idxs[0]] = Rhs; - return *this; - } - - template = true> - SwizzleOp &operator=(const DataT &Rhs) { - std::array Idxs{Indexes...}; - for (auto Idx : Idxs) { - (*m_Vector)[Idx] = Rhs; - } - return *this; - } - - template > - SwizzleOp &operator=(DataT &&Rhs) { - std::array Idxs{Indexes...}; - (*m_Vector)[Idxs[0]] = Rhs; - return *this; - } - template > NewLHOp, std::multiplies, Indexes...> operator*(const T &Rhs) const { @@ -1203,30 +1306,6 @@ class SwizzleOp { return NewLHOp(m_Vector, *this, Rhs); } - template < - typename T1, typename T2, typename T3, template class T4, - int... T5, - typename = typename std::enable_if_t> - SwizzleOp &operator=(const SwizzleOp &Rhs) { - std::array Idxs{Indexes...}; - for (size_t I = 0; I < Idxs.size(); ++I) { - (*m_Vector)[Idxs[I]] = Rhs.getValue(I); - } - return *this; - } - - template < - typename T1, typename T2, typename T3, template class T4, - int... T5, - typename = typename std::enable_if_t> - SwizzleOp &operator=(SwizzleOp &&Rhs) { - std::array Idxs{Indexes...}; - for (size_t I = 0; I < Idxs.size(); ++I) { - (*m_Vector)[Idxs[I]] = Rhs.getValue(I); - } - return *this; - } - template > NewRelOp, EqualTo, Indexes...> operator==(const T &Rhs) const { return NewRelOp, EqualTo, Indexes...>(NULL, *this, @@ -1358,20 +1437,10 @@ class SwizzleOp { #undef __SYCL_ACCESS_RETURN // End of hi/lo, even/odd, xyzw, and rgba swizzles. - // Leave store() interface to automatic conversion to vec<>. - // Load to vec_t and then assign to swizzle. - template - void load(size_t offset, multi_ptr ptr) { - vec_t Tmp; - Tmp.load(offset, ptr); - *this = Tmp; - } - template vec convert() const { // First materialize the swizzle to vec_t and then apply convert() to it. vec_t Tmp; - std::array Idxs{Indexes...}; for (size_t I = 0; I < Idxs.size(); ++I) { Tmp[I] = (*m_Vector)[Idxs[I]]; } @@ -1415,7 +1484,6 @@ class SwizzleOp { template CommonDataT getValue(EnableIfOneIndex Index) const { if (std::is_same_v, GetOp>) { - std::array Idxs{Indexes...}; return (*m_Vector)[Idxs[Index]]; } auto Op = OperationCurrentT(); @@ -1426,7 +1494,6 @@ class SwizzleOp { template DataT getValue(EnableIfMultipleIndexes Index) const { if (std::is_same_v, GetOp>) { - std::array Idxs{Indexes...}; return (*m_Vector)[Idxs[Index]]; } auto Op = OperationCurrentT(); @@ -1437,7 +1504,6 @@ class SwizzleOp { template