diff --git a/sycl/test-e2e/Matrix/SG32/joint_matrix_bfloat16_rowmajorA_rowmajorB.cpp b/sycl/test-e2e/Matrix/SG32/joint_matrix_rowmajorA_rowmajorB.cpp similarity index 84% rename from sycl/test-e2e/Matrix/SG32/joint_matrix_bfloat16_rowmajorA_rowmajorB.cpp rename to sycl/test-e2e/Matrix/SG32/joint_matrix_rowmajorA_rowmajorB.cpp index 1f9cd11065c4c..b88399514f51c 100644 --- a/sycl/test-e2e/Matrix/SG32/joint_matrix_bfloat16_rowmajorA_rowmajorB.cpp +++ b/sycl/test-e2e/Matrix/SG32/joint_matrix_rowmajorA_rowmajorB.cpp @@ -1,4 +1,4 @@ -//==--joint_matrix_bfloat16_rowmajorA_rowmajorB.cpp - DPC++ joint_matrix---==// +//==--------joint_matrix_rowmajorA_rowmajorB.cpp - DPC++ joint_matrix------==// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -25,4 +25,4 @@ using namespace sycl::ext::oneapi::experimental::matrix; #define SG_SZ 32 -#include "../joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp" +#include "../joint_matrix_rowmajorA_rowmajorB_impl.hpp" diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB.cpp b/sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB.cpp similarity index 83% rename from sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB.cpp rename to sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB.cpp index dcde38e2035c2..958bd94fe0cd3 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB.cpp @@ -1,4 +1,4 @@ -//==--joint_matrix_bfloat16_rowmajorA_rowmajorB.cpp - DPC++ joint_matrix---==// +//==-------joint_matrix_rowmajorA_rowmajorB.cpp - DPC++ joint_matrix-------==// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -20,4 +20,4 @@ using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; -#include "joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp" +#include "joint_matrix_rowmajorA_rowmajorB_impl.hpp" diff --git a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp similarity index 51% rename from sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp rename to sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp index 56c24d68d2545..d43fe9f9be793 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp @@ -1,4 +1,4 @@ -//==joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp - DPC++ joint_matrix-==// +//==-----joint_matrix_rowmajorA_rowmajorB_impl.hpp - DPC++ joint_matrix----==// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,22 +6,22 @@ // //===----------------------------------------------------------------------===// -template -void matrix_multiply(big_matrix &C, big_matrix &A, - big_matrix &B) { +template +void matrix_multiply(big_matrix &C, big_matrix &A, + big_matrix &B) { size_t NDRangeM = M / TM; size_t NDRangeN = N / TN; - buffer bufA(A.get_data(), range<2>(M, K)); - buffer bufB(B.get_data(), range<2>(K, N)); - buffer bufC((float *)C.get_data(), range<2>(M, N)); + buffer bufA((TA *)A.get_data(), range<2>(M, K)); + buffer bufB((TB *)B.get_data(), range<2>(K, N)); + buffer bufC((TC *)C.get_data(), range<2>(M, N)); queue q; size_t sg_size = get_sg_size(q); q.submit([&](handler &cgh) { - auto accC = bufC.get_access(cgh); - auto accA = bufA.get_access(cgh); - auto accB = bufB.get_access(cgh); + sycl::accessor accC{bufC, cgh, sycl::read_write}; + sycl::accessor accA{bufA, cgh, sycl::read_only}; + sycl::accessor accB{bufB, cgh, sycl::read_only}; cgh.parallel_for( nd_range<2>({NDRangeM, NDRangeN * sg_size}, {1, 1 * sg_size}), @@ -39,11 +39,9 @@ void matrix_multiply(big_matrix &C, big_matrix &A, const auto sg_starty = global_idy - spmd_item.get_local_id(1); sub_group sg = spmd_item.get_sub_group(); - joint_matrix - sub_a; - joint_matrix - sub_b; - joint_matrix sub_c; + joint_matrix sub_a; + joint_matrix sub_b; + joint_matrix sub_c; joint_matrix_load( sg, sub_c, @@ -72,34 +70,34 @@ void matrix_multiply(big_matrix &C, big_matrix &A, }).wait(); } -template int gemm_row_major() { +template +int gemm_row_major() { static constexpr size_t TM = 8; - static constexpr size_t TK = 16; static constexpr size_t MATRIX_M = TM * 2; static constexpr size_t MATRIX_N = TN * 2; static constexpr size_t MATRIX_K = TK * 2; - bfloat16 A[MATRIX_M][MATRIX_K]; - bfloat16 B[MATRIX_K][MATRIX_N]; - float C[MATRIX_M][MATRIX_N]; - float D[MATRIX_M][MATRIX_N]; + TA A[MATRIX_M][MATRIX_K]; + TB B[MATRIX_K][MATRIX_N]; + TC C[MATRIX_M][MATRIX_N]; + TC D[MATRIX_M][MATRIX_N]; - matrix_fill(MATRIX_M, MATRIX_K, (bfloat16 *)A, - [](int i, int j) { return 1.0f * (i + j); }); - matrix_fill(MATRIX_K, MATRIX_N, (bfloat16 *)B, - [](int i, int j) { return 2.0f * i + 3.0f * j; }); - matrix_fill(MATRIX_M, MATRIX_N, (float *)C, 1.0f); - matrix_fill(MATRIX_M, MATRIX_N, (float *)D, 1.0f); + matrix_fill(MATRIX_M, MATRIX_K, (TA *)A, + [](int i, int j) { return 1 * (i + j); }); + matrix_fill(MATRIX_K, MATRIX_N, (TB *)B, + [](int i, int j) { return 2 * i + 3 * j; }); + matrix_fill(MATRIX_M, MATRIX_N, (TC *)C, (TC)1); + matrix_fill(MATRIX_M, MATRIX_N, (TC *)D, (TC)1); - big_matrix MC((float *)&C); - big_matrix MD((float *)&D); - big_matrix MA((bfloat16 *)&A); - big_matrix MB((bfloat16 *)&B); + big_matrix MC((TC *)&C); + big_matrix MD((TC *)&D); + big_matrix MA((TA *)&A); + big_matrix MB((TB *)&B); matrix_multiply(MC, MA, MB); - matrix_multiply_ref((bfloat16 *)A, (bfloat16 *)B, (float *)D, MATRIX_M, - MATRIX_N, MATRIX_K); + matrix_multiply_ref((TA *)A, (TB *)B, (TC *)D, MATRIX_M, MATRIX_N, MATRIX_K); - bool res = matrix_compare(MATRIX_M, MATRIX_N, (float *)C, (float *)D); + bool res = matrix_compare(MATRIX_M, MATRIX_N, (TC *)C, (TC *)D); std::cout << (res ? "passed" : "failed") << std::endl; return !res; } @@ -113,11 +111,33 @@ int main() { for (unsigned int i = 0; i < combinations.size(); i++) { if (combinations[i].atype == matrix_type::bf16) { if (combinations[i].nsize == 0 || combinations[i].nsize == 16) { - gemm_row_major<16, class gemm_16>(); + gemm_row_major<16, 16, class gemm_bfloat16_16, bfloat16, bfloat16, + float>(); break; } if (combinations[i].nsize == 8) { - gemm_row_major<8, class gemm_8>(); + gemm_row_major<8, 16, class gemm_bfloat16_8, bfloat16, bfloat16, + float>(); + break; + } + } + if (combinations[i].atype == matrix_type::sint8) { + if (combinations[i].nsize == 0 || combinations[i].nsize == 16) { + gemm_row_major<16, 32, class gemm_int8_16, int8_t, int8_t, int32_t>(); + gemm_row_major<16, 32, class gemm_us_int8_16, uint8_t, int8_t, + int32_t>(); + gemm_row_major<16, 32, class gemm_su_int8_16, int8_t, uint8_t, + int32_t>(); + gemm_row_major<16, 32, class gemm_uu_int8_16, uint8_t, uint8_t, + int32_t>(); + break; + } + if (combinations[i].nsize == 8) { + gemm_row_major<8, 32, class gemm_int8_8, int8_t, int8_t, int32_t>(); + gemm_row_major<8, 32, class gemm_us_int8_8, uint8_t, int8_t, int32_t>(); + gemm_row_major<8, 32, class gemm_su_int8_8, int8_t, uint8_t, int32_t>(); + gemm_row_major<8, 32, class gemm_uu_int8_8, uint8_t, uint8_t, + int32_t>(); break; } }