Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NFC][SYCL] Minor refactoring in sycl::vec<> #13949

Merged
merged 16 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions sycl/include/sycl/detail/vector_arith.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ using rel_t = typename std::conditional_t<
} else { \
Ret.m_Data = Lhs.m_Data BINOP Rhs.m_Data; \
if constexpr (std::is_same_v<DataT, bool> && CONVERT) { \
Ret.ConvertToDataT(); \
vec_arith_common<bool, NumElements>::ConvertToDataT(Ret); \
} \
} \
return Ret; \
Expand Down Expand Up @@ -189,7 +189,7 @@ class vec_arith : public vec_arith_common<DataT, NumElements> {
} else {
Ret = vec_t{-Lhs.m_Data};
if constexpr (std::is_same_v<DataT, bool>) {
Ret.ConvertToDataT();
vec_arith_common<bool, NumElements>::ConvertToDataT(Ret);
}
return Ret;
}
Expand Down Expand Up @@ -391,12 +391,23 @@ template <typename DataT, int NumElements> class vec_arith_common {
} else {
vec_t Ret{(typename vec_t::DataType) ~Rhs.m_Data};
if constexpr (std::is_same_v<DataT, bool>) {
Ret.ConvertToDataT();
vec_arith_common<bool, NumElements>::ConvertToDataT(Ret);
}
return Ret;
}
}

#ifdef __SYCL_DEVICE_ONLY__
using vec_bool_t = vec<bool, NumElements>;
// Require only for std::bool.
uditagarwal97 marked this conversation as resolved.
Show resolved Hide resolved
static void ConvertToDataT(vec_bool_t &Ret) {
for (size_t i = 0; i < NumElements; ++i) {
uditagarwal97 marked this conversation as resolved.
Show resolved Hide resolved
DataT tmp = detail::VecAccess<vec_bool_t>::getValue(Ret, i);
uditagarwal97 marked this conversation as resolved.
Show resolved Hide resolved
detail::VecAccess<vec_bool_t>::setValue(Ret, i, tmp);
uditagarwal97 marked this conversation as resolved.
Show resolved Hide resolved
}
}
#endif

// friends
template <typename T1, int T2> friend class vec;
};
Expand Down
78 changes: 43 additions & 35 deletions sycl/include/sycl/vector_preview.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@
#error "SYCL device compiler is built without ext_vector_type support"
#endif

#if defined(__SYCL_DEVICE_ONLY__)
#define __SYCL_USE_EXT_VECTOR_TYPE__
#endif

#include <sycl/access/access.hpp> // for decorated, address_space
#include <sycl/aliases.hpp> // for half, cl_char, cl_int
#include <sycl/detail/common.hpp> // for ArrayCreator, RepeatV...
Expand All @@ -47,7 +43,7 @@
#include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16

#include <array> // for array
#include <assert.h> // for assert
#include <cassert> // for assert
#include <cstddef> // for size_t, NULL, byte
#include <cstdint> // for uint8_t, int16_t, int...
#include <functional> // for divides, multiplies
Expand Down Expand Up @@ -363,18 +359,28 @@ template <typename T>
using vec_data_t = typename detail::vec_helper<T>::RetType;

///////////////////////// class sycl::vec /////////////////////////
/// Provides a cross-patform vector class template that works efficiently on
/// SYCL devices as well as in host C++ code.
///
/// \ingroup sycl_api
// Provides a cross-patform vector class template that works efficiently on
uditagarwal97 marked this conversation as resolved.
Show resolved Hide resolved
// SYCL devices as well as in host C++ code.
template <typename Type, int NumElements>
class vec : public detail::vec_arith<Type, NumElements> {
using DataT = Type;

static constexpr size_t AdjustedNum = (NumElements == 3) ? 4 : NumElements;
uditagarwal97 marked this conversation as resolved.
Show resolved Hide resolved

// This represent type of underlying value. There should be only one field
// in the class, so vec<float, 16> should be equal to float16 in memory.
using DataType = typename detail::VecStorage<DataT, NumElements>::DataType;

public:
#ifdef __SYCL_DEVICE_ONLY__
// Type used for passing sycl::vec to SPIRV builtins.
// We can not use ext_vector_type(1) as it's not supported by SPIRV
// plugins (CTS fails).
using vector_t =
typename detail::VecStorage<DataT, NumElements>::VectorDataType;
#endif // __SYCL_DEVICE_ONLY__

private:
static constexpr bool IsHostHalf =
std::is_same_v<DataT, sycl::detail::half_impl::half> &&
std::is_same_v<sycl::detail::half_impl::StorageT,
Expand All @@ -383,7 +389,6 @@ class vec : public detail::vec_arith<Type, NumElements> {
static constexpr bool IsBfloat16 =
std::is_same_v<DataT, sycl::ext::oneapi::bfloat16>;

static constexpr size_t AdjustedNum = (NumElements == 3) ? 4 : NumElements;
static constexpr size_t Sz = sizeof(DataT) * AdjustedNum;
static constexpr bool IsSizeGreaterThanMaxAlign =
(Sz > detail::MaxVecAlignment);
Expand Down Expand Up @@ -456,6 +461,8 @@ class vec : public detail::vec_arith<Type, NumElements> {
}
template <typename DataT_, typename T>
static constexpr auto FlattenVecArgHelper(const T &A) {
// static_cast required to avoid narrowing conversion warning
// when T = unsigned long int and DataT_ = int.
return std::array<DataT_, 1>{vec_data<DataT_>::get(static_cast<DataT_>(A))};
}
template <typename DataT_, typename T> struct FlattenVecArg {
Expand Down Expand Up @@ -551,6 +558,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
using EnableIfSuitableNumElements =
typename std::enable_if_t<SizeChecker<0, NumElements, argTN...>::value>;

// Implementation detail for the next public ctor.
template <size_t... Is>
constexpr vec(const std::array<vec_data_t<DataT>, NumElements> &Arr,
std::index_sequence<Is...>)
Expand All @@ -562,14 +570,13 @@ class vec : public detail::vec_arith<Type, NumElements> {
})(Arr[Is])...} {}

public:
// Aliases required by SPEC to make sycl::vec consistent
// with that of marray and buffer.
using element_type = DataT;
using value_type = DataT;
using rel_t = detail::rel_t<DataT>;
#ifdef __SYCL_DEVICE_ONLY__
using vector_t =
typename detail::VecStorage<DataT, NumElements>::VectorDataType;
#endif // __SYCL_DEVICE_ONLY__

/****************** Constructors **************/
vec() = default;

constexpr vec(const vec &Rhs) = default;
Expand All @@ -587,7 +594,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
return *this;
}

#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
#ifdef __SYCL_DEVICE_ONLY__
template <typename T = void>
using EnableIfNotHostHalf = typename std::enable_if_t<!IsHostHalf, T>;

Expand All @@ -601,7 +608,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
template <typename T = void>
using EnableIfNotUsingArrayOnDevice =
typename std::enable_if_t<!IsUsingArrayOnDevice, T>;
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
#endif // __SYCL_DEVICE_ONLY__

template <typename T = void>
using EnableIfUsingArray =
Expand All @@ -612,7 +619,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
typename std::enable_if_t<!IsUsingArrayOnDevice && !IsUsingArrayOnHost,
T>;

#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
#ifdef __SYCL_DEVICE_ONLY__

template <typename Ty = DataT>
explicit constexpr vec(const EnableIfNotUsingArrayOnDevice<Ty> &arg)
Expand Down Expand Up @@ -645,12 +652,17 @@ class vec : public detail::vec_arith<Type, NumElements> {
}
return *this;
}
#else // __SYCL_USE_EXT_VECTOR_TYPE__
#else // __SYCL_DEVICE_ONLY__
explicit constexpr vec(const DataT &arg)
: vec{detail::RepeatValue<NumElements>(
static_cast<vec_data_t<DataT>>(arg)),
std::make_index_sequence<NumElements>()} {}

/****************** Assignment Operators **************/

// Template required to prevent ambiguous overload with the copy assignment
// when NumElements == 1. The template prevents implicit conversion from
// vec<_, 1> to DataT.
template <typename Ty = DataT>
typename std::enable_if_t<
std::is_fundamental_v<vec_data_t<Ty>> ||
Expand All @@ -662,9 +674,9 @@ class vec : public detail::vec_arith<Type, NumElements> {
}
return *this;
}
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
#endif // __SYCL_DEVICE_ONLY__

#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
#ifdef __SYCL_DEVICE_ONLY__
// Optimized naive constructors with NumElements of DataT values.
// We don't expect compilers to optimize vararg recursive functions well.

Expand Down Expand Up @@ -713,7 +725,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
vec_data<Ty>::get(ArgA), vec_data<Ty>::get(ArgB),
vec_data<Ty>::get(ArgC), vec_data<Ty>::get(ArgD),
vec_data<Ty>::get(ArgE), vec_data<Ty>::get(ArgF)} {}
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
#endif // __SYCL_DEVICE_ONLY__

// Constructor from values of base type or vec of base type. Checks that
// base types are match and that the NumElements == sum of lengths of args.
Expand All @@ -736,6 +748,10 @@ class vec : public detail::vec_arith<Type, NumElements> {
}
}

/* Available only when: compiled for the device.
* Converts this SYCL vec instance to the underlying backend-native vector
* type defined by vector_t.
*/
operator vector_t() const {
if constexpr (!IsUsingArrayOnDevice) {
return m_Data;
Expand Down Expand Up @@ -986,17 +1002,9 @@ class vec : public detail::vec_arith<Type, NumElements> {
store(Offset, MultiPtr);
}

void ConvertToDataT() {
for (size_t i = 0; i < NumElements; ++i) {
DataT tmp = getValue(i);
setValue(i, tmp);
}
}

private:
// Generic method that execute "Operation" on underlying values.

#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
#ifdef __SYCL_DEVICE_ONLY__
template <template <typename> class Operation,
typename Ty = vec<DataT, NumElements>>
vec<DataT, NumElements>
Expand All @@ -1018,7 +1026,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
}
return Result;
}
#else // __SYCL_USE_EXT_VECTOR_TYPE__
#else // __SYCL_DEVICE_ONLY__
template <template <typename> class Operation>
vec<DataT, NumElements>
operatorHelper(const vec<DataT, NumElements> &Rhs) const {
Expand All @@ -1029,12 +1037,12 @@ class vec : public detail::vec_arith<Type, NumElements> {
}
return Result;
}
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
#endif // __SYCL_DEVICE_ONLY__

// setValue and getValue should be able to operate on different underlying
// types: enum cl_float#N , builtin vector float#N, builtin type float.
// These versions are for N > 1.
#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
#ifdef __SYCL_DEVICE_ONLY__
template <int Num = NumElements, typename Ty = int,
typename = typename std::enable_if_t<1 != Num>>
constexpr void setValue(EnableIfNotHostHalf<Ty> Index, const DataT &Value,
Expand All @@ -1059,7 +1067,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
constexpr DataT getValue(EnableIfHostHalf<Ty> Index, int) const {
return vec_data<DataT>::get(m_Data.s[Index]);
}
#else // __SYCL_USE_EXT_VECTOR_TYPE__
#else // __SYCL_DEVICE_ONLY__
template <int Num = NumElements,
typename = typename std::enable_if_t<1 != Num>>
constexpr void setValue(int Index, const DataT &Value, int) {
Expand All @@ -1071,7 +1079,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
constexpr DataT getValue(int Index, int) const {
return vec_data<DataT>::get(m_Data[Index]);
}
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
#endif // __SYCL_DEVICE_ONLY__

// N==1 versions, used by host and device. Shouldn't trailing type be int?
template <int Num = NumElements,
Expand Down
Loading