Skip to content

Commit

Permalink
[SYCL][Matrix spec] Add joint_matrix_prefetch and overloads of load/s…
Browse files Browse the repository at this point in the history
…tore with annotated_ptr (#11473)
  • Loading branch information
dkhaldi authored Feb 21, 2024
1 parent 16e06ff commit 04a222f
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,28 @@ template <typename Group, typename T, size_t Rows, size_t Cols,
access::decorated IsDecorated>
void joint_matrix_store(Group g,
const joint_matrix<Group, T, use::a, Rows, Cols, Layout> &res,
multi_ptr<T, Space, IsDecorated> src, size_t stride);
multi_ptr<T, Space, IsDecorated> dest, size_t stride);

template <typename Group, typename T, size_t Rows, size_t Cols,
layout Layout, access::address_space Space,
access::decorated IsDecorated>
void joint_matrix_store(Group g,
const joint_matrix<Group, T, use::b, Rows, Cols, Layout> &res,
multi_ptr<T, Space, IsDecorated> src, size_t stride);
multi_ptr<T, Space, IsDecorated> dest, size_t stride);

template <typename Group, typename T, size_t Rows, size_t Cols,
layout Layout, typename PropertyListT>
void joint_matrix_store(Group g,
const joint_matrix<Group, T, use::a, Rows, Cols, Layout> &res,
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> dest,
size_t stride);

template <typename Group, typename T, size_t Rows, size_t Cols,
layout Layout, typename PropertyListT>
void joint_matrix_store(Group g,
const joint_matrix<Group, T, use::b, Rows, Cols, Layout> &res,
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> dest,
size_t stride);

} // namespace sycl::ext::intel::experimental::matrix
```
Expand Down Expand Up @@ -327,6 +341,7 @@ q.submit([&](sycl::handler& cgh) {
});
q.wait();
```

== Revision History

[frame="none",options="header"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,23 @@ void joint_matrix_load(Group g,
joint_matrix<Group, T1, Use, Rows, Cols, Layout> &res,
multi_ptr<T2, Space, IsDecorated> src, size_t stride);

// Only available when std::is_same_v<T1, std::remove_const_t<T2>>
template <typename Group, typename T1, typename T2,
size_t Rows, size_t Cols,
typename PropertyListT>
void joint_matrix_load(Group g,
joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
annotated_ptr<T2, PropertyListT> src, size_t stride, layout Layout);

// Only available when Layout != layout::dynamic
// and when std::is_same_v<T1, std::remove_const_t<T2>>
template <typename Group, typename T1, typename T2,
size_t Rows, size_t Cols, use Use, layout Layout,
typename PropertyListT>
void joint_matrix_load(Group g,
joint_matrix<Group, T1, Use, Rows, Cols, Layout> &res,
annotated_ptr<T2, PropertyListT> src, size_t stride);

} // namespace sycl::ext::oneapi::experimental::matrix
```

Expand All @@ -248,6 +265,33 @@ fashion. `stride` describes the number of elements between consecutive
rows for the row major layout, or between columns for the column major
layout.

The two last overloads of `joint_matrix_load` take
`sycl::ext::oneapi::experimental::annotated_ptr` as argument instead
of `sycl::multi_ptr`. The property list associated with the
`annotated_ptr` argument represents the compile-time constant
properties for cache control included in the SYCL extenion
link:../../proposed/sycl_ext_intel_cache_controls.asciidoc[sycl_ext_intel_cache_controls]
as illustrated in the example below.

```c++
using syclex = sycl::ext::oneapi::experimental;
using syclintelex = sycl::ext::intel::experimental;

auto A_ptr = syclex::annotated_ptr{A,
syclex::properties{syclintelex::read_hint<
syclintelex::cache_control<syclintelex::cache_mode::cached,
syclex::cache_level::L2>>}};
q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> it) {
sub_group sg = it.get_sub_group();
joint_matrix<sub_group, bfloat16, use::a, tM, tK, layout::row_major> tA;
for (int k = 0; k < K; k += tileK) {
// User specifies that this load will be cached to L2
joint_matrix_load(sg, tA, A_ptr + sg_startx * tM * K + k, K);
...
}
});
```

==== Store
```c++
namespace sycl::ext::oneapi::experimental::matrix {
Expand All @@ -259,6 +303,12 @@ void joint_matrix_store(Group g,
const joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
multi_ptr<T2, Space, IsDecorated> dest, size_t stride, layout Layout);

template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols,
typename PropertyListT>
void joint_matrix_store(Group g,
const joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
annotated_ptr<T2, PropertyListT> dest, size_t stride, layout Layout);

} // namespace sycl::ext::oneapi::experimental::matrix
```
This function stores the data in the accumulator matrix from the
Expand All @@ -270,6 +320,11 @@ written in a row (`row_major`), column major (`col_major`)
fashion. `stride` describes the number of elements between consecutive
rows for the row major layout, or between columns for the column major layout.

The second overload of `joint_matrix_store` takes
`sycl::ext::oneapi::experimental::annotated_ptr` as argument instead
of `sycl::multi_ptr`. The property list associated with the
`annotated_ptr` argument represents the compile-time constant
properties for cache control included in the SYCL extenion link:../../proposed/sycl_ext_intel_cache_controls.asciidoc[sycl_ext_intel_cache_controls]

==== Multiply and Add

Expand Down Expand Up @@ -372,6 +427,47 @@ joint_matrix_apply(sg, C, [=](T &x) {
});
```

==== Prefetch

```c++
namespace sycl::ext::oneapi::experimental::matrix {

template <size_t Rows, size_t Cols, typename Group, typename T,
typename Properties = empty_properties_t>
void joint_matrix_prefetch(Group g, T* ptr, size_t stride, layout Layout,
Properties properties = {});

} // namespace sycl::ext::oneapi::experimental::matrix
```

`joint_matrix_prefetch` allows groups of work-items to cooperatively
prefetch `Rows x Cols` elements in a 2d manner. This function is a group
function, as defined in Section 4.17.3 of the core SYCL
specification.

The level of cache targeted by `joint_matrix_prefetch` in the last
argument is specified using the compile-time properties defined in the
SYCL extension
link:../../proposed/sycl_ext_oneapi_prefetch.asciidoc[sycl_ext_oneapi_prefetch]
as illustrated in the example below. When no cache levels are
specified, the default behavior is to prefetch into the lowest level
cache (i.e. L1).

```c++
using syclex = sycl::ext::oneapi::experimental;

bfloat16 *memA = malloc_shared<bfloat16>(M*K, q);
q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> it) {
sub_group sg = it.get_sub_group();
for (int k = 0; k < K; k += tileK) {
syclex::joint_matrix_prefetch<tM, tK>(sg, memA + tM * K + tK, K,
layout::row_major,
syclex::properties{syclex::prefetch_hint_L2});
...
}
});
```

=== Support for Machine Learning Types
Some devices support special matrix element types that are commonly
used in machine learning algorithms.
Expand Down Expand Up @@ -1035,4 +1131,6 @@ and Intel XMX
|8 |2023-10-05 |Mahmoud Moadeli |Add AMD Matrix Core supported combinations
|9 |2023-11-13 |Dounia Khaldi |Add Granite Rapids Intel AMX
supported combinations
|9 |2023-12-04 |Dounia Khaldi |Add prefetch and `annotated_ptr`
load/store overloads
|======================

0 comments on commit 04a222f

Please sign in to comment.