Skip to content

Commit

Permalink
Address Greg's comments: improve the examples and some wording, add t…
Browse files Browse the repository at this point in the history
…he specific requirements as an appendix subsection
  • Loading branch information
dkhaldi committed Mar 12, 2024
1 parent df2c2d4 commit 95cdb68
Showing 1 changed file with 107 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -230,35 +230,46 @@ supporting the out of bounds checked APIs that are defined in this section.
In this section, we refer to the memory buffer where a `joint_matrix`
is loaded from or stored to as the global matrix. This global matrix
is also interpreted as a two-dimensional memory region as follows, where
`Height` is number of rows in the global matrix, `Width` is number of
columns in the global matrix, `stride` is number of columns that include
`GlobalRows` is number of rows in the global matrix, `GlobalCols` is number of
columns in the global matrix, `Stride` is number of columns that include
the out of bounds data (depicted as x here).

|<-----------`Width`-------->|

|DDDDDDDDDDDDDDDDDxxx|

|DDDDDDDDDDDDDDDDDxxx|

|......................................................| `Height`

|DDDDDDDDDDDDDDDDDxxx|

|DDDDDDDDDDDDDDDDDxxx|
```
GlobalCols
<----------->
dddddddddddddxxx ^
dddddddddddddxxx | GlobalRows
dddddddddddddxxx v
xxxxxxxxxxxxxxxx
<-------------->
Stride
```

|<------------`Stride`---------->|
In the diagram above, the global matrix has 13 columns and 3
rows. This is padded out to be evenly divisible by a joint matrix with
8 columns and 2 rows, which results in a stride of 16.

Note that joint matrix shape `Rows` and `Cols` represents a sub-block
of the picture above. The out of bounds data results when the global
matrix size is not evenly divisible by the joint matrix size.

==== Checked APIs
`joint_matrix_load`, `joint_matrix_store`, and `joint_matrix_fill`
operations do not perform bounds checking. When the global matrix size is
not multiple of the joint matrix size, this extension adds a new form
of API to load, store, and fill joint matrix while checking the
bounds. For load and fill, the out-of-bounds elements are set to
0. For the store, they are dropped out.
When an algorithm iterates over the global matrix, it loads or stores
elements that correspond to a joint matrix. When the global matrix
size does not evenly divide by the joint matrix size, some of these
loads or stores access the extra elements marked "x" in the diagram
above. The standard joint matrix functions (`joint_matrix_load`,
`joint_matrix_store` and `joint_matrix_fill` do not do any bounds
checking in this case, so they simply load or store to these extra
elements. This could cause unexpected values to be loaded into the
joint matrix for these elements. These functions could also cause a
memory fault if the extra elements are not valid addresses.

The checked APIs described below do not attempt to access the extra
memory. The checked load is guaranteed to return 0 for the extra
elements, and the checked store simply ignores stores to the extra
elements. Neither function will cause a memory fault if the extra
elements correspond to invalid addresses.

These functions are similar to the existing ones without bounds
checking, namely `joint_matrix_fill`, `joint_matrix_load`, and
Expand All @@ -268,32 +279,55 @@ checking, namely `joint_matrix_fill`, `joint_matrix_load`, and
the global memory matrix, which is different from the APIs that do not
do bounds checking. Those non-bounds-checking APIs take a pointer to
the base of the joint matrix.
* The coordinates `CoordX` and `CoordY` into the global matrix to
* The coordinates `RowIndex` and `ColIndex` into the global matrix to
calculate the pointer offset to load/store are given as separate
arguments.
* These variants take extra arguments to determine the global bounds
`Height` and `Width` of the global matrix.
`GlobalRows` and `GlobalCols` of the global matrix.

The out of bounds functions check that the joint matrix block starting at the
address `base_src` and ending at `base_src + Rows * Stride + Cols`
does not exceed the global matrix block starting at `[base_src]` and
ending at `base_src + Height * stride + Width`.
For instance, load of joint matrix `sub_b` defined as
`joint_matrix<sub_group, bfloat16, use::b, 2, 3, layout::row_major>
sub_b;` from global matrix of `Height=2` and `Width=10`, using
`joint_matrix_load_checked(sg, sub_b, base_src, 12, 2, 10, 0, 9);`,
checks that the load starting at `base_src + 9` does not exceed
`base_ptr + 10`. This results into reading only 1 column. The two
columns left that exceed the `Width=10` columns are filled by zero.
ending at `base_src + GlobalRows * Stride + GlobalCols`.

To illustrate, consider the global matrix shown above which has 13
columns and 3 rows (`GlobalRows=3` and `GlobalCols=13`), where the
joint matrix size is 8 columns by 2 rows defined as
```
joint_matrix<sub_group, bfloat16, use::b, 2, 8, layout::row_major> sub_b;
```
The load of the joint matrix at coordinate [8, 2] (column number 8,
row number 2 in the global matrix), overlaps the extra elements in
both dimensions. This is shown below, where capital letters correspond
to the elements that are accessed by this joint matrix load:

```
GlobalCols
<----------->
dddddddddddddxxx ^
dddddddddddddxxx | GlobalRows
ddddddddDDDDDXXX v
xxxxxxxxXXXXXXXX
<-------------->
Stride
```

If the joint matrix is loaded via `joint_matrix_load_checked` using
```
joint_matrix_load_checked(sg, sub_b, base_src, 16, 3, 13, 2, 8);
```
the extra elements that are shown with capital `X` are not accessed in
memory, and those elements are guaranteed to have the value zero in
the joint matrix after the load operation completes.

```c++
namespace sycl::ext::intel::experimental::matrix {

template <typename Group, typename T, size_t Rows, size_t Cols,
use Use, layout Layout, typename Tv>
void joint_matrix_fill_checked(Group g, joint_matrix<Group, T, Use, Rows,
Cols, Layout> &m, Tv v, size_t Height, size_t Width,
size_t CoordX, size_t CoordY);
Cols, Layout> &m, Tv v, size_t GlobalRows, size_t GlobalCols,
size_t RowIndex, size_t ColIndex);

// Only available when std::is_same_v<T1, std::remove_const_t<T2>>
template <typename Group, typename T1, typename T2,
Expand All @@ -302,7 +336,8 @@ template <typename Group, typename T1, typename T2,
void joint_matrix_load_checked(Group g,
joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
multi_ptr<T2, Space, IsDecorated> base_src, size_t Stride,
layout Layout, size_t Height, size_t Width, size_t CoordX, size_t CoordY);
layout Layout, size_t GlobalRows, size_t GlobalCols,
size_t RowIndex, size_t ColIndex);

// Only available when Layout != layout::dynamic
// and when std::is_same_v<T1, std::remove_const_t<T2>>
Expand All @@ -313,16 +348,16 @@ template <typename Group, typename T1, typename T2,
void joint_matrix_load_checked(Group g,
joint_matrix<Group, T1, Use, Rows, Cols, Layout> &res,
multi_ptr<T2, Space, IsDecorated> base_src, size_t Stride,
size_t Height, size_t Width, size_t CoordX, size_t CoordY);
size_t GlobalRows, size_t GlobalCols, size_t RowIndex, size_t ColIndex);

// Only available when std::is_same_v<T1, std::remove_const_t<T2>>
template <typename Group, typename T1, typename T2,
size_t Rows, size_t Cols, typename PropertyListT>
void joint_matrix_load_checked(Group g,
joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
ext::oneapi::experimental::annotated_ptr<T2, PropertyListT> base_src,
size_t Stride, layout Layout, size_t Height, size_t Width, size_t
CoordX, size_t CoordY);
size_t Stride, layout Layout, size_t GlobalRows, size_t GlobalCols,
size_t RowIndex, size_t ColIndex);

// Only available when Layout != layout::dynamic
// and when std::is_same_v<T1, std::remove_const_t<T2>>
Expand All @@ -331,52 +366,55 @@ template <typename Group, typename T1, typename T2, size_t Rows,
void joint_matrix_load_checked(Group g,
joint_matrix<Group, T1, Use, Rows, Cols, Layout> &res,
ext::oneapi::experimental::annotated_ptr<T2, PropertyListT> base_src,
size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY);
size_t Stride, size_t GlobalRows, size_t GlobalCols,
size_t RowIndex, size_t ColIndex);

template <typename Group, typename T, size_t Rows, size_t Cols,
access::address_space Space, access::decorated IsDecorated>
void joint_matrix_store_checked(Group g,
const joint_matrix<Group, T, use::accumulator, Rows, Cols, layout::dynamic> &res,
multi_ptr<T, Space, IsDecorated> base_dest, size_t Stride, layout Layout,
size_t Height, size_t Width, size_t CoordX, size_t CoordY);
size_t GlobalRows, size_t GlobalCols, size_t RowIndex, size_t ColIndex);

template <typename Group, typename T, size_t Rows, size_t Cols,
layout Layout, access::address_space Space,
access::decorated IsDecorated>
void joint_matrix_store_checked(Group g,
const joint_matrix<Group, T, use::a, Rows, Cols, Layout> &res,
multi_ptr<T, Space, IsDecorated> base_dest, size_t Stride,
size_t Height, size_t Width, size_t CoordX, size_t CoordY);
size_t GlobalRows, size_t GlobalCols, size_t RowIndex, size_t ColIndex);

template <typename Group, typename T, size_t Rows, size_t Cols,
layout Layout, access::address_space Space,
access::decorated IsDecorated>
void joint_matrix_store_checked(Group g,
const joint_matrix<Group, T, use::b, Rows, Cols, Layout> &res,
multi_ptr<T, Space, IsDecorated> base_dest, size_t Stride,
size_t Height, size_t Width, size_t CoordX, size_t CoordY);
size_t GlobalRows, size_t GlobalCols, size_t RowIndex, size_t ColIndex);

template <typename Group, typename T, size_t Rows, size_t Cols,
typename PropertyListT>
void joint_matrix_store_checked(Group g,
const joint_matrix<Group, T, use::accumulator, Rows, Cols, layout::dynamic> &res,
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> base_dest,
size_t Stride, layout Layout, size_t Height, size_t Width,
size_t CoordX, size_t CoordY);
size_t Stride, layout Layout, size_t GlobalRows, size_t GlobalCols,
size_t RowIndex, size_t ColIndex);

template <typename Group, typename T, size_t Rows, size_t Cols,
layout Layout, typename PropertyListT>
void joint_matrix_store_checked(Group g,
const joint_matrix<Group, T, use::a, Rows, Cols, Layout> &res,
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> base_dest,
size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY);
size_t Stride, size_t GlobalRows, size_t GlobalCols,
size_t RowIndex, size_t ColIndex);

template <typename Group, typename T, size_t Rows, size_t Cols,
layout Layout, typename PropertyListT>
void joint_matrix_store_checked(Group g,
const joint_matrix<Group, T, use::b, Rows, Cols, Layout> &res,
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> base_dest,
size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY);
size_t Stride, size_t GlobalRows, size_t GlobalCols,
size_t RowIndex, size_t ColIndex);

} // namespace sycl::ext::intel::experimental::matrix
```
Expand All @@ -387,33 +425,38 @@ in the SYCL extenion
link:../../proposed/sycl_ext_intel_cache_controls.asciidoc[sycl_ext_intel_cache_controls].

==== Restrictions and Device Information Descriptors
The following restrictions apply to these checked APIs:

- The base pointer must be 4 bytes aligned.

- For 8 bits data type, `CoordX` must be a multiple of 4. For 16 bits
data type, `CoordX` must be a multiple of 2. So `CoordX` must be a
multiple of 4 divided by size of the element type (`4/sizeof(T)`).

- For 8 bits data type, `Width` must be a multiple of 4, For 16 bits
data type, `Width` must be a multiple of 2. So `Width` must be a
multiple of 4 divided by size of the element type (`4/sizeof(T)`).

These requirements are added as part of the matrix query. The table
below provides a description for the additional device matrix
descriptors that can be queried using `get_info` API.
Applications must adhere to certain alignment restrictions when using
the checked APIs described in this section. This extension provides
the following queries to get these requirements:

[frame="none",options="header"]
|======================
| Device descriptors | Return type| Description
|`ext::intel::experimental::info::device::matrix_checked_alignment`| `size_t`
|Returns the alignment requirement for the checked APIs pointer.
|`ext::intel::experimental::info::device::matrix_checked_coordx_multiple_of<T>`|`size_t`|Returns
a value, of which `CoordX` must be multiple of.
|`ext::intel::experimental::info::device::matrix_checked_width_multiple_of<T>`|`size_t`|Returns
a value, of which `Width` must be multiple of.
|Tells the required alignment (in bytes) of the base pointer for
`joint_matrix_load_checked` and `joint_matrix_store_checked`.
|`ext::intel::experimental::info::device::matrix_checked_rowindex_multiple_of<T>`|
`size_t`|Returns a value, of which `RowIndex` must be multiple of.
|`ext::intel::experimental::info::device::matrix_checked_globalcols_multiple_of<T>`|
`size_t` | Returns a value, of which `GlobalCols` must be multiple of.
|======================

==== Appendix: Restrictions Per Hardware
===== Intel XMX
The checked APIs are currently available in devices with the architecture
`architecture::intel_gpu_pvc`. The following restrictions apply to
these checked APIs:

- The base pointer must be 4 bytes aligned.

- For 8 bits data type, `RowIndex` must be a multiple of 4. For 16 bits
data type, `RowIndex` must be a multiple of 2. So `RowIndex` must be a
multiple of 4 divided by size of the element type (`4/sizeof(T)`).

- For 8 bits data type, `GlobalCols` must be a multiple of 4, For 16 bits
data type, `GlobalCols` must be a multiple of 2. So `GlobalCols` must be a
multiple of 4 divided by size of the element type (`4/sizeof(T)`).

=== New Device Information Descriptor
Besides the query we provide in
link:sycl_ext_oneapi_matrix.asciidoc[sycl_ext_oneapi_matrix],
Expand Down

0 comments on commit 95cdb68

Please sign in to comment.