From 4c17a7f395d589767f2e7996c25c07e52cf8e161 Mon Sep 17 00:00:00 2001 From: Dounia Khaldi Date: Thu, 4 Jan 2024 04:09:51 -0600 Subject: [PATCH] [SYCL][Matrix Headers] Add out of bounds load/store (#11210) Spec is in https://github.com/intel/llvm/pull/11172 --- sycl/include/CL/__spirv/spirv_ops.hpp | 33 ++ .../sycl/ext/oneapi/matrix/matrix-intel.hpp | 310 ++++++++++++++++++ .../Matrix/joint_matrix_out_bounds_impl.hpp | 19 +- 3 files changed, 353 insertions(+), 9 deletions(-) diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index 82657dbdbc8cf..1189aef849b87 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -45,6 +45,39 @@ extern __DPCPP_SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL( std::size_t Stride, __spv::MatrixLayout Layout = L, __spv::Scope::Flag Sc = S, int MemOperand = 0); +template +extern __DPCPP_SYCL_EXTERNAL + __spv::__spirv_JointMatrixINTEL * + __spirv_CompositeConstructCheckedINTEL(const T Value, size_t Height, + size_t Stride, size_t Width, + size_t CoordX, size_t CoordY); + +template +extern __DPCPP_SYCL_EXTERNAL + __spv::__spirv_JointMatrixINTEL * + __spirv_JointMatrixLoadCheckedINTEL(T *Ptr, std::size_t Stride, + size_t Height, size_t Width, + size_t CoordX, size_t CoordY, + __spv::MatrixLayout Layout = L, + __spv::Scope::Flag Sc = S, + int MemOperand = 0); + +template +extern __DPCPP_SYCL_EXTERNAL void __spirv_JointMatrixStoreCheckedINTEL( + T *Ptr, __spv::__spirv_JointMatrixINTEL *Object, + std::size_t Stride, size_t Height, size_t Width, size_t CoordX, + size_t CoordY, __spv::MatrixLayout Layout = L, __spv::Scope::Flag Sc = S, + int MemOperand = 0); + template +inline __SYCL_ALWAYS_INLINE void joint_matrix_fill_checked( + Group, joint_matrix &Res, + const T2 &Value, size_t Stride, size_t Height, size_t Width, size_t CoordX, + size_t CoordY) { +#if defined(__SYCL_DEVICE_ONLY__) + using storage_element_type = + typename oneapi::detail::jm_type_interpretation_helper_trait< + T>::storage_element_type; + Res.spvm = __spirv_CompositeConstructCheckedINTEL< + storage_element_type, T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + static_cast(Value), Stride, Height, Width, CoordX, + CoordY); +#else + std::ignore = Res; + std::ignore = Value; + std::ignore = Stride; + std::ignore = Height; + std::ignore = Width; + std::ignore = CoordX; + std::ignore = CoordY; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) +} + +template < + typename Group, typename S, typename T, size_t NumRows, size_t NumCols, + access::address_space Space, access::decorated IsDecorated, + std::enable_if_t>::value, bool> = + true> +inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked( + Group sg, + joint_matrix + &Res, + multi_ptr Src, size_t Stride, layout Layout, + size_t Height, size_t Width, size_t CoordX, size_t CoordY) { +#if defined(__SYCL_DEVICE_ONLY__) + static_assert(Space != access::address_space::private_space, + "Joint Matrix doesn't support load from private memory!"); + std::ignore = sg; + using DecorT = typename sycl::detail::DecoratedType::type; + DecorT *Ptr = sycl::detail::getDecorated(Src); + Res.spvm = __spirv_JointMatrixLoadCheckedINTEL< + DecorT, S, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, Stride, Height, Width, CoordX, CoordY, + sycl::detail::joint_matrix_layout_to_spv(Layout), + spv_scope_traits::value); +#else + std::ignore = sg; + std::ignore = Res; + std::ignore = Src; + std::ignore = Stride; + std::ignore = Height; + std::ignore = Width; + std::ignore = Layout; + std::ignore = CoordX; + std::ignore = CoordY; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) +} + +template < + typename Group, typename S, typename T, use Use, size_t NumRows, + size_t NumCols, layout Layout, access::address_space Space, + access::decorated IsDecorated, + std::enable_if_t>::value || + (std::is_same::value && + std::is_same, float>::value), + bool> = true> +inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked( + Group sg, joint_matrix &Res, + multi_ptr Src, size_t Stride, size_t Height, + size_t Width, size_t CoordX, size_t CoordY) { +#if defined(__SYCL_DEVICE_ONLY__) + static_assert(Space != access::address_space::private_space, + "Joint Matrix doesn't support load from private memory!"); + std::ignore = sg; + using DecorT = typename sycl::detail::DecoratedType::type; + DecorT *Ptr = sycl::detail::getDecorated(Src); + Res.spvm = __spirv_JointMatrixLoadCheckedINTEL< + DecorT, S, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, Stride, Height, Width, CoordX, CoordY, + spv_matrix_layout_traits::value, spv_scope_traits::value); +#else + std::ignore = sg; + std::ignore = Res; + std::ignore = Src; + std::ignore = Stride; + std::ignore = Height; + std::ignore = Width; + std::ignore = CoordX; + std::ignore = CoordY; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) +} + +template +inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked( + Group sg, + joint_matrix + &Src, + multi_ptr Dst, size_t Stride, layout Layout, + size_t Height, size_t Width, size_t CoordX, size_t CoordY) { +#if defined(__SYCL_DEVICE_ONLY__) + static_assert(Space != access::address_space::private_space, + "Joint Matrix doesn't support store to private memory!"); + std::ignore = sg; + using DecorT = typename sycl::detail::DecoratedType::type; + DecorT *Ptr = sycl::detail::getDecorated(Dst); + __spirv_JointMatrixStoreCheckedINTEL< + DecorT, T, NumRows, NumCols, + spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY, + sycl::detail::joint_matrix_layout_to_spv(Layout), + spv_scope_traits::value); +#else + std::ignore = sg; + std::ignore = Src; + std::ignore = Dst; + std::ignore = Stride; + std::ignore = Height; + std::ignore = Width; + std::ignore = Layout; + std::ignore = CoordX; + std::ignore = CoordY; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) +} + +template = true> +inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked( + Group sg, const joint_matrix &Src, + multi_ptr Dst, size_t Stride, size_t Height, + size_t Width, size_t CoordX, size_t CoordY) { +#if defined(__SYCL_DEVICE_ONLY__) + static_assert(Space != access::address_space::private_space, + "Joint Matrix doesn't support store to private memory!"); + std::ignore = sg; + using DecorT = typename sycl::detail::DecoratedType::type; + DecorT *Ptr = sycl::detail::getDecorated(Dst); + __spirv_JointMatrixStoreCheckedINTEL::value, + spv_matrix_layout_traits::value>( + Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY, + spv_matrix_layout_traits::value, spv_scope_traits::value); +#else + std::ignore = sg; + std::ignore = Src; + std::ignore = Dst; + std::ignore = Stride; + std::ignore = Height; + std::ignore = Width; + std::ignore = CoordX; + std::ignore = CoordY; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) +} + +// Annotated pointer overloads: +template >::value, + bool> = true> +inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked( + Group sg, + joint_matrix + &Res, + ext::oneapi::experimental::annotated_ptr Src, + size_t Stride, layout Layout, size_t Height, size_t Width, size_t CoordX, + size_t CoordY) { +#if defined(__SYCL_DEVICE_ONLY__) + std::ignore = sg; + T *Ptr = Src.get(); + Res.spvm = __spirv_JointMatrixLoadCheckedINTEL< + T, S, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, Stride, Height, Width, CoordX, CoordY, + sycl::detail::joint_matrix_layout_to_spv(Layout), + spv_scope_traits::value); +#else + std::ignore = sg; + std::ignore = Res; + std::ignore = Src; + std::ignore = Stride; + std::ignore = Height; + std::ignore = Width; + std::ignore = Layout; + std::ignore = CoordX; + std::ignore = CoordY; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) +} + +template < + typename Group, typename S, typename T, use Use, size_t NumRows, + size_t NumCols, layout Layout, typename PropertyListT, + std::enable_if_t>::value || + (std::is_same::value && + std::is_same, float>::value), + bool> = true> +inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked( + Group sg, joint_matrix &Res, + ext::oneapi::experimental::annotated_ptr Src, + size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY) { +#if defined(__SYCL_DEVICE_ONLY__) + std::ignore = sg; + T *Ptr = Src.get(); + Res.spvm = __spirv_JointMatrixLoadCheckedINTEL< + T, S, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, Stride, Height, Width, CoordX, CoordY, + spv_matrix_layout_traits::value, spv_scope_traits::value); +#else + std::ignore = sg; + std::ignore = Res; + std::ignore = Src; + std::ignore = Stride; + std::ignore = Height; + std::ignore = Width; + std::ignore = CoordX; + std::ignore = CoordY; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) +} + +template +inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked( + Group sg, + joint_matrix + &Src, + ext::oneapi::experimental::annotated_ptr Dst, + size_t Stride, layout Layout, size_t Height, size_t Width, size_t CoordX, + size_t CoordY) { +#if defined(__SYCL_DEVICE_ONLY__) + std::ignore = sg; + T *Ptr = Dst.get(); + __spirv_JointMatrixStoreCheckedINTEL< + T, T, NumRows, NumCols, spv_matrix_use_traits::value, + spv_matrix_layout_traits::value>( + Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY, + sycl::detail::joint_matrix_layout_to_spv(Layout), + spv_scope_traits::value); +#else + std::ignore = sg; + std::ignore = Src; + std::ignore = Dst; + std::ignore = Stride; + std::ignore = Height; + std::ignore = Width; + std::ignore = Layout; + std::ignore = CoordX; + std::ignore = CoordY; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) +} + +template = true> +inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked( + Group sg, const joint_matrix &Src, + ext::oneapi::experimental::annotated_ptr Dst, + size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY) { +#if defined(__SYCL_DEVICE_ONLY__) + std::ignore = sg; + T *Ptr = Dst.get(); + __spirv_JointMatrixStoreCheckedINTEL::value, + spv_matrix_layout_traits::value>( + Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY, + spv_matrix_layout_traits::value, spv_scope_traits::value); +#else + std::ignore = sg; + std::ignore = Src; + std::ignore = Dst; + std::ignore = Stride; + std::ignore = Height; + std::ignore = Width; + std::ignore = CoordX; + std::ignore = CoordY; + throw runtime_error("joint matrix is not supported on host device.", + PI_ERROR_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) +} +// End out-of-bounds API + } // namespace intel::experimental::matrix } // namespace ext diff --git a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp index 7dfe8ee1376e6..3607eab14fbc0 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp @@ -51,22 +51,23 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) { sub_b; joint_matrix sub_c; // bounds-checked load where width and height are added - joint_matrix_fill_checked(sg, sub_c, 1, M, N); + ext::intel::experimental::matrix::joint_matrix_fill_checked( + sg, sub_c, 1, N, M, N, sg_startx * TM, sg_starty / SG_SZ * TN); for (int k = 0; k < K; k += TK) { // bounds-checked load where width and height are added - joint_matrix_load_checked(sg, sub_a, pA + (sg_startx * TM) * K + k, - K, M, K); + ext::intel::experimental::matrix::joint_matrix_load_checked( + sg, sub_a, pA, K, M, K, sg_startx * TM, k); // Assume we alreay in vnni format. // bounds-checked load where width and height are added - joint_matrix_load_checked( - sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor, - N * vnniFactor, K / vnniFactor, N * vnniFactor); + ext::intel::experimental::matrix::joint_matrix_load_checked( + sg, sub_b, pB, N * vnniFactor, K / vnniFactor, N * vnniFactor, + k, sg_starty / SG_SZ * TN * vnniFactor); joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } // bounds-checked store where width and height are added - joint_matrix_store_checked( - sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, - layout::row_major, M, N); + ext::intel::experimental::matrix::joint_matrix_store_checked( + sg, sub_c, pC, N, layout::row_major, M, N, sg_startx * TM, + sg_starty / SG_SZ * TN); }); // parallel for }).wait(); }