From 11aa43c26883b62205c89e01ae9f9f9eb7f677d4 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Fri, 20 Sep 2024 08:11:10 -0700 Subject: [PATCH] reuse joint_matmul from joint_matrix_bf16_fill_k_cache_impl.hpp in joint_matrix_bf16_fill_k_cache_runtime_dim_impl.hpp --- .../joint_matrix_bf16_fill_k_cache_impl.hpp | 10 +- ...t_matrix_bf16_fill_k_cache_runtime_dim.cpp | 6 +- ...rix_bf16_fill_k_cache_runtime_dim_impl.hpp | 127 ++---------------- 3 files changed, 21 insertions(+), 122 deletions(-) diff --git a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp index 47cfab5506187..b4a902a5b2281 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp @@ -14,9 +14,9 @@ #endif // number of test iterations -constexpr unsigned int testIterations = 100; +extern constexpr unsigned int testIterations = 100; // start recording time after X iterations -constexpr unsigned int recordThresh = 10; +extern constexpr unsigned int recordThresh = 10; #ifndef MATRIX_SIZE #define MATRIX_SIZE 256 @@ -46,9 +46,9 @@ template < double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i #ifdef ARG_DIM - , size_t rowsA, size_t colsA, size_t rowsB, size_t colsB + , size_t rowsA, size_t colsA, size_t rowsB, size_t colsB #endif // ARG_DIM - ) { + ) { size_t sgSize = get_sg_size>(q); range<2> global{rowsA / MCache1, (colsB / NCache1) * sgSize}; @@ -355,6 +355,7 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i return duration.count(); } +#ifndef EXCLUDE_MAIN_TEST template @@ -482,3 +483,4 @@ int main() { } return 0; } +#endif //EXCLUDE_MAIN_TEST diff --git a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_runtime_dim.cpp b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_runtime_dim.cpp index 30ba7db3d227e..4c9766bc1751f 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_runtime_dim.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_runtime_dim.cpp @@ -18,5 +18,9 @@ // -ffp-model=precise is added to not depend on compiler defaults. +#define EXCLUDE_MAIN_TEST 1 +#define ARG_DIM 1 + #include "common.hpp" -#include "joint_matrix_bf16_fill_k_cache_runtime_dim_impl.hpp" \ No newline at end of file +#include "joint_matrix_bf16_fill_k_cache_impl.hpp" +#include "joint_matrix_bf16_fill_k_cache_runtime_dim_impl.hpp" diff --git a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_runtime_dim_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_runtime_dim_impl.hpp index cbfb4de8f18aa..495185f46fbe4 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_runtime_dim_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_runtime_dim_impl.hpp @@ -8,123 +8,8 @@ #include #include - -// number of test iterations -constexpr unsigned int testIterations = 100; -// start recording time after X iterations -constexpr unsigned int recordThresh = 10; - template class MatMul; -template -double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i, - size_t rowsA, size_t colsA, size_t rowsB, size_t colsB) { - - size_t sgSize = get_sg_size>(q); - range<2> global{rowsA / MCache1, (colsB / NCache1) * sgSize}; - range<2> cachelocal{MCache2 / MCache1, NCache2 / NCache1 * sgSize}; - - // throw error if padding needed - assert(colsA == rowsB); - assert(rowsA % TM == 0); - assert(colsA % TK == 0); - assert(colsB % TN == 0); - // submit main kernel - std::chrono::high_resolution_clock::time_point start = - std::chrono::high_resolution_clock::now(); - - q.submit([&](handler &h) { - h.parallel_for>( // cache layer#1 - nd_range<2>{global, cachelocal}, - // loop global - // loop localrange - [=](nd_item<2> it) - { - // sg::load and sg::store expect decorations to be ON - auto pA = - address_space_cast(A); - auto pB = - address_space_cast(B); - auto pC = - address_space_cast(C); - auto m2 = it.get_group(0); - auto n2 = it.get_group(1); - auto m1 = it.get_local_id(0); - auto n1 = it.get_local_id(1) / sgSize; - auto sg = it.get_sub_group(); - - joint_matrix - tC[MCache1 / TM][NCache1 / TN]; - - for (unsigned int m = 0; m < MCache1 / TM; m++) { - for (unsigned int n = 0; n < NCache1 / TN; n++) { - joint_matrix_fill(sg, tC[m][n], 0); - } - } - - for (unsigned int k2 = 0; k2 < colsA / KCache2; k2++) { - joint_matrix - tA[MCache1 / TM][KCache2 / KCache1]; -#ifdef VNNI - joint_matrix - tB[NCache1 / TN][KCache2 / KCache1]; -#else // VNNI - joint_matrix - tB[NCache1 / TN][KCache2 / KCache1]; -#endif // VNNI - - for (unsigned int k1 = 0; k1 < KCache2 / KCache1; k1++) { - unsigned int k = (k2 * KCache2 + k1 * KCache1) / TK; - for (unsigned int m = 0; m < MCache1 / TM; m++) { - joint_matrix_load( - sg, tA[m][k1], - pA + (m2 * MCache2 + m1 * MCache1 + m * TM) * colsA + - k * TK, - colsA); - } - for (unsigned int n = 0; n < NCache1 / TN; n++) { - joint_matrix_load( - sg, tB[n][k1], - pB + (k * TK / vnniFactor) * (colsB * vnniFactor) + - (n2 * NCache2 + n1 * NCache1 + n * TN) * vnniFactor, - colsB * vnniFactor); - } // n - for (unsigned int m = 0; m < MCache1 / TM; m++) { - for (unsigned int n = 0; n < NCache1 / TN; n++) { - joint_matrix_mad(sg, tC[m][n], tA[m][k1], tB[n][k1], - tC[m][n]); - } // n - } // m - } // k1 - } // for k2 - - for (unsigned int m = 0; m < MCache1 / TM; m++) { - for (unsigned int n = 0; n < NCache1 / TN; n++) { - joint_matrix_store( - sg, tC[m][n], - pC + (m2 * MCache2 + m1 * MCache1 + m * TM) * colsB + - (n2 * NCache2 + n1 * NCache1 + n * TN), - colsB, layout::row_major); - } // n - } // m - }); // parallel_for - }); // queue.submit - - if (i == testIterations - 1) - q.wait(); - std::chrono::duration duration = - std::chrono::high_resolution_clock::now() - start; - - return duration.count(); -} - template @@ -164,7 +49,8 @@ void test(size_t matrix_size) { double duration = joint_matmul - (A, B, C, q, i, matrix_size, matrix_size, matrix_size, matrix_size); + (A, B, C, q, i, + matrix_size, matrix_size, matrix_size, matrix_size); if (i >= recordThresh) { totalDuration += duration; @@ -189,7 +75,14 @@ void test(size_t matrix_size) { int main(int argc, char *argv[]) { size_t matrix_size; - matrix_size = std::stoul(argv[1]); + + // Check for command line argument + if (argc == 2) { + matrix_size = std::stoul(argv[1]); + } else { + std::cerr << "Usage: ./program matrix_size\n"; + return 1; // Error if no argument + } queue q; std::vector combinations =