Skip to content

Commit

Permalink
[SYCL][CUDA] Improved joint_matrix layout test coverage. (#12483)
Browse files Browse the repository at this point in the history
Improved joint_matrix layout test coverage.

The test framework that the cuda backend tests use has been updated to
support all possible `joint_matrix` gemm API combinations, including all
matrix layouts. the gemm header is backend agnostic; hence all backends
could use this test framework in the future.

This test framework can also act as an example to show how to deal with
different layout combinations when computing a general GEMM.

Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
  • Loading branch information
JackAKirk committed Feb 1, 2024
1 parent e402523 commit f9e4f10
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 43 deletions.
105 changes: 68 additions & 37 deletions sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using namespace sycl::ext::oneapi;
using namespace sycl::ext::oneapi::experimental::matrix;
constexpr float bf16_eps = 0.00390625;

// Example usage of Nvidia matrix multiply.
// Example usage of joint_matrix matrix multiply.
// Optimizations such as memory paddings for avoiding bank conflicts are not
// included in this test which aids clarity for what is going on. This example
// forms a "Big matrix" corresponding to a single "TILE" using cuda example
Expand All @@ -30,37 +30,47 @@ constexpr float bf16_eps = 0.00390625;
constexpr int N_THREADS_PER_MATRIX_OP = 32;

// number of submatrices per row of accumulator ("C", "D") matrices.
constexpr int SUB_TILES_M = 3;
constexpr int SUB_TILES_M = 2;
// number of submatrices per col of accumulator matrices.
constexpr int SUB_TILES_N = 2;
// number of submatrices per col of "A"/per row of "B", matrices.
constexpr int SUB_TILES_K = 1;
constexpr int SUB_TILES_K = 2;

template <typename Tm, typename Tc, typename Td, size_t M, size_t K, size_t N>
template <typename Tm, typename Tc, typename Td, size_t M, size_t K, size_t N,
layout layout_A, layout layout_B, layout layout_C>
class TypeHelper;

template <typename Tm, typename Tc, typename Td, size_t M, size_t K, size_t N>
using KernelName = class TypeHelper<Tm, Tc, Td, M, K, N>;
template <typename Tm, typename Tc, typename Td, size_t M, size_t K, size_t N,
layout layout_A, layout layout_B, layout layout_C>
using KernelName =
class TypeHelper<Tm, Tc, Td, M, K, N, layout_A, layout_B, layout_C>;

template <size_t Big_N, size_t Big_K, typename Tm, typename Tc>
template <size_t Big_N, size_t Big_K, size_t Big_M, layout layout_A,
layout layout_B, typename Tm, typename Tc>
Tc matrix_ref_mn(const int &m, const int &n, Tm *A, Tm *B, Tc *C) {
Tc res = C[m * Big_N + n];

if constexpr (std::is_same<Tm, bfloat16>::value) {
for (int k = 0; k < Big_K; k++)
res += A[m * Big_K + k] * B[k * Big_N + n];
} else {
for (int k = 0; k < Big_K; k++)
res +=
static_cast<Tc>(A[m * Big_K + k]) * static_cast<Tc>(B[k * Big_N + n]);
for (int k = 0; k < Big_K; k++) {
auto index_a =
layout_A == layout::row_major ? m * Big_K + k : m + k * Big_M;
auto index_b =
layout_B == layout::row_major ? k * Big_N + n : k + n * Big_K;

if constexpr (std::is_same<Tm, bfloat16>::value) {
res += A[index_a] * B[index_b];
} else {
res += static_cast<Tc>(A[index_a]) * static_cast<Tc>(B[index_b]);
}
}

return res;
}

template <typename Tm, typename Tc, typename Td, size_t Sub_Tiles_M,
size_t Sub_Tiles_K, size_t Sub_Tiles_N, size_t M, size_t K, size_t N,
typename T3 = std::remove_const_t<Tm>>
template <
typename Tm, typename Tc, typename Td, size_t Sub_Tiles_M,
size_t Sub_Tiles_K, size_t Sub_Tiles_N, size_t M, size_t K, size_t N,
layout layout_A = layout::row_major, layout layout_B = layout::row_major,
layout layout_C = layout::row_major, typename T3 = std::remove_const_t<Tm>>
void test(queue &q) {
// total number of M dimension matrix elements for the "Big matrix".
constexpr auto Big_M = Sub_Tiles_M * M;
Expand Down Expand Up @@ -97,7 +107,8 @@ void test(queue &q) {
accessor<bfloat16, 1, access::mode::write, target::device> accA(bufA,
cgh);

cgh.parallel_for<KernelName<Tm, Tc, class copyA, M, K, N>>(
cgh.parallel_for<KernelName<Tm, Tc, class copyA, M, K, N, layout_A,
layout_B, layout_C>>(
range<1>(Big_M * Big_K), [=](item<1> item) {
auto i = item.get_linear_id();
accA[i] = 0.1f * (i % 10);
Expand All @@ -107,7 +118,8 @@ void test(queue &q) {
accessor<bfloat16, 1, access::mode::write, target::device> accB(bufB,
cgh);

cgh.parallel_for<KernelName<Tm, Tc, class copyB, M, K, N>>(
cgh.parallel_for<KernelName<Tm, Tc, class copyB, M, K, N, layout_A,
layout_B, layout_C>>(
range<1>(Big_K * Big_N), [=](item<1> item) {
auto i = item.get_linear_id();
accB[i] = 0.1f * (i % 10);
Expand All @@ -130,41 +142,55 @@ void test(queue &q) {
range<2> GlobalRange = {Sub_Tiles_M,
Sub_Tiles_N * N_THREADS_PER_MATRIX_OP};

cgh.parallel_for<KernelName<Tm, Tc, Td, M, K, N>>(
cgh.parallel_for<
KernelName<Tm, Tc, Td, M, K, N, layout_A, layout_B, layout_C>>(
nd_range<2>(GlobalRange, LocalRange), [=](nd_item<2> item) {
sycl::sub_group sg = item.get_sub_group();
// row id of current submatrix of BIG C matrix
const auto m = item.get_group().get_group_id()[0];
// column id of current submatrix of BIG C matrix
const auto n = item.get_group().get_group_id()[1];

joint_matrix<sycl::sub_group, T3, use::a, M, K, layout::row_major>
sub_a;
joint_matrix<sycl::sub_group, T3, use::b, K, N, layout::row_major>
sub_b;
joint_matrix<sycl::sub_group, T3, use::a, M, K, layout_A> sub_a;
joint_matrix<sycl::sub_group, T3, use::b, K, N, layout_B> sub_b;
joint_matrix<sycl::sub_group, std::remove_const_t<Tc>,
use::accumulator, M, N>
sub_c;
joint_matrix<sycl::sub_group, Td, use::accumulator, M, N> sub_d;
auto stride_C = layout_C == layout::row_major ? Big_N : Big_M;
auto load_stride_C = layout_C == layout::row_major
? (m * M) * Big_N + n * N
: (m * M) + n * N * Big_M;

joint_matrix_load(
sg, sub_c,
accC.template get_multi_ptr<access::decorated::no>() +
(m * M) * Big_N + n * N,
Big_N, layout::row_major);
load_stride_C,
stride_C, layout_C);

auto stride_A = layout_A == layout::row_major ? Big_K : Big_M;
auto stride_B = layout_B == layout::row_major ? Big_N : Big_K;

// k = row/col id of current submatrix of BIG A/B matrices
for (int k = 0; k < Sub_Tiles_K; k++) {
auto load_stride_A = layout_A == layout::row_major
? (k * K) + (m * M * Big_K)
: (k * K * Big_M) + (m * M);
auto load_stride_B = layout_B == layout::row_major
? (k * K * Big_N) + (n * N)
: (k * K) + (n * N * Big_K);

joint_matrix_load(
sg, sub_a,
accA.template get_multi_ptr<access::decorated::no>() +
(k * K) + (m * M * Big_K),
Big_K);
load_stride_A,
stride_A);

joint_matrix_load(
sg, sub_b,
accB.template get_multi_ptr<access::decorated::no>() +
(k * K * Big_N) + (n * N),
Big_N);
load_stride_B,
stride_B);

// round values to correct precision if using tf32
if constexpr (std::is_same<T3, precision::tf32>::value) {
Expand All @@ -174,27 +200,32 @@ void test(queue &q) {
}

joint_matrix_mad(sg, sub_d, sub_a, sub_b, sub_c);
joint_matrix_copy(sg, sub_d, sub_c);
}
joint_matrix_store(
sg, sub_d,
accD.template get_multi_ptr<access::decorated::no>() +
(m * M) * Big_N + n * N,
Big_N, layout::row_major);
load_stride_C,
stride_C, layout_C);
});
});
q.wait();
}

for (int m = 0; m < Big_M; m++) {
for (int n = 0; n < Big_N; n++) {
auto index_D =
layout_C == layout::row_major ? m * Big_N + n : m + n * Big_M;
if constexpr (std::is_same<std::remove_const_t<Tm>, bfloat16>::value) {
auto res_device = matrix_ref_mn<Big_N, Big_K>(m, n, A, B, C);
assert(fabs(2 * (D[m * Big_N + n] - res_device)) /
(D[m * Big_N + n] + res_device) <
auto res_device =
matrix_ref_mn<Big_N, Big_K, Big_M, layout_A, layout_B>(m, n, A, B,
C);
assert(fabs(2 * (D[index_D] - res_device)) / (D[index_D] + res_device) <
bf16_eps * 2);
} else {
assert(
(D[m * Big_N + n] == matrix_ref_mn<Big_N, Big_K>(m, n, A, B, C)));
assert((D[index_D] ==
matrix_ref_mn<Big_N, Big_K, Big_M, layout_A, layout_B>(m, n, A,
B, C)));
}
}
}
Expand Down
15 changes: 13 additions & 2 deletions sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,23 @@ int main() {
test<const half, const half, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N,
32, 16, 8>(Q);

// test different layout combinations for one case

test<const half, const half, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N,
32, 16, 8, layout::row_major, layout::row_major, layout::col_major>(Q);
test<const half, const half, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N,
32, 16, 8, layout::row_major, layout::col_major, layout::row_major>(Q);
test<const half, const half, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N,
32, 16, 8, layout::col_major, layout::row_major, layout::row_major>(Q);
test<const half, const half, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N,
32, 16, 8, layout::col_major, layout::col_major, layout::row_major>(Q);

// joint_matrix_apply tests

auto apply_add = [](auto &x) { x = x + 2; };
float D[MATRIX_M][MATRIX_N];
big_matrix<float, MATRIX_M, MATRIX_N> MD_f((float *)&D);

// joint_matrix_apply tests

matrix_verify_lambda<half, float, M, 16, N>(Q, MD_f, 0.0, apply_add);
}

Expand Down
19 changes: 17 additions & 2 deletions sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm72.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,28 @@ int main() {
test<const uint8_t, const int32_t, int32_t, SUB_TILES_M, SUB_TILES_K,
SUB_TILES_N, 32, 16, 8>(Q);

// test different layout combinations for one case

test<const uint8_t, const int32_t, int32_t, SUB_TILES_M, SUB_TILES_K,
SUB_TILES_N, 32, 16, 8, layout::row_major, layout::row_major,
layout::col_major>(Q);
test<const uint8_t, const int32_t, int32_t, SUB_TILES_M, SUB_TILES_K,
SUB_TILES_N, 32, 16, 8, layout::col_major, layout::row_major,
layout::row_major>(Q);
test<const uint8_t, const int32_t, int32_t, SUB_TILES_M, SUB_TILES_K,
SUB_TILES_N, 32, 16, 8, layout::row_major, layout::col_major,
layout::row_major>(Q);
test<const uint8_t, const int32_t, int32_t, SUB_TILES_M, SUB_TILES_K,
SUB_TILES_N, 32, 16, 8, layout::col_major, layout::col_major,
layout::row_major>(Q);

// joint_matrix_apply tests

auto apply_add = [](auto &x) { x = x + 2; };

int32_t D_i[MATRIX_M][MATRIX_N];
big_matrix<int32_t, MATRIX_M, MATRIX_N> MD_i((int32_t *)&D_i);

// joint_matrix_apply tests

matrix_verify_lambda<uint8_t, int32_t, M, 16, N>(Q, MD_i, 0, apply_add);
matrix_verify_lambda<int8_t, int32_t, M, 16, N>(Q, MD_i, 0, apply_add);
}
Expand Down
22 changes: 20 additions & 2 deletions sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm80.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,28 @@ int main() {

// A/B tf32
test<float, float, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N, 16, 8, 16,
layout::row_major, layout::row_major, layout::row_major,
precision::tf32>(Q);
test<const float, const float, float, SUB_TILES_M, SUB_TILES_K, SUB_TILES_N,
16, 8, 16, precision::tf32>(Q);
16, 8, 16, layout::row_major, layout::row_major, layout::row_major,
precision::tf32>(Q);

// test different layout combinations for one case

test<const bfloat16, const float, float, SUB_TILES_M, SUB_TILES_K,
SUB_TILES_N, 8, 16, 32, layout::row_major, layout::col_major,
layout::row_major>(Q);
test<const bfloat16, const float, float, SUB_TILES_M, SUB_TILES_K,
SUB_TILES_N, 8, 16, 32, layout::col_major, layout::row_major,
layout::row_major>(Q);
test<const bfloat16, const float, float, SUB_TILES_M, SUB_TILES_K,
SUB_TILES_N, 8, 16, 32, layout::col_major, layout::col_major,
layout::row_major>(Q);
test<const bfloat16, const float, float, SUB_TILES_M, SUB_TILES_K,
SUB_TILES_N, 8, 16, 32, layout::col_major, layout::col_major,
layout::col_major>(Q);

// joint_matrix_apply tests

float D[MATRIX_M][MATRIX_N];
big_matrix<float, MATRIX_M, MATRIX_N> MD_f((float *)&D);
Expand All @@ -54,7 +73,6 @@ int main() {
big_matrix<double, 8 * nWGperDim, 8 * nWGperDim> MD_d((double *)&D_d);
auto apply_add = [](auto &x) { x = x + 2; };

// joint_matrix_apply tests
matrix_verify_lambda<bfloat16, float, 16, 16, 16>(Q, MD_f, 0.0, apply_add);

matrix_verify_lambda<double, double, 8, 4, 8>(Q, MD_d, -60.0, apply_add);
Expand Down

0 comments on commit f9e4f10

Please sign in to comment.