Skip to content

Commit

Permalink
[NFC][SYCL] Minor refactoring in sycl::vec<> (#13949)
Browse files Browse the repository at this point in the history
Follow-up of #13947

Added comments + Rearranged code + Removed redundant MACRO
  • Loading branch information
uditagarwal97 committed Jun 10, 2024
1 parent b26b69a commit f88d72e
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 38 deletions.
17 changes: 14 additions & 3 deletions sycl/include/sycl/detail/vector_arith.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ using rel_t = typename std::conditional_t<
} else { \
Ret.m_Data = Lhs.m_Data BINOP Rhs.m_Data; \
if constexpr (std::is_same_v<DataT, bool> && CONVERT) { \
Ret.ConvertToDataT(); \
vec_arith_common<bool, NumElements>::ConvertToDataT(Ret); \
} \
} \
return Ret; \
Expand Down Expand Up @@ -189,7 +189,7 @@ class vec_arith : public vec_arith_common<DataT, NumElements> {
} else {
Ret = vec_t{-Lhs.m_Data};
if constexpr (std::is_same_v<DataT, bool>) {
Ret.ConvertToDataT();
vec_arith_common<bool, NumElements>::ConvertToDataT(Ret);
}
return Ret;
}
Expand Down Expand Up @@ -391,12 +391,23 @@ template <typename DataT, int NumElements> class vec_arith_common {
} else {
vec_t Ret{(typename vec_t::DataType) ~Rhs.m_Data};
if constexpr (std::is_same_v<DataT, bool>) {
Ret.ConvertToDataT();
vec_arith_common<bool, NumElements>::ConvertToDataT(Ret);
}
return Ret;
}
}

#ifdef __SYCL_DEVICE_ONLY__
using vec_bool_t = vec<bool, NumElements>;
// Required only for std::bool.
static void ConvertToDataT(vec_bool_t &Ret) {
for (size_t I = 0; I < NumElements; ++I) {
DataT Tmp = detail::VecAccess<vec_bool_t>::getValue(Ret, I);
detail::VecAccess<vec_bool_t>::setValue(Ret, I, Tmp);
}
}
#endif

// friends
template <typename T1, int T2> friend class vec;
};
Expand Down
80 changes: 45 additions & 35 deletions sycl/include/sycl/vector_preview.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@
#error "SYCL device compiler is built without ext_vector_type support"
#endif

#if defined(__SYCL_DEVICE_ONLY__)
#define __SYCL_USE_EXT_VECTOR_TYPE__
#endif

#include <sycl/access/access.hpp> // for decorated, address_space
#include <sycl/aliases.hpp> // for half, cl_char, cl_int
#include <sycl/detail/common.hpp> // for ArrayCreator, RepeatV...
Expand All @@ -47,7 +43,7 @@
#include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16

#include <array> // for array
#include <assert.h> // for assert
#include <cassert> // for assert
#include <cstddef> // for size_t, NULL, byte
#include <cstdint> // for uint8_t, int16_t, int...
#include <functional> // for divides, multiplies
Expand Down Expand Up @@ -363,18 +359,30 @@ template <typename T>
using vec_data_t = typename detail::vec_helper<T>::RetType;

///////////////////////// class sycl::vec /////////////////////////
/// Provides a cross-patform vector class template that works efficiently on
/// SYCL devices as well as in host C++ code.
///
/// \ingroup sycl_api
// Provides a cross-platform vector class template that works efficiently on
// SYCL devices as well as in host C++ code.
template <typename Type, int NumElements>
class vec : public detail::vec_arith<Type, NumElements> {
using DataT = Type;

// https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#memory-layout-and-alignment
// It is required by the SPEC to align vec<DataT, 3> with vec<DataT, 4>.
static constexpr size_t AdjustedNum = (NumElements == 3) ? 4 : NumElements;

// This represent type of underlying value. There should be only one field
// in the class, so vec<float, 16> should be equal to float16 in memory.
using DataType = typename detail::VecStorage<DataT, NumElements>::DataType;

public:
#ifdef __SYCL_DEVICE_ONLY__
// Type used for passing sycl::vec to SPIRV builtins.
// We can not use ext_vector_type(1) as it's not supported by SPIRV
// plugins (CTS fails).
using vector_t =
typename detail::VecStorage<DataT, NumElements>::VectorDataType;
#endif // __SYCL_DEVICE_ONLY__

private:
static constexpr bool IsHostHalf =
std::is_same_v<DataT, sycl::detail::half_impl::half> &&
std::is_same_v<sycl::detail::half_impl::StorageT,
Expand All @@ -383,7 +391,6 @@ class vec : public detail::vec_arith<Type, NumElements> {
static constexpr bool IsBfloat16 =
std::is_same_v<DataT, sycl::ext::oneapi::bfloat16>;

static constexpr size_t AdjustedNum = (NumElements == 3) ? 4 : NumElements;
static constexpr size_t Sz = sizeof(DataT) * AdjustedNum;
static constexpr bool IsSizeGreaterThanMaxAlign =
(Sz > detail::MaxVecAlignment);
Expand Down Expand Up @@ -456,6 +463,8 @@ class vec : public detail::vec_arith<Type, NumElements> {
}
template <typename DataT_, typename T>
static constexpr auto FlattenVecArgHelper(const T &A) {
// static_cast required to avoid narrowing conversion warning
// when T = unsigned long int and DataT_ = int.
return std::array<DataT_, 1>{vec_data<DataT_>::get(static_cast<DataT_>(A))};
}
template <typename DataT_, typename T> struct FlattenVecArg {
Expand Down Expand Up @@ -551,6 +560,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
using EnableIfSuitableNumElements =
typename std::enable_if_t<SizeChecker<0, NumElements, argTN...>::value>;

// Implementation detail for the next public ctor.
template <size_t... Is>
constexpr vec(const std::array<vec_data_t<DataT>, NumElements> &Arr,
std::index_sequence<Is...>)
Expand All @@ -562,14 +572,13 @@ class vec : public detail::vec_arith<Type, NumElements> {
})(Arr[Is])...} {}

public:
// Aliases required by SPEC to make sycl::vec consistent
// with that of marray and buffer.
using element_type = DataT;
using value_type = DataT;
using rel_t = detail::rel_t<DataT>;
#ifdef __SYCL_DEVICE_ONLY__
using vector_t =
typename detail::VecStorage<DataT, NumElements>::VectorDataType;
#endif // __SYCL_DEVICE_ONLY__

/****************** Constructors **************/
vec() = default;

constexpr vec(const vec &Rhs) = default;
Expand All @@ -587,7 +596,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
return *this;
}

#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
#ifdef __SYCL_DEVICE_ONLY__
template <typename T = void>
using EnableIfNotHostHalf = typename std::enable_if_t<!IsHostHalf, T>;

Expand All @@ -601,7 +610,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
template <typename T = void>
using EnableIfNotUsingArrayOnDevice =
typename std::enable_if_t<!IsUsingArrayOnDevice, T>;
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
#endif // __SYCL_DEVICE_ONLY__

template <typename T = void>
using EnableIfUsingArray =
Expand All @@ -612,7 +621,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
typename std::enable_if_t<!IsUsingArrayOnDevice && !IsUsingArrayOnHost,
T>;

#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
#ifdef __SYCL_DEVICE_ONLY__

template <typename Ty = DataT>
explicit constexpr vec(const EnableIfNotUsingArrayOnDevice<Ty> &arg)
Expand Down Expand Up @@ -645,12 +654,17 @@ class vec : public detail::vec_arith<Type, NumElements> {
}
return *this;
}
#else // __SYCL_USE_EXT_VECTOR_TYPE__
#else // __SYCL_DEVICE_ONLY__
explicit constexpr vec(const DataT &arg)
: vec{detail::RepeatValue<NumElements>(
static_cast<vec_data_t<DataT>>(arg)),
std::make_index_sequence<NumElements>()} {}

/****************** Assignment Operators **************/

// Template required to prevent ambiguous overload with the copy assignment
// when NumElements == 1. The template prevents implicit conversion from
// vec<_, 1> to DataT.
template <typename Ty = DataT>
typename std::enable_if_t<
std::is_fundamental_v<vec_data_t<Ty>> ||
Expand All @@ -662,9 +676,9 @@ class vec : public detail::vec_arith<Type, NumElements> {
}
return *this;
}
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
#endif // __SYCL_DEVICE_ONLY__

#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
#ifdef __SYCL_DEVICE_ONLY__
// Optimized naive constructors with NumElements of DataT values.
// We don't expect compilers to optimize vararg recursive functions well.

Expand Down Expand Up @@ -713,7 +727,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
vec_data<Ty>::get(ArgA), vec_data<Ty>::get(ArgB),
vec_data<Ty>::get(ArgC), vec_data<Ty>::get(ArgD),
vec_data<Ty>::get(ArgE), vec_data<Ty>::get(ArgF)} {}
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
#endif // __SYCL_DEVICE_ONLY__

// Constructor from values of base type or vec of base type. Checks that
// base types are match and that the NumElements == sum of lengths of args.
Expand All @@ -736,6 +750,10 @@ class vec : public detail::vec_arith<Type, NumElements> {
}
}

/* Available only when: compiled for the device.
* Converts this SYCL vec instance to the underlying backend-native vector
* type defined by vector_t.
*/
operator vector_t() const {
if constexpr (!IsUsingArrayOnDevice) {
return m_Data;
Expand Down Expand Up @@ -986,17 +1004,9 @@ class vec : public detail::vec_arith<Type, NumElements> {
store(Offset, MultiPtr);
}

void ConvertToDataT() {
for (size_t i = 0; i < NumElements; ++i) {
DataT tmp = getValue(i);
setValue(i, tmp);
}
}

private:
// Generic method that execute "Operation" on underlying values.

#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
#ifdef __SYCL_DEVICE_ONLY__
template <template <typename> class Operation,
typename Ty = vec<DataT, NumElements>>
vec<DataT, NumElements>
Expand All @@ -1018,7 +1028,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
}
return Result;
}
#else // __SYCL_USE_EXT_VECTOR_TYPE__
#else // __SYCL_DEVICE_ONLY__
template <template <typename> class Operation>
vec<DataT, NumElements>
operatorHelper(const vec<DataT, NumElements> &Rhs) const {
Expand All @@ -1029,12 +1039,12 @@ class vec : public detail::vec_arith<Type, NumElements> {
}
return Result;
}
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
#endif // __SYCL_DEVICE_ONLY__

// setValue and getValue should be able to operate on different underlying
// types: enum cl_float#N , builtin vector float#N, builtin type float.
// These versions are for N > 1.
#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
#ifdef __SYCL_DEVICE_ONLY__
template <int Num = NumElements, typename Ty = int,
typename = typename std::enable_if_t<1 != Num>>
constexpr void setValue(EnableIfNotHostHalf<Ty> Index, const DataT &Value,
Expand All @@ -1059,7 +1069,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
constexpr DataT getValue(EnableIfHostHalf<Ty> Index, int) const {
return vec_data<DataT>::get(m_Data.s[Index]);
}
#else // __SYCL_USE_EXT_VECTOR_TYPE__
#else // __SYCL_DEVICE_ONLY__
template <int Num = NumElements,
typename = typename std::enable_if_t<1 != Num>>
constexpr void setValue(int Index, const DataT &Value, int) {
Expand All @@ -1071,7 +1081,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
constexpr DataT getValue(int Index, int) const {
return vec_data<DataT>::get(m_Data[Index]);
}
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
#endif // __SYCL_DEVICE_ONLY__

// N==1 versions, used by host and device. Shouldn't trailing type be int?
template <int Num = NumElements,
Expand Down

0 comments on commit f88d72e

Please sign in to comment.