Skip to content

Commit

Permalink
matrix multiply fused with apply
Browse files Browse the repository at this point in the history
  • Loading branch information
YuriPlyakhin committed Mar 21, 2024
1 parent 14dbfbc commit 64fd280
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 19 deletions.
26 changes: 10 additions & 16 deletions sycl/test-e2e/Matrix/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,21 @@ float make_fp32(bfloat16 x) {
return *res;
}

template <typename Ta, typename Tb, typename Tc, unsigned int VF = 1>
template <typename Ta, typename Tb, typename Tc, unsigned int VF = 1,
typename F = std::nullptr_t>
void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
bool transpose_c = false, bool colmajor_a = false,
bool colmajor_b = false) {
bool colmajor_b = false, F &&lambda = {}) {
for (unsigned int m = 0; m < M; m++) {
for (unsigned int n = 0; n < N; n++) {
for (unsigned int k = 0; k < K; k++) {
int c_ind = transpose_c ? (n * M + m) : m * N + n;
Tc acc = *(C + c_ind);

for (unsigned int k = 0; k < K; k++) {
int a_ind = colmajor_a ? (k * M + m) : m * K + k;
int b_ind = colmajor_b ? (n * K + k) : k * N + n;
int c_ind = transpose_c ? (n * M + m) : m * N + n;

Ta *va = (Ta *)(A + a_ind * VF);
Tb *vb = (Tb *)(B + b_ind * VF);
Tc acc = *(C + c_ind);

for (unsigned int i = 0; i < VF; i++) {
if constexpr (std::is_same_v<Ta, bfloat16> &&
Expand All @@ -74,9 +74,12 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
else
assert(false && "Unsupported type in matrix_multiply_ref.");
}

*(C + c_ind) = acc;
}

if constexpr (!std::is_same_v<F, std::nullptr_t>) {
lambda(*(C + c_ind));
}
}
}
}
Expand Down Expand Up @@ -150,15 +153,6 @@ void matrix_copy(unsigned int rows, unsigned int cols, T *src, T *dst) {
}
}

template <typename T, typename F>
void matrix_apply(unsigned int rows, unsigned int cols, T *src, F &&lambda) {
for (unsigned int i = 0; i < rows; i++) {
for (unsigned int j = 0; j < cols; j++) {
lambda(src[i * cols + j]);
}
}
}

template <typename T1, typename T2, bool exact = false>
bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
for (int i = 0; i < rows; i++) {
Expand Down
7 changes: 4 additions & 3 deletions sycl/test-e2e/Matrix/element_wise_ops_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,13 @@ bool test() {
big_matrix<Tc, MATRIX_M, MATRIX_N> MD((Tc *)&D);
big_matrix<Ta, MATRIX_M, MATRIX_K> MA((Ta *)&A);
big_matrix<Ta, MATRIX_K / VF, MATRIX_N * VF> MB((Ta *)&B);

matrix_multiply<TM, TN, TK, VF, kernel_name>(MC, MA, MB);
matrix_multiply_ref<Ta, Ta, Tc, VF>((Ta *)A, (Ta *)B, (Tc *)D, MATRIX_M,
MATRIX_N, MATRIX_K / VF);
matrix_apply(MATRIX_M, MATRIX_N, (Tc *)D, [](Tc &x) { x = x * 2; });

MATRIX_N, MATRIX_K / VF, false, false,
false, [](Tc &x) { x = x * 2; });
bool res = matrix_compare(MATRIX_M, MATRIX_N, (Tc *)C, (Tc *)D);

std::cout << TM << "x" << TN << "x" << TK << ": "
<< (res ? "passed" : "failed") << std::endl;
return res;
Expand Down

0 comments on commit 64fd280

Please sign in to comment.