Skip to content

Commit

Permalink
[SYCL][Matrix Headers] Add out of bounds load/store (intel#11210)
Browse files Browse the repository at this point in the history
Spec is in intel#11172
  • Loading branch information
dkhaldi authored Jan 4, 2024
1 parent caa4ed5 commit 4c17a7f
Show file tree
Hide file tree
Showing 3 changed files with 353 additions and 9 deletions.
33 changes: 33 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,39 @@ extern __DPCPP_SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL(
std::size_t Stride, __spv::MatrixLayout Layout = L,
__spv::Scope::Flag Sc = S, int MemOperand = 0);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
__spirv_CompositeConstructCheckedINTEL(const T Value, size_t Height,
size_t Stride, size_t Width,
size_t CoordX, size_t CoordY);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
__spirv_JointMatrixLoadCheckedINTEL(T *Ptr, std::size_t Stride,
size_t Height, size_t Width,
size_t CoordX, size_t CoordY,
__spv::MatrixLayout Layout = L,
__spv::Scope::Flag Sc = S,
int MemOperand = 0);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL void __spirv_JointMatrixStoreCheckedINTEL(
T *Ptr, __spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *Object,
std::size_t Stride, size_t Height, size_t Width, size_t CoordX,
size_t CoordY, __spv::MatrixLayout Layout = L, __spv::Scope::Flag Sc = S,
int MemOperand = 0);

template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
__spv::MatrixUse UA, __spv::MatrixUse UB, __spv::MatrixUse UC,
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
Expand Down
310 changes: 310 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,316 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_apply(
#endif
}

using namespace sycl::ext::oneapi::experimental::matrix;

// Begin out-of-bounds API

template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
layout Layout, typename T2>
inline __SYCL_ALWAYS_INLINE void joint_matrix_fill_checked(
Group, joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &Res,
const T2 &Value, size_t Stride, size_t Height, size_t Width, size_t CoordX,
size_t CoordY) {
#if defined(__SYCL_DEVICE_ONLY__)
using storage_element_type =
typename oneapi::detail::jm_type_interpretation_helper_trait<
T>::storage_element_type;
Res.spvm = __spirv_CompositeConstructCheckedINTEL<
storage_element_type, T, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value>(
static_cast<storage_element_type>(Value), Stride, Height, Width, CoordX,
CoordY);
#else
std::ignore = Res;
std::ignore = Value;
std::ignore = Stride;
std::ignore = Height;
std::ignore = Width;
std::ignore = CoordX;
std::ignore = CoordY;
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__)
}

template <
typename Group, typename S, typename T, size_t NumRows, size_t NumCols,
access::address_space Space, access::decorated IsDecorated,
std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value, bool> =
true>
inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
Group sg,
joint_matrix<Group, S, use::accumulator, NumRows, NumCols, layout::dynamic>
&Res,
multi_ptr<T, Space, IsDecorated> Src, size_t Stride, layout Layout,
size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
#if defined(__SYCL_DEVICE_ONLY__)
static_assert(Space != access::address_space::private_space,
"Joint Matrix doesn't support load from private memory!");
std::ignore = sg;
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
DecorT, S, NumRows, NumCols,
spv_matrix_use_traits<use::accumulator>::value,
spv_matrix_layout_traits<layout::dynamic>::value>(
Ptr, Stride, Height, Width, CoordX, CoordY,
sycl::detail::joint_matrix_layout_to_spv(Layout),
spv_scope_traits<Group>::value);
#else
std::ignore = sg;
std::ignore = Res;
std::ignore = Src;
std::ignore = Stride;
std::ignore = Height;
std::ignore = Width;
std::ignore = Layout;
std::ignore = CoordX;
std::ignore = CoordY;
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__)
}

template <
typename Group, typename S, typename T, use Use, size_t NumRows,
size_t NumCols, layout Layout, access::address_space Space,
access::decorated IsDecorated,
std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
(std::is_same<S, precision::tf32>::value &&
std::is_same<std::remove_const_t<T>, float>::value),
bool> = true>
inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &Res,
multi_ptr<T, Space, IsDecorated> Src, size_t Stride, size_t Height,
size_t Width, size_t CoordX, size_t CoordY) {
#if defined(__SYCL_DEVICE_ONLY__)
static_assert(Space != access::address_space::private_space,
"Joint Matrix doesn't support load from private memory!");
std::ignore = sg;
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
DecorT, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value>(
Ptr, Stride, Height, Width, CoordX, CoordY,
spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
#else
std::ignore = sg;
std::ignore = Res;
std::ignore = Src;
std::ignore = Stride;
std::ignore = Height;
std::ignore = Width;
std::ignore = CoordX;
std::ignore = CoordY;
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__)
}

template <typename Group, typename T, size_t NumRows, size_t NumCols,
access::address_space Space, access::decorated IsDecorated>
inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
Group sg,
joint_matrix<Group, T, use::accumulator, NumRows, NumCols, layout::dynamic>
&Src,
multi_ptr<T, Space, IsDecorated> Dst, size_t Stride, layout Layout,
size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
#if defined(__SYCL_DEVICE_ONLY__)
static_assert(Space != access::address_space::private_space,
"Joint Matrix doesn't support store to private memory!");
std::ignore = sg;
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
__spirv_JointMatrixStoreCheckedINTEL<
DecorT, T, NumRows, NumCols,
spv_matrix_use_traits<use::accumulator>::value,
spv_matrix_layout_traits<layout::dynamic>::value>(
Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
sycl::detail::joint_matrix_layout_to_spv(Layout),
spv_scope_traits<Group>::value);
#else
std::ignore = sg;
std::ignore = Src;
std::ignore = Dst;
std::ignore = Stride;
std::ignore = Height;
std::ignore = Width;
std::ignore = Layout;
std::ignore = CoordX;
std::ignore = CoordY;
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__)
}

template <typename Group, typename T, typename Tp, use Use, size_t NumRows,
size_t NumCols, layout Layout, access::address_space Space,
access::decorated IsDecorated,
std::enable_if_t<Use == use::a || Use == use::b, bool> = true>
inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
Group sg, const joint_matrix<Group, Tp, Use, NumRows, NumCols, Layout> &Src,
multi_ptr<T, Space, IsDecorated> Dst, size_t Stride, size_t Height,
size_t Width, size_t CoordX, size_t CoordY) {
#if defined(__SYCL_DEVICE_ONLY__)
static_assert(Space != access::address_space::private_space,
"Joint Matrix doesn't support store to private memory!");
std::ignore = sg;
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
__spirv_JointMatrixStoreCheckedINTEL<DecorT, Tp, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value>(
Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
#else
std::ignore = sg;
std::ignore = Src;
std::ignore = Dst;
std::ignore = Stride;
std::ignore = Height;
std::ignore = Width;
std::ignore = CoordX;
std::ignore = CoordY;
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__)
}

// Annotated pointer overloads:
template <typename Group, typename S, typename T, size_t NumRows,
size_t NumCols, typename PropertyListT,
std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value,
bool> = true>
inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
Group sg,
joint_matrix<Group, S, use::accumulator, NumRows, NumCols, layout::dynamic>
&Res,
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> Src,
size_t Stride, layout Layout, size_t Height, size_t Width, size_t CoordX,
size_t CoordY) {
#if defined(__SYCL_DEVICE_ONLY__)
std::ignore = sg;
T *Ptr = Src.get();
Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
spv_matrix_layout_traits<layout::dynamic>::value>(
Ptr, Stride, Height, Width, CoordX, CoordY,
sycl::detail::joint_matrix_layout_to_spv(Layout),
spv_scope_traits<Group>::value);
#else
std::ignore = sg;
std::ignore = Res;
std::ignore = Src;
std::ignore = Stride;
std::ignore = Height;
std::ignore = Width;
std::ignore = Layout;
std::ignore = CoordX;
std::ignore = CoordY;
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__)
}

template <
typename Group, typename S, typename T, use Use, size_t NumRows,
size_t NumCols, layout Layout, typename PropertyListT,
std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
(std::is_same<S, precision::tf32>::value &&
std::is_same<std::remove_const_t<T>, float>::value),
bool> = true>
inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &Res,
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> Src,
size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
#if defined(__SYCL_DEVICE_ONLY__)
std::ignore = sg;
T *Ptr = Src.get();
Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
T, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value>(
Ptr, Stride, Height, Width, CoordX, CoordY,
spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
#else
std::ignore = sg;
std::ignore = Res;
std::ignore = Src;
std::ignore = Stride;
std::ignore = Height;
std::ignore = Width;
std::ignore = CoordX;
std::ignore = CoordY;
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__)
}

template <typename Group, typename T, size_t NumRows, size_t NumCols,
typename PropertyListT>
inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
Group sg,
joint_matrix<Group, T, use::accumulator, NumRows, NumCols, layout::dynamic>
&Src,
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> Dst,
size_t Stride, layout Layout, size_t Height, size_t Width, size_t CoordX,
size_t CoordY) {
#if defined(__SYCL_DEVICE_ONLY__)
std::ignore = sg;
T *Ptr = Dst.get();
__spirv_JointMatrixStoreCheckedINTEL<
T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
spv_matrix_layout_traits<layout::dynamic>::value>(
Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
sycl::detail::joint_matrix_layout_to_spv(Layout),
spv_scope_traits<Group>::value);
#else
std::ignore = sg;
std::ignore = Src;
std::ignore = Dst;
std::ignore = Stride;
std::ignore = Height;
std::ignore = Width;
std::ignore = Layout;
std::ignore = CoordX;
std::ignore = CoordY;
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__)
}

template <typename Group, typename T, typename Tp, use Use, size_t NumRows,
size_t NumCols, layout Layout, typename PropertyListT,
std::enable_if_t<Use == use::a || Use == use::b, bool> = true>
inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
Group sg, const joint_matrix<Group, Tp, Use, NumRows, NumCols, Layout> &Src,
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> Dst,
size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
#if defined(__SYCL_DEVICE_ONLY__)
std::ignore = sg;
T *Ptr = Dst.get();
__spirv_JointMatrixStoreCheckedINTEL<T, Tp, NumRows, NumCols,
spv_matrix_use_traits<Use>::value,
spv_matrix_layout_traits<Layout>::value>(
Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
#else
std::ignore = sg;
std::ignore = Src;
std::ignore = Dst;
std::ignore = Stride;
std::ignore = Height;
std::ignore = Width;
std::ignore = CoordX;
std::ignore = CoordY;
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__)
}
// End out-of-bounds API

} // namespace intel::experimental::matrix

} // namespace ext
Expand Down
19 changes: 10 additions & 9 deletions sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,23 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) {
sub_b;
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
// bounds-checked load where width and height are added
joint_matrix_fill_checked(sg, sub_c, 1, M, N);
ext::intel::experimental::matrix::joint_matrix_fill_checked(
sg, sub_c, 1, N, M, N, sg_startx * TM, sg_starty / SG_SZ * TN);
for (int k = 0; k < K; k += TK) {
// bounds-checked load where width and height are added
joint_matrix_load_checked(sg, sub_a, pA + (sg_startx * TM) * K + k,
K, M, K);
ext::intel::experimental::matrix::joint_matrix_load_checked(
sg, sub_a, pA, K, M, K, sg_startx * TM, k);
// Assume we alreay in vnni format.
// bounds-checked load where width and height are added
joint_matrix_load_checked(
sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor,
N * vnniFactor, K / vnniFactor, N * vnniFactor);
ext::intel::experimental::matrix::joint_matrix_load_checked(
sg, sub_b, pB, N * vnniFactor, K / vnniFactor, N * vnniFactor,
k, sg_starty / SG_SZ * TN * vnniFactor);
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
}
// bounds-checked store where width and height are added
joint_matrix_store_checked(
sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N,
layout::row_major, M, N);
ext::intel::experimental::matrix::joint_matrix_store_checked(
sg, sub_c, pC, N, layout::row_major, M, N, sg_startx * TM,
sg_starty / SG_SZ * TN);
}); // parallel for
}).wait();
}
Expand Down

0 comments on commit 4c17a7f

Please sign in to comment.