From 0eac61876fce2d524f16b9b1d4239a7733f9cbbd Mon Sep 17 00:00:00 2001 From: Dmitry Sidorov Date: Thu, 8 Feb 2024 11:59:38 +0100 Subject: [PATCH] [SYCL][Matrix] Correct Prefetch instruction usage (#12623) Signed-off-by: Sidorov, Dmitry --- sycl/include/CL/__spirv/spirv_ops.hpp | 9 +++++---- sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp | 5 +++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index ea1a6580d30e6..9af5b7e75ae38 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -174,10 +174,11 @@ extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL * __spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL *, Ts val, size_t i); -template -extern __DPCPP_SYCL_EXTERNAL void __spirv_JointMatrixPrefetchINTEL( - T *Ptr, std::size_t coordX, std::size_t coordY, unsigned int CacheLevel, - __spv::MatrixLayout Layout, std::size_t Stride); +template +extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixPrefetchINTEL( + T *Ptr, std::size_t coordX, std::size_t coordY, std::size_t NumRows, + std::size_t NumCols, unsigned int CacheLevel, __spv::MatrixLayout Layout, + std::size_t Stride); #ifndef __SPIRV_BUILTIN_DECLARATIONS__ #error \ diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 98aea6f04a48b..a07e9c144ba6a 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -524,8 +524,9 @@ joint_matrix_prefetch(Group sg, T *Ptr, size_t stride, // Will be removed once SPIRV implementation also uses offsetpointer size_t coordX = 0; size_t coordY = 0; - __spirv_JointMatrixPrefetchINTEL( - Ptr, coordX, coordY, detail::PropertyMetaInfo::value, + __spirv_CooperativeMatrixPrefetchINTEL( + Ptr, coordX, coordY, NumRows, NumCols, + detail::PropertyMetaInfo::value, sycl::detail::joint_matrix_layout_to_spv(Layout), stride); #endif // defined(__NVPTX__) #else