From ea0d4f88c86c1dab256838221f422bd5536f6d03 Mon Sep 17 00:00:00 2001 From: Yury Plyakhin Date: Wed, 31 Jan 2024 16:58:18 -0800 Subject: [PATCH 1/5] Added big combinations to test --- .../test-e2e/Matrix/SG32/element_wise_ops.cpp | 13 +- .../test-e2e/Matrix/XMX8/element_wise_ops.cpp | 22 --- sycl/test-e2e/Matrix/common.hpp | 14 +- sycl/test-e2e/Matrix/element_wise_all_ops.cpp | 1 - .../Matrix/element_wise_all_ops_impl.hpp | 83 +++++--- sycl/test-e2e/Matrix/element_wise_ops.cpp | 12 +- .../test-e2e/Matrix/element_wise_ops_impl.hpp | 183 ++++++++---------- 7 files changed, 160 insertions(+), 168 deletions(-) delete mode 100644 sycl/test-e2e/Matrix/XMX8/element_wise_ops.cpp diff --git a/sycl/test-e2e/Matrix/SG32/element_wise_ops.cpp b/sycl/test-e2e/Matrix/SG32/element_wise_ops.cpp index b0522ed6656ff..1a09518e65ffb 100644 --- a/sycl/test-e2e/Matrix/SG32/element_wise_ops.cpp +++ b/sycl/test-e2e/Matrix/SG32/element_wise_ops.cpp @@ -5,19 +5,16 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// REQUIRES: matrix +// REQUIRES: aspect-ext_intel_matrix // REQUIRES-INTEL-DRIVER: lin: 27501, win: 101.4943 +// SG size = 32 is not currently supported for SYCL Joint Matrix by IGC on DG2 +// UNSUPPORTED: gpu-intel-dg2 // RUN: %{build} -o %t.out // RUN: %{run} %t.out -#include -#include +#include "../common.hpp" -using namespace sycl; -using namespace sycl::ext::oneapi::experimental::matrix; - -constexpr size_t SG_SZ = 32; -constexpr size_t TN = 16; +#define SG_SZ 32 #include "../element_wise_ops_impl.hpp" diff --git a/sycl/test-e2e/Matrix/XMX8/element_wise_ops.cpp b/sycl/test-e2e/Matrix/XMX8/element_wise_ops.cpp deleted file mode 100644 index 8fa0a2bf5094a..0000000000000 --- a/sycl/test-e2e/Matrix/XMX8/element_wise_ops.cpp +++ /dev/null @@ -1,22 +0,0 @@ -//==----------- element_wise_ops.cpp - DPC++ joint_matrix------------- ----==// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// REQUIRES: matrix-xmx8 - -// RUN: %{build} -o %t.out -// RUN: %{run} %t.out - -#include -#include - -using namespace sycl; -using namespace sycl::ext::oneapi::experimental::matrix; - -#define SG_SZ 8 -constexpr size_t TN = 8; - -#include "../element_wise_ops_impl.hpp" diff --git a/sycl/test-e2e/Matrix/common.hpp b/sycl/test-e2e/Matrix/common.hpp index 1e27fe6d7d989..d65a2a5c11d33 100644 --- a/sycl/test-e2e/Matrix/common.hpp +++ b/sycl/test-e2e/Matrix/common.hpp @@ -151,6 +151,15 @@ void matrix_copy(unsigned int rows, unsigned int cols, T *src, T *dst) { } } +template +void matrix_apply(unsigned int rows, unsigned int cols, T *src, F &&lambda) { + for (unsigned int i = 0; i < rows; i++) { + for (unsigned int j = 0; j < cols; j++) { + lambda(src[i * cols + j]); + } + } +} + template bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) { for (int i = 0; i < rows; i++) { @@ -170,8 +179,9 @@ bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) { } } else if constexpr (exact || std::is_same_v) { if (src[i * cols + j] != ref[i * cols + j]) { - std::cout << "Incorrect result in matrix." << "i: " << i - << ", j: " << j << ", Ref: " << ref[i * cols + j] + std::cout << "Incorrect result in matrix." + << "i: " << i << ", j: " << j + << ", Ref: " << ref[i * cols + j] << ", Val: " << src[i * cols + j] << "\n"; return false; } diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops.cpp b/sycl/test-e2e/Matrix/element_wise_all_ops.cpp index 7593adaf23a62..7aa1c2d3f7b3c 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops.cpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops.cpp @@ -14,5 +14,4 @@ // RUN: %{run} %t.out #include "common.hpp" - #include "element_wise_all_ops_impl.hpp" diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp index 55d1162ebd3af..3986d05eaea65 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp @@ -5,6 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// + template void assert_ops_ref(host_accessor mat, const float ref) { @@ -105,8 +106,11 @@ void verify_op_c(const T l, const T r, const float ref, OP op) { // Avoid same kernel name for different types template class ewops_a {}; -template -void test_ewops_a() { +template void test_ewops_a() { + std::cout << "Test A " << SROWS << "x" << SCOLS << "\n"; + + static constexpr size_t NROWS = SROWS * 2; + static constexpr size_t NCOLS = SCOLS * 2; verify_op_a>( T(5.0), T(2.0), 7.0, [](auto l, auto r) { return l + r; }); @@ -135,64 +139,87 @@ void test_ewops_a() { T(5.0), T(2.0), 2.0, [](auto l, auto r) { return l <= r ? T(3.0) : T(2.0); }); } + // Avoid same kernel name for different types and numbers of columns -template class ewops_c {}; -template -void test_ewops_c() { +template class ewops_c {}; +template void test_ewops_c() { + std::cout << "Test C " << SROWS << "x" << SCOLS << "\n"; - verify_op_c>( + static constexpr size_t NROWS = SROWS * 2; + static constexpr size_t NCOLS = SCOLS * 2; + + verify_op_c>( T(5.0), T(2.0), 7.0, [](auto l, auto r) { return l + r; }); - verify_op_c>( + verify_op_c>( T(5.0), T(2.0), 3.0, [](auto l, auto r) { return l - r; }); - verify_op_c>( + verify_op_c>( T(5.0), T(2.0), 10.0, [](auto l, auto r) { return l * r; }); - verify_op_c>( + verify_op_c>( T(5.0), T(2.0), 2.5, [](auto l, auto r) { return l / r; }); verify_op_c>( + ewops_c>( T(5.0), T(5.0), 5.0, [](auto l, auto r) { return l == r ? l : T(1.0); }); - verify_op_c>( + verify_op_c>( T(5.0), T(4.0), 4.0, [](auto l, auto r) { return l == r ? l : r; }); - verify_op_c>( + verify_op_c>( T(5.0), T(5.0), 1.0, [](auto l, auto r) { return l != r ? l : T(1.0); }); - verify_op_c>( + verify_op_c>( T(5.0), T(2.0), 3.0, [](auto l, auto r) { return l > r ? T(3.0) : T(2.0); }); - verify_op_c>( + verify_op_c>( T(5.0), T(2.0), 2.0, [](auto l, auto r) { return l < r ? T(3.0) : T(2.0); }); - verify_op_c>( + verify_op_c>( T(5.0), T(2.0), 3.0, [](auto l, auto r) { return l >= r ? T(3.0) : T(2.0); }); - verify_op_c>( + verify_op_c>( T(5.0), T(2.0), 2.0, [](auto l, auto r) { return l <= r ? T(3.0) : T(2.0); }); } int main() { - static constexpr size_t TM = 8; - - static constexpr size_t MATRIX_M = TM * 2; - static constexpr size_t MATRIX_N = 32; - static constexpr size_t MATRIX_K = 32; queue q; std::vector combinations = q.get_device() .get_info(); + for (unsigned int i = 0; i < combinations.size(); i++) { if (combinations[i].atype == matrix_type::bf16) { - if (combinations[i].nsize == 0 || combinations[i].nsize == 16) { - test_ewops_a(); - test_ewops_c(); - break; + + if (combinations[i].nsize == 0 || + (combinations[i].msize == 0 && combinations[i].nsize == 16)) { + test_ewops_a(); + test_ewops_c(); + } + + if (combinations[i].msize == 16 && combinations[i].nsize == 16) { + test_ewops_c(); + } + +// This combination is not currently supported for sub group size = 32 in IGC +#if (!defined(SG_SZ) || SG_SZ != 32) + if (combinations[i].msize == 32 && combinations[i].nsize == 64) { + test_ewops_c(); } +#endif + if (combinations[i].nsize == 8) { - test_ewops_a(); - test_ewops_c(); - break; + test_ewops_a(); + test_ewops_c(); } } } + return 0; } diff --git a/sycl/test-e2e/Matrix/element_wise_ops.cpp b/sycl/test-e2e/Matrix/element_wise_ops.cpp index a87ad3ab17999..855ed65900a6f 100644 --- a/sycl/test-e2e/Matrix/element_wise_ops.cpp +++ b/sycl/test-e2e/Matrix/element_wise_ops.cpp @@ -5,18 +5,10 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// REQUIRES: matrix +// REQUIRES: aspect-ext_intel_matrix // RUN: %{build} -o %t.out // RUN: %{run} %t.out -#include -#include - -using namespace sycl; -using namespace sycl::ext::oneapi::experimental::matrix; - -#define SG_SZ 16 -constexpr size_t TN = 16; - +#include "common.hpp" #include "element_wise_ops_impl.hpp" diff --git a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp index 1dd9779aa0b56..5027b009a1ae5 100644 --- a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp @@ -1,44 +1,36 @@ -#define TM 8 -#define TK 32 +//==----------- element_wise_ops_impl.hpp - DPC++ joint_matrix---------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// -template struct big_matrix { -public: - T *mat; - -public: - T *get_data() { return mat; } - void set_data(T *data) { mat = data; } - big_matrix(T *data) : mat(data) {} -}; - -template -void matrix_multiply(big_matrix &C, - big_matrix &A, - big_matrix &B) { - size_t M = NUM_ROWS_C; - size_t N = NUM_COLS_C; - size_t K = NUM_COLS_A; - // B => K/4 x N*4, A => M x K, C => M, N +template +void matrix_multiply(big_matrix &C, big_matrix &A, + big_matrix &B) { // stride should be X's cols, e.g., B's stirde = N*4 - assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4); size_t NDRangeM = M / TM; size_t NDRangeN = N / TN; - buffer bufA(A.get_data(), range<2>(M, K)); - buffer bufB(B.get_data(), range<2>(K, N)); - buffer bufC(C.get_data(), range<2>(M, N)); + buffer bufA(A.get_data(), range<2>(M, K)); + buffer bufB(B.get_data(), range<2>(K, N)); + buffer bufC(C.get_data(), range<2>(M, N)); queue q; + size_t sg_size = get_sg_size(q); q.submit([&](handler &cgh) { - auto accC = bufC.get_access(cgh); - auto accA = bufA.get_access(cgh); - auto accB = bufB.get_access(cgh); + auto accC = bufC.template get_access(cgh); + auto accA = bufA.template get_access(cgh); + auto accB = bufB.template get_access(cgh); - cgh.parallel_for( - nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), - [accA, accB, accC, M, N, - K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { + cgh.parallel_for( + nd_range<2>({NDRangeM, NDRangeN * sg_size}, {1, 1 * sg_size}), + [=](nd_item<2> spmd_item) +#ifdef SG_SZ + [[intel::reqd_sub_group_size(SG_SZ)]] +#endif + { // The submatrix API has to be accessed by all the workitems in a // subgroup these functions will be called once by the subgroup no // code divergence between the workitems @@ -48,18 +40,16 @@ void matrix_multiply(big_matrix &C, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sycl::sub_group sg = spmd_item.get_sub_group(); - joint_matrix - sub_a; + joint_matrix sub_a; // For B, we assume B has been already VNNIed. - joint_matrix + joint_matrix sub_b; - joint_matrix sub_c; + joint_matrix sub_c; joint_matrix_load( sg, sub_c, accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + (sg_startx * TM) * N + sg_starty / sg_size * TN, N, layout::row_major); for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( @@ -70,79 +60,78 @@ void matrix_multiply(big_matrix &C, joint_matrix_load( sg, sub_b, accB.template get_multi_ptr() + - (k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4, - N * 4); + (k * TK / VF) * (N * VF) + sg_starty / sg_size * TN * VF, + N * VF); joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } - joint_matrix_apply(sg, sub_c, [](int32_t &x) { x = x * 2; }); + joint_matrix_apply(sg, sub_c, [](Tc &x) { x = x * 2; }); joint_matrix_store( sg, sub_c, accC.template get_multi_ptr() + - (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + (sg_startx * TM) * N + sg_starty / sg_size * TN, N, layout::row_major); }); // parallel for }).wait(); } -static constexpr size_t MATRIX_M = TM * 2; -static constexpr size_t MATRIX_N = TN * 2; -static constexpr size_t MATRIX_K = TK * 2; -int8_t A[MATRIX_M][MATRIX_K]; -int8_t B[MATRIX_K / 4][MATRIX_N * 4]; -int32_t C[MATRIX_M][MATRIX_N]; -int32_t D[MATRIX_M][MATRIX_N]; +template +bool test() { -void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M, - int N, int K) { - // tiling - for (int m = 0; m < M; m++) - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k++) { - char *va = (char *)(A_mem + m * K + k); - char *vb = (char *)(B_mem + k * N + n); - int acc = *(C_mem + m * N + n); - for (int i = 0; i < 4; i++) { - acc += (va[i] * vb[i]); - } - *(C_mem + m * N + n) = acc; - } - *(C_mem + m * N + n) *= 2; - } + static constexpr size_t MATRIX_M = TM * 2; + static constexpr size_t MATRIX_N = TN * 2; + static constexpr size_t MATRIX_K = TK * 2; + + Ta A[MATRIX_M][MATRIX_K]; + Ta B[MATRIX_K / VF][MATRIX_N * VF]; + Tc C[MATRIX_M][MATRIX_N]; + Tc D[MATRIX_M][MATRIX_N]; + + matrix_rand(MATRIX_M, MATRIX_K, (Ta *)A, (Ta)100); + matrix_rand(MATRIX_K / VF, MATRIX_N * VF, (Ta *)B, (Ta)100); + matrix_fill(MATRIX_M, MATRIX_N, (Tc *)C, (Tc)1); + matrix_fill(MATRIX_M, MATRIX_N, (Tc *)D, (Tc)1); + + big_matrix MC((Tc *)&C); + big_matrix MD((Tc *)&D); + big_matrix MA((Ta *)&A); + big_matrix MB((Ta *)&B); + matrix_multiply(MC, MA, MB); + matrix_multiply_ref((Ta *)A, (Ta *)B, (Tc *)D, MATRIX_M, + MATRIX_N, MATRIX_K / VF); + matrix_apply(MATRIX_M, MATRIX_N, (Tc *)D, [](Tc &x) { x = x * 2; }); + + bool res = matrix_compare(MATRIX_M, MATRIX_N, (Tc *)C, (Tc *)D); + std::cout << TM << "x" << TN << "x" << TK << ": " + << (res ? "passed" : "failed") << std::endl; + return res; } int main() { - for (int i = 0; i < MATRIX_M; i++) { - for (int j = 0; j < MATRIX_K; j++) { - A[i][j] = i + 2 * j; - } - } - for (int i = 0; i < MATRIX_K / 4; i++) { - for (int j = 0; j < MATRIX_N * 4; j++) { - B[i][j] = i + j; - } - } - for (int i = 0; i < MATRIX_M; i++) { - for (int j = 0; j < MATRIX_N; j++) { - C[i][j] = 1; - D[i][j] = 1; - } - } - - big_matrix MC((int32_t *)&C); - big_matrix MD((int32_t *)&D); - big_matrix MA((int8_t *)&A); - big_matrix MB((int8_t *)&B); - matrix_multiply(MC, MA, MB); - matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, - MATRIX_N, MATRIX_K / 4); + queue q; + std::vector combinations = + q.get_device() + .get_info(); - bool res = true; - for (int i = 0; i < MATRIX_M; i++) { - for (int j = 0; j < MATRIX_N; j++) { - if (C[i][j] != D[i][j]) - res = false; + bool passed = true; + for (unsigned int i = 0; i < combinations.size(); i++) { + if (combinations[i].atype == matrix_type::sint8 && + combinations[i].btype == matrix_type::sint8) { + if (combinations[i].nsize == 0 || combinations[i].nsize == 16) { + passed &= test(); + } + if (combinations[i].nsize == 8) { + passed &= test(); + } + } +// This combination is not currently supported for sub group size = 32 in IGC +#if (!defined(SG_SZ) || SG_SZ != 32) + if (combinations[i].atype == matrix_type::bf16 && + combinations[i].msize == 32 && combinations[i].nsize == 64) { + passed &= test(); } +#endif } - std::cout << (res ? "passed" : "failed") << std::endl; - return !res; + return !passed; } From 14dbfbc3efa50e664e12ed6d4ff753a21e9da1b9 Mon Sep 17 00:00:00 2001 From: Yury Plyakhin Date: Thu, 21 Mar 2024 15:24:46 -0700 Subject: [PATCH 2/5] more combinations --- sycl/test-e2e/Matrix/common.hpp | 3 +- .../test-e2e/Matrix/element_wise_ops_impl.hpp | 37 ++++++++++++------- .../joint_matrix_rowmajorA_rowmajorB_impl.hpp | 3 +- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/sycl/test-e2e/Matrix/common.hpp b/sycl/test-e2e/Matrix/common.hpp index d65a2a5c11d33..0f1848fa9e337 100644 --- a/sycl/test-e2e/Matrix/common.hpp +++ b/sycl/test-e2e/Matrix/common.hpp @@ -132,8 +132,7 @@ void matrix_rand(unsigned int rows, unsigned int cols, T *src, T val) { if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { src[i * cols + j] = T(fdistr(dev)); - } else if constexpr (std::is_same_v || - std::is_same_v) { + } else if constexpr (std::is_integral_v) { src[i * cols + j] = T(idistr(dev)); } else { assert(false && "Unsupported type in matrix_rand."); diff --git a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp index 5027b009a1ae5..8e010b897e2df 100644 --- a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp @@ -116,22 +116,33 @@ int main() { bool passed = true; for (unsigned int i = 0; i < combinations.size(); i++) { - if (combinations[i].atype == matrix_type::sint8 && - combinations[i].btype == matrix_type::sint8) { - if (combinations[i].nsize == 0 || combinations[i].nsize == 16) { - passed &= test(); - } - if (combinations[i].nsize == 8) { - passed &= test(); - } + if (combinations[i].nsize == 0) { // Intel AMX + passed &= test(); + passed &= test(); + passed &= test(); + break; } -// This combination is not currently supported for sub group size = 32 in IGC + + if (combinations[i].nsize == 16) { // architecture::intel_gpu_pvc + passed &= test(); + passed &= test(); + passed &= test(); #if (!defined(SG_SZ) || SG_SZ != 32) - if (combinations[i].atype == matrix_type::bf16 && - combinations[i].msize == 32 && combinations[i].nsize == 64) { - passed &= test(); - } + // These combination are not currently supported for subgroup size = 32 in + // IGC + passed &= test(); + passed &= test(); #endif + break; + } + + if (combinations[i].nsize == 8) { // architecture::intel_gpu_dg2* + passed &= test(); + passed &= test(); + passed &= test(); + break; + } } + return !passed; } diff --git a/sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp index 036f75abb9c97..89b1da275fe0d 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp @@ -122,7 +122,8 @@ int main() { float>(); } } - if (combinations[i].atype == matrix_type::sint8) { + if (combinations[i].atype == matrix_type::sint8 && combinations[i].btype == + matrix_type::sint8) { if (combinations[i].nsize == 0 || (combinations[i].nsize == 16 && combinations[i].max_msize == 8 && combinations[i].ksize == 32)) { From 64fd2808cd1d8445f015db133b60a5ed07bf7d0f Mon Sep 17 00:00:00 2001 From: Yury Plyakhin Date: Thu, 21 Mar 2024 16:31:31 -0700 Subject: [PATCH 3/5] matrix multiply fused with apply --- sycl/test-e2e/Matrix/common.hpp | 26 +++++++------------ .../test-e2e/Matrix/element_wise_ops_impl.hpp | 7 ++--- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/sycl/test-e2e/Matrix/common.hpp b/sycl/test-e2e/Matrix/common.hpp index 0f1848fa9e337..4ce66cfc5bbfd 100644 --- a/sycl/test-e2e/Matrix/common.hpp +++ b/sycl/test-e2e/Matrix/common.hpp @@ -42,21 +42,21 @@ float make_fp32(bfloat16 x) { return *res; } -template +template void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K, bool transpose_c = false, bool colmajor_a = false, - bool colmajor_b = false) { + bool colmajor_b = false, F &&lambda = {}) { for (unsigned int m = 0; m < M; m++) { for (unsigned int n = 0; n < N; n++) { - for (unsigned int k = 0; k < K; k++) { + int c_ind = transpose_c ? (n * M + m) : m * N + n; + Tc acc = *(C + c_ind); + for (unsigned int k = 0; k < K; k++) { int a_ind = colmajor_a ? (k * M + m) : m * K + k; int b_ind = colmajor_b ? (n * K + k) : k * N + n; - int c_ind = transpose_c ? (n * M + m) : m * N + n; - Ta *va = (Ta *)(A + a_ind * VF); Tb *vb = (Tb *)(B + b_ind * VF); - Tc acc = *(C + c_ind); for (unsigned int i = 0; i < VF; i++) { if constexpr (std::is_same_v && @@ -74,9 +74,12 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K, else assert(false && "Unsupported type in matrix_multiply_ref."); } - *(C + c_ind) = acc; } + + if constexpr (!std::is_same_v) { + lambda(*(C + c_ind)); + } } } } @@ -150,15 +153,6 @@ void matrix_copy(unsigned int rows, unsigned int cols, T *src, T *dst) { } } -template -void matrix_apply(unsigned int rows, unsigned int cols, T *src, F &&lambda) { - for (unsigned int i = 0; i < rows; i++) { - for (unsigned int j = 0; j < cols; j++) { - lambda(src[i * cols + j]); - } - } -} - template bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) { for (int i = 0; i < rows; i++) { diff --git a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp index 8e010b897e2df..edde026ed877e 100644 --- a/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_ops_impl.hpp @@ -96,12 +96,13 @@ bool test() { big_matrix MD((Tc *)&D); big_matrix MA((Ta *)&A); big_matrix MB((Ta *)&B); + matrix_multiply(MC, MA, MB); matrix_multiply_ref((Ta *)A, (Ta *)B, (Tc *)D, MATRIX_M, - MATRIX_N, MATRIX_K / VF); - matrix_apply(MATRIX_M, MATRIX_N, (Tc *)D, [](Tc &x) { x = x * 2; }); - + MATRIX_N, MATRIX_K / VF, false, false, + false, [](Tc &x) { x = x * 2; }); bool res = matrix_compare(MATRIX_M, MATRIX_N, (Tc *)C, (Tc *)D); + std::cout << TM << "x" << TN << "x" << TK << ": " << (res ? "passed" : "failed") << std::endl; return res; From f15d74a6460f01dd2adebbc6d7c63f0e97e3104a Mon Sep 17 00:00:00 2001 From: Yury Plyakhin Date: Thu, 21 Mar 2024 16:40:25 -0700 Subject: [PATCH 4/5] format --- .../test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp index 89b1da275fe0d..20b93ed46cc12 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp @@ -122,8 +122,8 @@ int main() { float>(); } } - if (combinations[i].atype == matrix_type::sint8 && combinations[i].btype == - matrix_type::sint8) { + if (combinations[i].atype == matrix_type::sint8 && + combinations[i].btype == matrix_type::sint8) { if (combinations[i].nsize == 0 || (combinations[i].nsize == 16 && combinations[i].max_msize == 8 && combinations[i].ksize == 32)) { From 81102bbe3e03d79a6388fcb98e05e8f883c01395 Mon Sep 17 00:00:00 2001 From: Yury Plyakhin Date: Thu, 21 Mar 2024 16:53:20 -0700 Subject: [PATCH 5/5] optimization --- sycl/test-e2e/Matrix/common.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sycl/test-e2e/Matrix/common.hpp b/sycl/test-e2e/Matrix/common.hpp index 4ce66cfc5bbfd..cbdc847e68c5e 100644 --- a/sycl/test-e2e/Matrix/common.hpp +++ b/sycl/test-e2e/Matrix/common.hpp @@ -74,12 +74,12 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K, else assert(false && "Unsupported type in matrix_multiply_ref."); } - *(C + c_ind) = acc; } if constexpr (!std::is_same_v) { - lambda(*(C + c_ind)); + lambda(acc); } + *(C + c_ind) = acc; } } }