Skip to content

Commit

Permalink
Add int_8 variants to the test
Browse files Browse the repository at this point in the history
  • Loading branch information
dkhaldi committed Feb 9, 2024
1 parent 2b5ca84 commit da872db
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//==--joint_matrix_bfloat16_rowmajorA_rowmajorB.cpp - DPC++ joint_matrix---==//
//==--------joint_matrix_rowmajorA_rowmajorB.cpp - DPC++ joint_matrix------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -25,4 +25,4 @@ using namespace sycl::ext::oneapi::experimental::matrix;

#define SG_SZ 32

#include "../joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp"
#include "../joint_matrix_rowmajorA_rowmajorB_impl.hpp"
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//==--joint_matrix_bfloat16_rowmajorA_rowmajorB.cpp - DPC++ joint_matrix---==//
//==-------joint_matrix_rowmajorA_rowmajorB.cpp - DPC++ joint_matrix-------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -20,4 +20,4 @@
using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

#include "joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp"
#include "joint_matrix_rowmajorA_rowmajorB_impl.hpp"
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
//==joint_matrix_bfloat16_rowmajorA_rowmajorB_impl.hpp - DPC++ joint_matrix-==//
//==-----joint_matrix_rowmajorA_rowmajorB_impl.hpp - DPC++ joint_matrix----==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

template <size_t TM, size_t TN, size_t TK, class kernel_name, typename T1,
typename T2, size_t M, size_t N, size_t K>
void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
big_matrix<T2, K, N> &B) {
template <size_t TM, size_t TN, size_t TK, class kernel_name, typename TA,
typename TB, typename TC, size_t M, size_t N, size_t K>
void matrix_multiply(big_matrix<TC, M, N> &C, big_matrix<TA, M, K> &A,
big_matrix<TB, K, N> &B) {
size_t NDRangeM = M / TM;
size_t NDRangeN = N / TN;
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, K));
buffer<bfloat16, 2> bufB(B.get_data(), range<2>(K, N));
buffer<float, 2> bufC((float *)C.get_data(), range<2>(M, N));
buffer<TA, 2> bufA((TA *)A.get_data(), range<2>(M, K));
buffer<TB, 2> bufB((TB *)B.get_data(), range<2>(K, N));
buffer<TC, 2> bufC((TC *)C.get_data(), range<2>(M, N));

queue q;
size_t sg_size = get_sg_size<kernel_name>(q);
q.submit([&](handler &cgh) {
auto accC = bufC.get_access<access::mode::read_write>(cgh);
auto accA = bufA.get_access<access::mode::read_write>(cgh);
auto accB = bufB.get_access<access::mode::read_write>(cgh);
sycl::accessor accC{bufC, cgh, sycl::read_write};
sycl::accessor accA{bufA, cgh, sycl::read_only};
sycl::accessor accB{bufB, cgh, sycl::read_only};

cgh.parallel_for<kernel_name>(
nd_range<2>({NDRangeM, NDRangeN * sg_size}, {1, 1 * sg_size}),
Expand All @@ -39,11 +39,9 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

sub_group sg = spmd_item.get_sub_group();
joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::row_major>
sub_a;
joint_matrix<sub_group, bfloat16, use::b, TK, TN, layout::row_major>
sub_b;
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
joint_matrix<sub_group, TA, use::a, TM, TK, layout::row_major> sub_a;
joint_matrix<sub_group, TB, use::b, TK, TN, layout::row_major> sub_b;
joint_matrix<sub_group, TC, use::accumulator, TM, TN> sub_c;

joint_matrix_load(
sg, sub_c,
Expand Down Expand Up @@ -72,34 +70,34 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
}).wait();
}

template <size_t TN, class kernel_name> int gemm_row_major() {
template <size_t TN, size_t TK, class kernel_name, typename TA, typename TB,
typename TC>
int gemm_row_major() {
static constexpr size_t TM = 8;
static constexpr size_t TK = 16;

static constexpr size_t MATRIX_M = TM * 2;
static constexpr size_t MATRIX_N = TN * 2;
static constexpr size_t MATRIX_K = TK * 2;
bfloat16 A[MATRIX_M][MATRIX_K];
bfloat16 B[MATRIX_K][MATRIX_N];
float C[MATRIX_M][MATRIX_N];
float D[MATRIX_M][MATRIX_N];
TA A[MATRIX_M][MATRIX_K];
TB B[MATRIX_K][MATRIX_N];
TC C[MATRIX_M][MATRIX_N];
TC D[MATRIX_M][MATRIX_N];

matrix_fill(MATRIX_M, MATRIX_K, (bfloat16 *)A,
[](int i, int j) { return 1.0f * (i + j); });
matrix_fill(MATRIX_K, MATRIX_N, (bfloat16 *)B,
[](int i, int j) { return 2.0f * i + 3.0f * j; });
matrix_fill(MATRIX_M, MATRIX_N, (float *)C, 1.0f);
matrix_fill(MATRIX_M, MATRIX_N, (float *)D, 1.0f);
matrix_fill(MATRIX_M, MATRIX_K, (TA *)A,
[](int i, int j) { return 1 * (i + j); });
matrix_fill(MATRIX_K, MATRIX_N, (TB *)B,
[](int i, int j) { return 2 * i + 3 * j; });
matrix_fill(MATRIX_M, MATRIX_N, (TC *)C, (TC)1);
matrix_fill(MATRIX_M, MATRIX_N, (TC *)D, (TC)1);

big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA((bfloat16 *)&A);
big_matrix<bfloat16, MATRIX_K, MATRIX_N> MB((bfloat16 *)&B);
big_matrix<TC, MATRIX_M, MATRIX_N> MC((TC *)&C);
big_matrix<TC, MATRIX_M, MATRIX_N> MD((TC *)&D);
big_matrix<TA, MATRIX_M, MATRIX_K> MA((TA *)&A);
big_matrix<TB, MATRIX_K, MATRIX_N> MB((TB *)&B);
matrix_multiply<TM, TN, TK, kernel_name>(MC, MA, MB);
matrix_multiply_ref((bfloat16 *)A, (bfloat16 *)B, (float *)D, MATRIX_M,
MATRIX_N, MATRIX_K);
matrix_multiply_ref((TA *)A, (TB *)B, (TC *)D, MATRIX_M, MATRIX_N, MATRIX_K);

bool res = matrix_compare(MATRIX_M, MATRIX_N, (float *)C, (float *)D);
bool res = matrix_compare(MATRIX_M, MATRIX_N, (TC *)C, (TC *)D);
std::cout << (res ? "passed" : "failed") << std::endl;
return !res;
}
Expand All @@ -113,11 +111,33 @@ int main() {
for (unsigned int i = 0; i < combinations.size(); i++) {
if (combinations[i].atype == matrix_type::bf16) {
if (combinations[i].nsize == 0 || combinations[i].nsize == 16) {
gemm_row_major<16, class gemm_16>();
gemm_row_major<16, 16, class gemm_bfloat16_16, bfloat16, bfloat16,
float>();
break;
}
if (combinations[i].nsize == 8) {
gemm_row_major<8, class gemm_8>();
gemm_row_major<8, 16, class gemm_bfloat16_8, bfloat16, bfloat16,
float>();
break;
}
}
if (combinations[i].atype == matrix_type::sint8) {
if (combinations[i].nsize == 0 || combinations[i].nsize == 16) {
gemm_row_major<16, 32, class gemm_int8_16, int8_t, int8_t, int32_t>();
gemm_row_major<16, 32, class gemm_us_int8_16, uint8_t, int8_t,
int32_t>();
gemm_row_major<16, 32, class gemm_su_int8_16, int8_t, uint8_t,
int32_t>();
gemm_row_major<16, 32, class gemm_uu_int8_16, uint8_t, uint8_t,
int32_t>();
break;
}
if (combinations[i].nsize == 8) {
gemm_row_major<8, 32, class gemm_int8_8, int8_t, int8_t, int32_t>();
gemm_row_major<8, 32, class gemm_us_int8_8, uint8_t, int8_t, int32_t>();
gemm_row_major<8, 32, class gemm_su_int8_8, int8_t, uint8_t, int32_t>();
gemm_row_major<8, 32, class gemm_uu_int8_8, uint8_t, uint8_t,
int32_t>();
break;
}
}
Expand Down

0 comments on commit da872db

Please sign in to comment.