Skip to content

Commit

Permalink
reuse joint_matmul from joint_matrix_bf16_fill_k_cache_impl.hpp in jo…
Browse files Browse the repository at this point in the history
…int_matrix_bf16_fill_k_cache_runtime_dim_impl.hpp
  • Loading branch information
YixingZhang007 committed Sep 20, 2024
1 parent a184cbc commit 11aa43c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 122 deletions.
10 changes: 6 additions & 4 deletions sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
#endif

// number of test iterations
constexpr unsigned int testIterations = 100;
extern constexpr unsigned int testIterations = 100;
// start recording time after X iterations
constexpr unsigned int recordThresh = 10;
extern constexpr unsigned int recordThresh = 10;

#ifndef MATRIX_SIZE
#define MATRIX_SIZE 256
Expand Down Expand Up @@ -46,9 +46,9 @@ template <

double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i
#ifdef ARG_DIM
, size_t rowsA, size_t colsA, size_t rowsB, size_t colsB
, size_t rowsA, size_t colsA, size_t rowsB, size_t colsB
#endif // ARG_DIM
) {
) {

size_t sgSize = get_sg_size<MatMul<TM, TN, TK>>(q);
range<2> global{rowsA / MCache1, (colsB / NCache1) * sgSize};
Expand Down Expand Up @@ -355,6 +355,7 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i
return duration.count();
}

#ifndef EXCLUDE_MAIN_TEST
template <typename T, typename TResult, size_t vnniFactor, size_t TM, size_t TN,
size_t TK, size_t MCache1, size_t NCache1, size_t KCache1,
size_t MCache2, size_t NCache2, size_t KCache2>
Expand Down Expand Up @@ -482,3 +483,4 @@ int main() {
}
return 0;
}
#endif //EXCLUDE_MAIN_TEST
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,9 @@

// -ffp-model=precise is added to not depend on compiler defaults.

#define EXCLUDE_MAIN_TEST 1
#define ARG_DIM 1

#include "common.hpp"
#include "joint_matrix_bf16_fill_k_cache_runtime_dim_impl.hpp"
#include "joint_matrix_bf16_fill_k_cache_impl.hpp"
#include "joint_matrix_bf16_fill_k_cache_runtime_dim_impl.hpp"
127 changes: 10 additions & 117 deletions sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_runtime_dim_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,123 +8,8 @@

#include <random>
#include <sycl/usm.hpp>

// number of test iterations
constexpr unsigned int testIterations = 100;
// start recording time after X iterations
constexpr unsigned int recordThresh = 10;

template <size_t TM, size_t TN, size_t TK> class MatMul;

template <size_t vnniFactor, typename TOperand, typename TResult, size_t TM,
size_t TN, size_t TK, size_t MCache1, size_t NCache1, size_t KCache1,
size_t MCache2, size_t NCache2, size_t KCache2>
double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i,
size_t rowsA, size_t colsA, size_t rowsB, size_t colsB) {

size_t sgSize = get_sg_size<MatMul<TM, TN, TK>>(q);
range<2> global{rowsA / MCache1, (colsB / NCache1) * sgSize};
range<2> cachelocal{MCache2 / MCache1, NCache2 / NCache1 * sgSize};

// throw error if padding needed
assert(colsA == rowsB);
assert(rowsA % TM == 0);
assert(colsA % TK == 0);
assert(colsB % TN == 0);
// submit main kernel
std::chrono::high_resolution_clock::time_point start =
std::chrono::high_resolution_clock::now();

q.submit([&](handler &h) {
h.parallel_for<MatMul<TM, TN, TK>>( // cache layer#1
nd_range<2>{global, cachelocal},
// loop global
// loop localrange
[=](nd_item<2> it)
{
// sg::load and sg::store expect decorations to be ON
auto pA =
address_space_cast<sycl::access::address_space::global_space,
sycl::access::decorated::yes>(A);
auto pB =
address_space_cast<sycl::access::address_space::global_space,
sycl::access::decorated::yes>(B);
auto pC =
address_space_cast<sycl::access::address_space::global_space,
sycl::access::decorated::yes>(C);
auto m2 = it.get_group(0);
auto n2 = it.get_group(1);
auto m1 = it.get_local_id(0);
auto n1 = it.get_local_id(1) / sgSize;
auto sg = it.get_sub_group();

joint_matrix<sub_group, TResult, use::accumulator, TM, TN>
tC[MCache1 / TM][NCache1 / TN];

for (unsigned int m = 0; m < MCache1 / TM; m++) {
for (unsigned int n = 0; n < NCache1 / TN; n++) {
joint_matrix_fill(sg, tC[m][n], 0);
}
}

for (unsigned int k2 = 0; k2 < colsA / KCache2; k2++) {
joint_matrix<sub_group, TOperand, use::a, TM, TK, layout::row_major>
tA[MCache1 / TM][KCache2 / KCache1];
#ifdef VNNI
joint_matrix<sub_group, TOperand, use::b, TK, TN,
layout::ext_intel_packed>
tB[NCache1 / TN][KCache2 / KCache1];
#else // VNNI
joint_matrix<sub_group, TOperand, use::b, TK, TN,
layout::row_major>
tB[NCache1 / TN][KCache2 / KCache1];
#endif // VNNI

for (unsigned int k1 = 0; k1 < KCache2 / KCache1; k1++) {
unsigned int k = (k2 * KCache2 + k1 * KCache1) / TK;
for (unsigned int m = 0; m < MCache1 / TM; m++) {
joint_matrix_load(
sg, tA[m][k1],
pA + (m2 * MCache2 + m1 * MCache1 + m * TM) * colsA +
k * TK,
colsA);
}
for (unsigned int n = 0; n < NCache1 / TN; n++) {
joint_matrix_load(
sg, tB[n][k1],
pB + (k * TK / vnniFactor) * (colsB * vnniFactor) +
(n2 * NCache2 + n1 * NCache1 + n * TN) * vnniFactor,
colsB * vnniFactor);
} // n
for (unsigned int m = 0; m < MCache1 / TM; m++) {
for (unsigned int n = 0; n < NCache1 / TN; n++) {
joint_matrix_mad(sg, tC[m][n], tA[m][k1], tB[n][k1],
tC[m][n]);
} // n
} // m
} // k1
} // for k2

for (unsigned int m = 0; m < MCache1 / TM; m++) {
for (unsigned int n = 0; n < NCache1 / TN; n++) {
joint_matrix_store(
sg, tC[m][n],
pC + (m2 * MCache2 + m1 * MCache1 + m * TM) * colsB +
(n2 * NCache2 + n1 * NCache1 + n * TN),
colsB, layout::row_major);
} // n
} // m
}); // parallel_for
}); // queue.submit

if (i == testIterations - 1)
q.wait();
std::chrono::duration<double, std::milli> duration =
std::chrono::high_resolution_clock::now() - start;

return duration.count();
}

template <typename T, typename TResult, size_t vnniFactor, size_t TM, size_t TN,
size_t TK, size_t MCache1, size_t NCache1, size_t KCache1,
size_t MCache2, size_t NCache2, size_t KCache2>
Expand Down Expand Up @@ -164,7 +49,8 @@ void test(size_t matrix_size) {
double duration =
joint_matmul<vnniFactor, T, TResult, TM, TN, TK, MCache1, NCache1,
KCache1, MCache2, NCache2, KCache2>
(A, B, C, q, i, matrix_size, matrix_size, matrix_size, matrix_size);
(A, B, C, q, i,
matrix_size, matrix_size, matrix_size, matrix_size);

if (i >= recordThresh) {
totalDuration += duration;
Expand All @@ -189,7 +75,14 @@ void test(size_t matrix_size) {

int main(int argc, char *argv[]) {
size_t matrix_size;
matrix_size = std::stoul(argv[1]);

// Check for command line argument
if (argc == 2) {
matrix_size = std::stoul(argv[1]);
} else {
std::cerr << "Usage: ./program matrix_size\n";
return 1; // Error if no argument
}

queue q;
std::vector<combination> combinations =
Expand Down

0 comments on commit 11aa43c

Please sign in to comment.