Skip to content

Commit

Permalink
Add comments; Rearrange code; Remove redundant MACRO
Browse files Browse the repository at this point in the history
  • Loading branch information
uditagarwal97 committed May 29, 2024
1 parent 5341760 commit 051a557
Showing 1 changed file with 51 additions and 27 deletions.
78 changes: 51 additions & 27 deletions sycl/include/sycl/vector_preview.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@
#error "SYCL device compiler is built without ext_vector_type support"
#endif

#if defined(__SYCL_DEVICE_ONLY__)
#define __SYCL_USE_EXT_VECTOR_TYPE__
#endif

#include <sycl/access/access.hpp> // for decorated, address_space
#include <sycl/aliases.hpp> // for half, cl_char, cl_int
#include <sycl/detail/common.hpp> // for ArrayCreator, RepeatV...
Expand All @@ -46,7 +42,7 @@
#include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16

#include <array> // for array
#include <assert.h> // for assert
#include <cassert> // for assert
#include <cstddef> // for size_t, NULL, byte
#include <cstdint> // for uint8_t, int16_t, int...
#include <functional> // for divides, multiplies
Expand Down Expand Up @@ -354,17 +350,27 @@ template <typename T>
using vec_data_t = typename detail::vec_helper<T>::RetType;

///////////////////////// class sycl::vec /////////////////////////
/// Provides a cross-patform vector class template that works efficiently on
/// SYCL devices as well as in host C++ code.
///
/// \ingroup sycl_api
// Provides a cross-patform vector class template that works efficiently on
// SYCL devices as well as in host C++ code.
template <typename Type, int NumElements> class vec {
using DataT = Type;

static constexpr size_t AdjustedNum = (NumElements == 3) ? 4 : NumElements;

// This represent type of underlying value. There should be only one field
// in the class, so vec<float, 16> should be equal to float16 in memory.
using DataType = typename detail::VecStorage<DataT, NumElements>::DataType;

public:
#ifdef __SYCL_DEVICE_ONLY__
// Type used for passing sycl::vec to SPIRV builtins.
// We can not use ext_vector_type(1) as it's not supported by SPIRV
// plugins (CTS fails).
using vector_t =
typename detail::VecStorage<DataT, NumElements>::VectorDataType;
#endif // __SYCL_DEVICE_ONLY__

private:
static constexpr bool IsHostHalf =
std::is_same_v<DataT, sycl::detail::half_impl::half> &&
std::is_same_v<sycl::detail::half_impl::StorageT,
Expand All @@ -373,7 +379,6 @@ template <typename Type, int NumElements> class vec {
static constexpr bool IsBfloat16 =
std::is_same_v<DataT, sycl::ext::oneapi::bfloat16>;

static constexpr size_t AdjustedNum = (NumElements == 3) ? 4 : NumElements;
static constexpr size_t Sz = sizeof(DataT) * AdjustedNum;
static constexpr bool IsSizeGreaterThanMaxAlign =
(Sz > detail::MaxVecAlignment);
Expand Down Expand Up @@ -446,6 +451,8 @@ template <typename Type, int NumElements> class vec {
}
template <typename DataT_, typename T>
static constexpr auto FlattenVecArgHelper(const T &A) {
// static_cast required to avoid narrowing conversion warning
// when T = unsigned long int and DataT_ = int.
return std::array<DataT_, 1>{vec_data<DataT_>::get(static_cast<DataT_>(A))};
}
template <typename DataT_, typename T> struct FlattenVecArg {
Expand Down Expand Up @@ -541,6 +548,7 @@ template <typename Type, int NumElements> class vec {
using EnableIfSuitableNumElements =
typename std::enable_if_t<SizeChecker<0, NumElements, argTN...>::value>;

// Implementation detail for the next public ctor.
template <size_t... Is>
constexpr vec(const std::array<vec_data_t<DataT>, NumElements> &Arr,
std::index_sequence<Is...>)
Expand All @@ -552,14 +560,13 @@ template <typename Type, int NumElements> class vec {
})(Arr[Is])...} {}

public:
// Aliases required by SPEC to make sycl::vec consistent
// with that of marray and buffer.
using element_type = DataT;
using value_type = DataT;
using rel_t = detail::rel_t<DataT>;
#ifdef __SYCL_DEVICE_ONLY__
using vector_t =
typename detail::VecStorage<DataT, NumElements>::VectorDataType;
#endif // __SYCL_DEVICE_ONLY__

/****************** Constructors **************/
vec() = default;

constexpr vec(const vec &Rhs) = default;
Expand All @@ -577,7 +584,7 @@ template <typename Type, int NumElements> class vec {
return *this;
}

#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
#ifdef __SYCL_DEVICE_ONLY__
template <typename T = void>
using EnableIfNotHostHalf = typename std::enable_if_t<!IsHostHalf, T>;

Expand All @@ -591,7 +598,7 @@ template <typename Type, int NumElements> class vec {
template <typename T = void>
using EnableIfNotUsingArrayOnDevice =
typename std::enable_if_t<!IsUsingArrayOnDevice, T>;
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
#endif // __SYCL_DEVICE_ONLY__

template <typename T = void>
using EnableIfUsingArray =
Expand All @@ -602,7 +609,7 @@ template <typename Type, int NumElements> class vec {
typename std::enable_if_t<!IsUsingArrayOnDevice && !IsUsingArrayOnHost,
T>;

#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
#ifdef __SYCL_DEVICE_ONLY__

template <typename Ty = DataT>
explicit constexpr vec(const EnableIfNotUsingArrayOnDevice<Ty> &arg)
Expand Down Expand Up @@ -635,12 +642,15 @@ template <typename Type, int NumElements> class vec {
}
return *this;
}
#else // __SYCL_USE_EXT_VECTOR_TYPE__
#else // __SYCL_DEVICE_ONLY__
explicit constexpr vec(const DataT &arg)
: vec{detail::RepeatValue<NumElements>(
static_cast<vec_data_t<DataT>>(arg)),
std::make_index_sequence<NumElements>()} {}

// Template required to prevent ambiguous overload with the copy assignment
// when NumElements == 1. The template prevents implicit conversion from
// vec<_, 1> to DataT.
template <typename Ty = DataT>
typename std::enable_if_t<
std::is_fundamental_v<vec_data_t<Ty>> ||
Expand All @@ -652,9 +662,9 @@ template <typename Type, int NumElements> class vec {
}
return *this;
}
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
#endif // __SYCL_DEVICE_ONLY__

#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
#ifdef __SYCL_DEVICE_ONLY__
// Optimized naive constructors with NumElements of DataT values.
// We don't expect compilers to optimize vararg recursive functions well.

Expand Down Expand Up @@ -703,7 +713,7 @@ template <typename Type, int NumElements> class vec {
vec_data<Ty>::get(ArgA), vec_data<Ty>::get(ArgB),
vec_data<Ty>::get(ArgC), vec_data<Ty>::get(ArgD),
vec_data<Ty>::get(ArgE), vec_data<Ty>::get(ArgF)} {}
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
#endif // __SYCL_DEVICE_ONLY__

// Constructor from values of base type or vec of base type. Checks that
// base types are match and that the NumElements == sum of lengths of args.
Expand All @@ -726,6 +736,11 @@ template <typename Type, int NumElements> class vec {
}
}

/* @SYCL2020
* Available only when: compiled for the device.
* Converts this SYCL vec instance to the underlying backend-native vector
* type defined by vector_t.
*/
operator vector_t() const {
if constexpr (!IsUsingArrayOnDevice) {
return m_Data;
Expand Down Expand Up @@ -976,12 +991,17 @@ template <typename Type, int NumElements> class vec {
store(Offset, MultiPtr);
}

#ifdef __SYCL_DEVICE_ONLY__
// Require only for std::bool.
void ConvertToDataT() {
for (size_t i = 0; i < NumElements; ++i) {
DataT tmp = getValue(i);
setValue(i, tmp);
}
}
#endif

/******************* sycl::vec math operations ***********************/

#if defined(__SYCL_BINOP) || defined(BINOP_BASE)
#error "Undefine __SYCL_BINOP and BINOP_BASE macro"
Expand Down Expand Up @@ -1222,11 +1242,13 @@ template <typename Type, int NumElements> class vec {
}
return Ret;
} else {
#ifdef __SYCL_DEVICE_ONLY__
vec Ret{(typename vec::DataType) ~Rhs.m_Data};
if constexpr (std::is_same_v<Type, bool>) {
Ret.ConvertToDataT();
}
return Ret;
#endif // __SYCL_DEVICE_ONLY__
}
}

Expand Down Expand Up @@ -1293,11 +1315,13 @@ template <typename Type, int NumElements> class vec {
I, vec_data<DataT>::get(-vec_data<DataT>::get(Lhs.getValue(I))));
return Ret;
} else {
#ifdef __SYCL_DEVICE_ONLY__
Ret = vec{-Lhs.m_Data};
if constexpr (std::is_same_v<Type, bool>) {
Ret.ConvertToDataT();
}
return Ret;
#endif // __SYCL_DEVICE_ONLY__
}
}

Expand All @@ -1311,7 +1335,7 @@ template <typename Type, int NumElements> class vec {
private:
// Generic method that execute "Operation" on underlying values.

#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
#ifdef __SYCL_DEVICE_ONLY__
template <template <typename> class Operation,
typename Ty = vec<DataT, NumElements>>
vec<DataT, NumElements>
Expand All @@ -1333,7 +1357,7 @@ template <typename Type, int NumElements> class vec {
}
return Result;
}
#else // __SYCL_USE_EXT_VECTOR_TYPE__
#else // __SYCL_DEVICE_ONLY__
template <template <typename> class Operation>
vec<DataT, NumElements>
operatorHelper(const vec<DataT, NumElements> &Rhs) const {
Expand All @@ -1344,12 +1368,12 @@ template <typename Type, int NumElements> class vec {
}
return Result;
}
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
#endif // __SYCL_DEVICE_ONLY__

// setValue and getValue should be able to operate on different underlying
// types: enum cl_float#N , builtin vector float#N, builtin type float.
// These versions are for N > 1.
#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
#ifdef __SYCL_DEVICE_ONLY__
template <int Num = NumElements, typename Ty = int,
typename = typename std::enable_if_t<1 != Num>>
constexpr void setValue(EnableIfNotHostHalf<Ty> Index, const DataT &Value,
Expand All @@ -1374,7 +1398,7 @@ template <typename Type, int NumElements> class vec {
constexpr DataT getValue(EnableIfHostHalf<Ty> Index, int) const {
return vec_data<DataT>::get(m_Data.s[Index]);
}
#else // __SYCL_USE_EXT_VECTOR_TYPE__
#else // __SYCL_DEVICE_ONLY__
template <int Num = NumElements,
typename = typename std::enable_if_t<1 != Num>>
constexpr void setValue(int Index, const DataT &Value, int) {
Expand All @@ -1386,7 +1410,7 @@ template <typename Type, int NumElements> class vec {
constexpr DataT getValue(int Index, int) const {
return vec_data<DataT>::get(m_Data[Index]);
}
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
#endif // __SYCL_DEVICE_ONLY__

// N==1 versions, used by host and device. Shouldn't trailing type be int?
template <int Num = NumElements,
Expand Down

0 comments on commit 051a557

Please sign in to comment.