From f9e4f10662c15fdd5d77313b9fb1ca8614976c2c Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Thu, 1 Feb 2024 12:35:25 +0000 Subject: [PATCH] [SYCL][CUDA] Improved joint_matrix layout test coverage. (#12483) Improved joint_matrix layout test coverage. The test framework that the cuda backend tests use has been updated to support all possible `joint_matrix` gemm API combinations, including all matrix layouts. the gemm header is backend agnostic; hence all backends could use this test framework in the future. This test framework can also act as an example to show how to deal with different layout combinations when computing a general GEMM. Signed-off-by: JackAKirk --- .../Matrix/joint_matrix_gemm_cuda.hpp | 105 ++++++++++++------ .../Matrix/joint_matrix_tensorcores_sm70.cpp | 15 ++- .../Matrix/joint_matrix_tensorcores_sm72.cpp | 19 +++- .../Matrix/joint_matrix_tensorcores_sm80.cpp | 22 +++- 4 files changed, 118 insertions(+), 43 deletions(-) diff --git a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp index fe5b110864e6b..9fd4f184692be 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp @@ -11,7 +11,7 @@ using namespace sycl::ext::oneapi; using namespace sycl::ext::oneapi::experimental::matrix; constexpr float bf16_eps = 0.00390625; -// Example usage of Nvidia matrix multiply. +// Example usage of joint_matrix matrix multiply. // Optimizations such as memory paddings for avoiding bank conflicts are not // included in this test which aids clarity for what is going on. This example // forms a "Big matrix" corresponding to a single "TILE" using cuda example @@ -30,37 +30,47 @@ constexpr float bf16_eps = 0.00390625; constexpr int N_THREADS_PER_MATRIX_OP = 32; // number of submatrices per row of accumulator ("C", "D") matrices. -constexpr int SUB_TILES_M = 3; +constexpr int SUB_TILES_M = 2; // number of submatrices per col of accumulator matrices. constexpr int SUB_TILES_N = 2; // number of submatrices per col of "A"/per row of "B", matrices. -constexpr int SUB_TILES_K = 1; +constexpr int SUB_TILES_K = 2; -template +template class TypeHelper; -template -using KernelName = class TypeHelper; +template +using KernelName = + class TypeHelper; -template +template Tc matrix_ref_mn(const int &m, const int &n, Tm *A, Tm *B, Tc *C) { Tc res = C[m * Big_N + n]; - if constexpr (std::is_same::value) { - for (int k = 0; k < Big_K; k++) - res += A[m * Big_K + k] * B[k * Big_N + n]; - } else { - for (int k = 0; k < Big_K; k++) - res += - static_cast(A[m * Big_K + k]) * static_cast(B[k * Big_N + n]); + for (int k = 0; k < Big_K; k++) { + auto index_a = + layout_A == layout::row_major ? m * Big_K + k : m + k * Big_M; + auto index_b = + layout_B == layout::row_major ? k * Big_N + n : k + n * Big_K; + + if constexpr (std::is_same::value) { + res += A[index_a] * B[index_b]; + } else { + res += static_cast(A[index_a]) * static_cast(B[index_b]); + } } return res; } -template > +template < + typename Tm, typename Tc, typename Td, size_t Sub_Tiles_M, + size_t Sub_Tiles_K, size_t Sub_Tiles_N, size_t M, size_t K, size_t N, + layout layout_A = layout::row_major, layout layout_B = layout::row_major, + layout layout_C = layout::row_major, typename T3 = std::remove_const_t> void test(queue &q) { // total number of M dimension matrix elements for the "Big matrix". constexpr auto Big_M = Sub_Tiles_M * M; @@ -97,7 +107,8 @@ void test(queue &q) { accessor accA(bufA, cgh); - cgh.parallel_for>( + cgh.parallel_for>( range<1>(Big_M * Big_K), [=](item<1> item) { auto i = item.get_linear_id(); accA[i] = 0.1f * (i % 10); @@ -107,7 +118,8 @@ void test(queue &q) { accessor accB(bufB, cgh); - cgh.parallel_for>( + cgh.parallel_for>( range<1>(Big_K * Big_N), [=](item<1> item) { auto i = item.get_linear_id(); accB[i] = 0.1f * (i % 10); @@ -130,7 +142,8 @@ void test(queue &q) { range<2> GlobalRange = {Sub_Tiles_M, Sub_Tiles_N * N_THREADS_PER_MATRIX_OP}; - cgh.parallel_for>( + cgh.parallel_for< + KernelName>( nd_range<2>(GlobalRange, LocalRange), [=](nd_item<2> item) { sycl::sub_group sg = item.get_sub_group(); // row id of current submatrix of BIG C matrix @@ -138,33 +151,46 @@ void test(queue &q) { // column id of current submatrix of BIG C matrix const auto n = item.get_group().get_group_id()[1]; - joint_matrix - sub_a; - joint_matrix - sub_b; + joint_matrix sub_a; + joint_matrix sub_b; joint_matrix, use::accumulator, M, N> sub_c; joint_matrix sub_d; + auto stride_C = layout_C == layout::row_major ? Big_N : Big_M; + auto load_stride_C = layout_C == layout::row_major + ? (m * M) * Big_N + n * N + : (m * M) + n * N * Big_M; joint_matrix_load( sg, sub_c, accC.template get_multi_ptr() + - (m * M) * Big_N + n * N, - Big_N, layout::row_major); + load_stride_C, + stride_C, layout_C); + + auto stride_A = layout_A == layout::row_major ? Big_K : Big_M; + auto stride_B = layout_B == layout::row_major ? Big_N : Big_K; + // k = row/col id of current submatrix of BIG A/B matrices for (int k = 0; k < Sub_Tiles_K; k++) { + auto load_stride_A = layout_A == layout::row_major + ? (k * K) + (m * M * Big_K) + : (k * K * Big_M) + (m * M); + auto load_stride_B = layout_B == layout::row_major + ? (k * K * Big_N) + (n * N) + : (k * K) + (n * N * Big_K); + joint_matrix_load( sg, sub_a, accA.template get_multi_ptr() + - (k * K) + (m * M * Big_K), - Big_K); + load_stride_A, + stride_A); joint_matrix_load( sg, sub_b, accB.template get_multi_ptr() + - (k * K * Big_N) + (n * N), - Big_N); + load_stride_B, + stride_B); // round values to correct precision if using tf32 if constexpr (std::is_same::value) { @@ -174,12 +200,13 @@ void test(queue &q) { } joint_matrix_mad(sg, sub_d, sub_a, sub_b, sub_c); + joint_matrix_copy(sg, sub_d, sub_c); } joint_matrix_store( sg, sub_d, accD.template get_multi_ptr() + - (m * M) * Big_N + n * N, - Big_N, layout::row_major); + load_stride_C, + stride_C, layout_C); }); }); q.wait(); @@ -187,14 +214,18 @@ void test(queue &q) { for (int m = 0; m < Big_M; m++) { for (int n = 0; n < Big_N; n++) { + auto index_D = + layout_C == layout::row_major ? m * Big_N + n : m + n * Big_M; if constexpr (std::is_same, bfloat16>::value) { - auto res_device = matrix_ref_mn(m, n, A, B, C); - assert(fabs(2 * (D[m * Big_N + n] - res_device)) / - (D[m * Big_N + n] + res_device) < + auto res_device = + matrix_ref_mn(m, n, A, B, + C); + assert(fabs(2 * (D[index_D] - res_device)) / (D[index_D] + res_device) < bf16_eps * 2); } else { - assert( - (D[m * Big_N + n] == matrix_ref_mn(m, n, A, B, C))); + assert((D[index_D] == + matrix_ref_mn(m, n, A, + B, C))); } } } diff --git a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp index f28372b6277dc..a558600ad390c 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp @@ -80,12 +80,23 @@ int main() { test(Q); + // test different layout combinations for one case + + test(Q); + test(Q); + test(Q); + test(Q); + + // joint_matrix_apply tests + auto apply_add = [](auto &x) { x = x + 2; }; float D[MATRIX_M][MATRIX_N]; big_matrix MD_f((float *)&D); - // joint_matrix_apply tests - matrix_verify_lambda(Q, MD_f, 0.0, apply_add); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm72.cpp b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm72.cpp index cea15392408cc..1dea8c879b5eb 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm72.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm72.cpp @@ -50,13 +50,28 @@ int main() { test(Q); + // test different layout combinations for one case + + test(Q); + test(Q); + test(Q); + test(Q); + + // joint_matrix_apply tests + auto apply_add = [](auto &x) { x = x + 2; }; int32_t D_i[MATRIX_M][MATRIX_N]; big_matrix MD_i((int32_t *)&D_i); - // joint_matrix_apply tests - matrix_verify_lambda(Q, MD_i, 0, apply_add); matrix_verify_lambda(Q, MD_i, 0, apply_add); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm80.cpp b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm80.cpp index 2a0731d9b988e..ca823161b6197 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm80.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm80.cpp @@ -43,9 +43,28 @@ int main() { // A/B tf32 test(Q); test(Q); + 16, 8, 16, layout::row_major, layout::row_major, layout::row_major, + precision::tf32>(Q); + + // test different layout combinations for one case + + test(Q); + test(Q); + test(Q); + test(Q); + + // joint_matrix_apply tests float D[MATRIX_M][MATRIX_N]; big_matrix MD_f((float *)&D); @@ -54,7 +73,6 @@ int main() { big_matrix MD_d((double *)&D_d); auto apply_add = [](auto &x) { x = x + 2; }; - // joint_matrix_apply tests matrix_verify_lambda(Q, MD_f, 0.0, apply_add); matrix_verify_lambda(Q, MD_d, -60.0, apply_add);