Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Make mutable swizzle functions unusable on const vec #13026

116 changes: 83 additions & 33 deletions sycl/include/sycl/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1735,6 +1735,10 @@ class SwizzleOp {
DataT, std::common_type_t<OpLeftDataT, OpRightDataT>>;
static constexpr int getNumElements() { return sizeof...(Indexes); }

template <typename RelayVecT = VecT>
static constexpr bool VecIsMutable =
!std::is_const_v<RelayVecT> && std::is_same_v<RelayVecT, VecT>;

using rel_t = detail::rel_t<DataT>;
using vec_t = vec<DataT, sizeof...(Indexes)>;
using vec_rel_t = vec<rel_t, sizeof...(Indexes)>;
Expand All @@ -1760,13 +1764,20 @@ class SwizzleOp {
OperationCurrentT, Indexes...>,
OperationCurrentT_, Idx_...>;

template <int IdxNum>
static constexpr bool HasOneIndex =
1 == IdxNum && SwizzleOp::getNumElements() == IdxNum;

template <int IdxNum, typename T = void>
using EnableIfOneIndex = typename std::enable_if_t<
1 == IdxNum && SwizzleOp::getNumElements() == IdxNum, T>;
using EnableIfOneIndex = typename std::enable_if_t<HasOneIndex<IdxNum>, T>;

template <int IdxNum>
static constexpr bool HasMultipleIndices =
1 != IdxNum && SwizzleOp::getNumElements() == IdxNum;

template <int IdxNum, typename T = void>
using EnableIfMultipleIndexes = typename std::enable_if_t<
1 != IdxNum && SwizzleOp::getNumElements() == IdxNum, T>;
using EnableIfMultipleIndexes =
typename std::enable_if_t<HasMultipleIndices<IdxNum>, T>;

template <typename T>
using EnableIfScalarType = typename std::enable_if_t<
Expand Down Expand Up @@ -1849,6 +1860,27 @@ class SwizzleOp {
#ifdef __SYCL_OPASSIGN
#error "Undefine __SYCL_OPASSIGN macro."
#endif
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
#define __SYCL_OPASSIGN(OPASSIGN, OP) \
template <typename RelayVecT = VecT> \
friend const std::enable_if_t<VecIsMutable<RelayVecT>, SwizzleOp> \
&operator OPASSIGN(const SwizzleOp & Lhs, const DataT & Rhs) { \
Lhs.operatorHelper<OP>(vec_t(Rhs)); \
return Lhs; \
} \
template <typename RhsOperation, typename RelayVecT = VecT> \
friend const std::enable_if_t<VecIsMutable<RelayVecT>, SwizzleOp> \
&operator OPASSIGN(const SwizzleOp & Lhs, const RhsOperation & Rhs) { \
Lhs.operatorHelper<OP>(Rhs); \
return Lhs; \
} \
template <typename RelayVecT = VecT> \
friend const std::enable_if_t<VecIsMutable<RelayVecT>, SwizzleOp> \
&operator OPASSIGN(const SwizzleOp & Lhs, const vec_t & Rhs) { \
Lhs.operatorHelper<OP>(Rhs); \
return Lhs; \
}
#else // __INTEL_PREVIEW_BREAKING_CHANGES
#define __SYCL_OPASSIGN(OPASSIGN, OP) \
SwizzleOp &operator OPASSIGN(const DataT & Rhs) { \
operatorHelper<OP>(vec_t(Rhs)); \
Expand All @@ -1859,6 +1891,7 @@ class SwizzleOp {
operatorHelper<OP>(Rhs); \
return *this; \
}
#endif // __INTEL_PREVIEW_BREAKING_CHANGES

__SYCL_OPASSIGN(+=, std::plus)
__SYCL_OPASSIGN(-=, std::minus)
Expand All @@ -1875,6 +1908,22 @@ class SwizzleOp {
#ifdef __SYCL_UOP
#error "Undefine __SYCL_UOP macro"
#endif
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
#define __SYCL_UOP(UOP, OPASSIGN) \
template <typename RelayVecT = VecT> \
friend const std::enable_if_t<VecIsMutable<RelayVecT>, SwizzleOp> \
&operator UOP(const SwizzleOp & sv) { \
sv OPASSIGN static_cast<DataT>(1); \
return sv; \
} \
template <typename RelayVecT = VecT> \
friend std::enable_if_t<VecIsMutable<RelayVecT>, vec_t> operator UOP( \
const SwizzleOp & sv, int) { \
vec_t Ret = sv; \
sv OPASSIGN static_cast<DataT>(1); \
return Ret; \
}
#else // __INTEL_PREVIEW_BREAKING_CHANGES
#define __SYCL_UOP(UOP, OPASSIGN) \
SwizzleOp &operator UOP() { \
*this OPASSIGN static_cast<DataT>(1); \
Expand All @@ -1885,6 +1934,7 @@ class SwizzleOp {
*this OPASSIGN static_cast<DataT>(1); \
return Ret; \
}
#endif // __INTEL_PREVIEW_BREAKING_CHANGES

__SYCL_UOP(++, +=)
__SYCL_UOP(--, -=)
Expand Down Expand Up @@ -2010,35 +2060,33 @@ class SwizzleOp {
#undef __SYCL_RELLOGOP
#endif // defined(__INTEL_PREVIEW_BREAKING_CHANGES)

template <int IdxNum = getNumElements(),
typename = EnableIfMultipleIndexes<IdxNum>>
SwizzleOp &operator=(const vec<DataT, IdxNum> &Rhs) {
template <int IdxNum = getNumElements(), typename RelayVecT = VecT>
std::enable_if_t<HasMultipleIndices<IdxNum> && VecIsMutable<RelayVecT>,
SwizzleOp> &
operator=(const vec<DataT, IdxNum> &Rhs) {
std::array<int, IdxNum> Idxs{Indexes...};
for (size_t I = 0; I < Idxs.size(); ++I) {
m_Vector->setValue(Idxs[I], Rhs.getValue(I));
}
return *this;
}

template <int IdxNum = getNumElements(), typename = EnableIfOneIndex<IdxNum>>
SwizzleOp &operator=(const DataT &Rhs) {
std::array<int, IdxNum> Idxs{Indexes...};
m_Vector->setValue(Idxs[0], Rhs);
return *this;
}

template <int IdxNum = getNumElements(),
EnableIfMultipleIndexes<IdxNum, bool> = true>
SwizzleOp &operator=(const DataT &Rhs) {
std::array<int, IdxNum> Idxs{Indexes...};
for (auto Idx : Idxs) {
m_Vector->setValue(Idx, Rhs);
template <typename RelayVecT = VecT>
std::enable_if_t<VecIsMutable<RelayVecT>, SwizzleOp> &
operator=(const DataT &Rhs) {
std::array<int, getNumElements()> Idxs{Indexes...};
if constexpr (getNumElements() == 1) {
m_Vector->setValue(Idxs[0], Rhs);
} else {
for (auto Idx : Idxs)
m_Vector->setValue(Idx, Rhs);
}
return *this;
}

template <int IdxNum = getNumElements(), typename = EnableIfOneIndex<IdxNum>>
SwizzleOp &operator=(DataT &&Rhs) {
template <int IdxNum = getNumElements(), typename RelayVecT = VecT>
std::enable_if_t<HasOneIndex<IdxNum> && VecIsMutable<RelayVecT>, SwizzleOp> &
operator=(DataT &&Rhs) {
std::array<int, IdxNum> Idxs{Indexes...};
m_Vector->setValue(Idxs[0], Rhs);
return *this;
Expand Down Expand Up @@ -2186,10 +2234,10 @@ class SwizzleOp {
return NewLHOp<RhsOperation, LShift, Indexes...>(m_Vector, *this, Rhs);
}

template <
typename T1, typename T2, typename T3, template <typename> class T4,
int... T5,
typename = typename std::enable_if_t<sizeof...(T5) == getNumElements()>>
template <typename T1, typename T2, typename T3, template <typename> class T4,
int... T5,
typename = typename std::enable_if_t<
sizeof...(T5) == getNumElements() && VecIsMutable<>>>
SwizzleOp &operator=(const SwizzleOp<T1, T2, T3, T4, T5...> &Rhs) {
std::array<int, getNumElements()> Idxs{Indexes...};
for (size_t I = 0; I < Idxs.size(); ++I) {
Expand All @@ -2198,10 +2246,10 @@ class SwizzleOp {
return *this;
}

template <
typename T1, typename T2, typename T3, template <typename> class T4,
int... T5,
typename = typename std::enable_if_t<sizeof...(T5) == getNumElements()>>
template <typename T1, typename T2, typename T3, template <typename> class T4,
int... T5,
typename = typename std::enable_if_t<
sizeof...(T5) == getNumElements() && VecIsMutable<>>>
SwizzleOp &operator=(SwizzleOp<T1, T2, T3, T4, T5...> &&Rhs) {
std::array<int, getNumElements()> Idxs{Indexes...};
for (size_t I = 0; I < Idxs.size(); ++I) {
Expand Down Expand Up @@ -2343,8 +2391,10 @@ class SwizzleOp {

// Leave store() interface to automatic conversion to vec<>.
// Load to vec_t and then assign to swizzle.
template <access::address_space Space, access::decorated DecorateAddress>
void load(size_t offset, multi_ptr<DataT, Space, DecorateAddress> ptr) {
template <access::address_space Space, access::decorated DecorateAddress,
typename RelayVecT = VecT>
std::enable_if_t<VecIsMutable<RelayVecT>, void>
load(size_t offset, multi_ptr<DataT, Space, DecorateAddress> ptr) const {
vec_t Tmp;
Tmp.template load(offset, ptr);
*this = Tmp;
Expand Down Expand Up @@ -2416,7 +2466,7 @@ class SwizzleOp {
}

template <template <typename> class Operation, typename RhsOperation>
void operatorHelper(const RhsOperation &Rhs) {
void operatorHelper(const RhsOperation &Rhs) const {
Operation<vec_data_t<DataT>> Op;
std::array<int, getNumElements()> Idxs{Indexes...};
for (size_t I = 0; I < Idxs.size(); ++I) {
Expand Down
78 changes: 78 additions & 0 deletions sycl/test-e2e/Regression/swizzle_opassign.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// REQUIRES: aspect-usm_shared_allocations
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out
// RUN: %if preview-breaking-changes-supported %{ %{build} -fpreview-breaking-changes -o %t2.out %}
// RUN: %if preview-breaking-changes-supported %{ %{run} %t2.out %}

// Tests that the mutating operators (+=, -=, ..., ++, --) on swizzles compile
// and correctly mutate the elements in the corresponding vector.

#include <sycl/sycl.hpp>

constexpr size_t NumOps = 14;
constexpr std::string_view OpNames[NumOps] = {
"+=", "-=", "*=", "/=", "%=", "&=", "|=",
"^=", "<<=", ">>=", "prefix ++", "prefix --", "postfix ++", "prefix ++"};

int main() {
sycl::queue Q;
bool *Results = sycl::malloc_shared<bool>(NumOps, Q);
for (size_t I = 0; I < NumOps; ++I)
Results[I] = 0;

Q.single_task([=]() {
bool *ResultIt = Results;
#define TestCase(OP) \
{ \
sycl::vec<int, 4> VecVal{1, 2, 3, 4}; \
int ExpectedRes = VecVal[1] OP 2; \
*(ResultIt++) = (VecVal.swizzle<1>() OP## = 2)[0] == ExpectedRes && \
VecVal[1] == ExpectedRes; \
}
TestCase(+);
TestCase(-);
TestCase(*);
TestCase(/);
TestCase(%);
TestCase(&);
TestCase(|);
TestCase(^);
TestCase(<<);
TestCase(>>);
{
sycl::vec<int, 4> VecVal{1, 2, 3, 4};
int ExpectedRes = VecVal[1] + 1;
*(ResultIt++) = (++VecVal.swizzle<1>())[0] == ExpectedRes &&
VecVal[1] == ExpectedRes;
}
{
sycl::vec<int, 4> VecVal{1, 2, 3, 4};
int ExpectedRes = VecVal[1] - 1;
*(ResultIt++) = (--VecVal.swizzle<1>())[0] == ExpectedRes &&
VecVal[1] == ExpectedRes;
}
{
sycl::vec<int, 4> VecVal{1, 2, 3, 4};
int ExpectedRes = VecVal[1] + 1;
*(ResultIt++) = (VecVal.swizzle<1>()++)[0] == (ExpectedRes - 1) &&
VecVal[1] == ExpectedRes;
}
{
sycl::vec<int, 4> VecVal{1, 2, 3, 4};
int ExpectedRes = VecVal[1] - 1;
*(ResultIt++) = (VecVal.swizzle<1>()--)[0] == (ExpectedRes + 1) &&
VecVal[1] == ExpectedRes;
}
}).wait_and_throw();

int Failures = 0;
for (size_t I = 0; I < NumOps; ++I) {
if (!Results[I]) {
std::cout << "Failed for " << OpNames[I] << std::endl;
++Failures;
}
}

sycl::free(Results, Q);
return Failures;
}
97 changes: 97 additions & 0 deletions sycl/test/basic_tests/vectors/const_swizzle_negative.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// REQUIRES: preview-breaking-changes-supported
// RUN: %clangxx -fsycl-device-only -ferror-limit=0 -Xclang -fsycl-is-device -fsyntax-only -fpreview-breaking-changes -Xclang -verify -Xclang -verify-ignore-unexpected=note %s

#include <sycl/sycl.hpp>

int main() {
sycl::queue Q;
Q.single_task([]() {
const sycl::vec<int, 4> X{1};

// expected-error@+1 {{no viable overloaded '='}}
X.swizzle<0>() = 1;
// expected-error@+1 {{no viable overloaded '='}}
X.swizzle<0>() = sycl::vec<int, 1>{1};
// expected-error@+1 {{no viable overloaded '='}}
X.swizzle<0, 2>() = sycl::vec<int, 2>{1};

// expected-error@+1 {{no viable overloaded '+='}}
X.swizzle<0>() += 1;
// expected-error@+1 {{no viable overloaded '-='}}
X.swizzle<0>() -= 1;
// expected-error@+1 {{no viable overloaded '*='}}
X.swizzle<0>() *= 1;
// expected-error@+1 {{no viable overloaded '/='}}
X.swizzle<0>() /= 1;
// expected-error@+1 {{no viable overloaded '%='}}
X.swizzle<0>() %= 1;
// expected-error@+1 {{no viable overloaded '&='}}
X.swizzle<0>() &= 1;
// expected-error@+1 {{no viable overloaded '|='}}
X.swizzle<0>() |= 1;
// expected-error@+1 {{no viable overloaded '^='}}
X.swizzle<0>() ^= 1;
// expected-error@+1 {{no viable overloaded '>>='}}
X.swizzle<0>() >>= 1;
// expected-error@+1 {{no viable overloaded '<<='}}
X.swizzle<0>() <<= 1;

// expected-error@+1 {{no viable overloaded '+='}}
X.swizzle<0>() += sycl::vec<int, 1>{1};
// expected-error@+1 {{no viable overloaded '-='}}
X.swizzle<0>() -= sycl::vec<int, 1>{1};
// expected-error@+1 {{no viable overloaded '*='}}
X.swizzle<0>() *= sycl::vec<int, 1>{1};
// expected-error@+1 {{no viable overloaded '/='}}
X.swizzle<0>() /= sycl::vec<int, 1>{1};
// expected-error@+1 {{no viable overloaded '%='}}
X.swizzle<0>() %= sycl::vec<int, 1>{1};
// expected-error@+1 {{no viable overloaded '&='}}
X.swizzle<0>() &= sycl::vec<int, 1>{1};
// expected-error@+1 {{no viable overloaded '|='}}
X.swizzle<0>() |= sycl::vec<int, 1>{1};
// expected-error@+1 {{no viable overloaded '^='}}
X.swizzle<0>() ^= sycl::vec<int, 1>{1};
// expected-error@+1 {{no viable overloaded '>>='}}
X.swizzle<0>() >>= sycl::vec<int, 1>{1};
// expected-error@+1 {{no viable overloaded '<<='}}
X.swizzle<0>() <<= sycl::vec<int, 1>{1};

// expected-error@+1 {{no viable overloaded '+='}}
X.swizzle<0>() += X.swizzle<1>();
// expected-error@+1 {{no viable overloaded '-='}}
X.swizzle<0>() -= X.swizzle<1>();
// expected-error@+1 {{no viable overloaded '*='}}
X.swizzle<0>() *= X.swizzle<1>();
// expected-error@+1 {{no viable overloaded '/='}}
X.swizzle<0>() /= X.swizzle<1>();
// expected-error@+1 {{no viable overloaded '%='}}
X.swizzle<0>() %= X.swizzle<1>();
// expected-error@+1 {{no viable overloaded '&='}}
X.swizzle<0>() &= X.swizzle<1>();
// expected-error@+1 {{no viable overloaded '|='}}
X.swizzle<0>() |= X.swizzle<1>();
// expected-error@+1 {{no viable overloaded '^='}}
X.swizzle<0>() ^= X.swizzle<1>();
// expected-error@+1 {{no viable overloaded '>>='}}
X.swizzle<0>() >>= X.swizzle<1>();
// expected-error@+1 {{no viable overloaded '<<='}}
X.swizzle<0>() <<= X.swizzle<1>();

// expected-error@+1 {{cannot increment value of type}}
X.swizzle<0>()++;
// expected-error@+1 {{cannot increment value of type}}
++X.swizzle<0>();
// expected-error@+1 {{cannot decrement value of type}}
X.swizzle<0>()--;
// expected-error@+1 {{cannot decrement value of type}}
--X.swizzle<0>();

int I = 1;
// expected-error@+1 {{no matching member function for call to 'load'}}
X.load(0,
steffenlarsen marked this conversation as resolved.
Show resolved Hide resolved
sycl::address_space_cast<sycl::access::address_space::private_space,
sycl::access::decorated::no>(&I));
});
return 0;
}
Loading