Skip to content

Commit

Permalink
[SYCL][Joint Matrix] Support 1x64x16 bf16 combination (#13391)
Browse files Browse the repository at this point in the history
- add support in device_info
- add support in tests
  • Loading branch information
YuriPlyakhin authored Apr 24, 2024
1 parent 1f72a47 commit ed0619b
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 123 deletions.
2 changes: 2 additions & 0 deletions sycl/source/detail/device_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,8 @@ struct get_device_info_impl<
matrix_type::fp32, matrix_type::fp32},
{0, 0, 0, 16, 16, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::fp32, matrix_type::fp32},
{0, 0, 0, 1, 64, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::fp32, matrix_type::fp32},
{0, 0, 0, 32, 64, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::fp32, matrix_type::fp32},
{8, 0, 0, 0, 16, 8, matrix_type::tf32, matrix_type::tf32,
Expand Down
26 changes: 0 additions & 26 deletions sycl/test-e2e/Matrix/SG32/joint_matrix_bfloat16_32x64x16.cpp

This file was deleted.

26 changes: 0 additions & 26 deletions sycl/test-e2e/Matrix/SG32/joint_matrix_bfloat16_32x64x32.cpp

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//==----- joint_matrix_bfloat16_16x16x16.cpp - DPC++ joint_matrix----------==//
//==----- joint_matrix_bfloat16_packedB.cpp - DPC++ joint_matrix----------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -15,12 +15,5 @@

#include "../common.hpp"

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

#define SG_SZ 32
constexpr size_t TM = 16;
constexpr size_t TN = 16;
constexpr size_t TK = 16;

#include "../joint_matrix_bfloat16_packedB_impl.hpp"
26 changes: 0 additions & 26 deletions sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64x16.cpp

This file was deleted.

25 changes: 0 additions & 25 deletions sycl/test-e2e/Matrix/joint_matrix_bfloat16_32x64x32.cpp

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//==----- joint_matrix_bfloat16_16x16x16.cpp - DPC++ joint_matrix----------==//
//==----- joint_matrix_bfloat16_packedB.cpp - DPC++ joint_matrix----------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -13,12 +13,5 @@

#include "common.hpp"

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

#define SG_SZ 16
constexpr size_t TM = 16;
constexpr size_t TN = 16;
constexpr size_t TK = 16;

#include "joint_matrix_bfloat16_packedB_impl.hpp"
43 changes: 39 additions & 4 deletions sycl/test-e2e/Matrix/joint_matrix_bfloat16_packedB_impl.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
template <typename T1, typename T2, size_t M, size_t N, size_t K>
//=----- joint_matrix_bfloat16_packedB_impl.hpp - DPC++ joint_matrix -------=//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//=-------------------------------------------------------------------------=//

template <size_t TM, size_t TN, size_t TK, class kernel_name, typename T1,
typename T2, size_t M, size_t N, size_t K>
void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
big_matrix<T2, K / 2, N * 2> &B) {
size_t NDRangeM = M / TM;
Expand All @@ -13,7 +22,7 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
auto accA = bufA.get_access<access::mode::read_write>(cgh);
auto accB = bufB.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class imatrix>(
cgh.parallel_for<kernel_name>(
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]]

Expand Down Expand Up @@ -63,7 +72,7 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
}).wait();
}

int main() {
template <size_t TM, size_t TN, size_t TK, class kernel_name> int test() {
static constexpr size_t MATRIX_M = TM * 2;
static constexpr size_t MATRIX_N = TN * 2;
static constexpr size_t MATRIX_K = TK * 2;
Expand All @@ -83,12 +92,38 @@ int main() {
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA((bfloat16 *)&A);
big_matrix<bfloat16, MATRIX_K / 2, MATRIX_N * 2> MB((bfloat16 *)&B);
matrix_multiply(MC, MA, MB);
matrix_multiply<TM, TN, TK, kernel_name>(MC, MA, MB);
matrix_multiply_ref<bfloat16, bfloat16, float, 2>(
(bfloat16 *)A, (bfloat16 *)B, (float *)D, MATRIX_M, MATRIX_N,
MATRIX_K / 2);

bool res = matrix_compare(MATRIX_M, MATRIX_N, (float *)C, (float *)D);
std::cout << TM << "x" << TN << "x" << TK << " ";
std::cout << (res ? "passed" : "failed") << std::endl;
return !res;
}

int main() {
queue q;
std::vector<combination> combinations =
q.get_device()
.get_info<sycl::ext::oneapi::experimental::info::device::
matrix_combinations>();

int ret = 0;
for (auto &combination : combinations) {
if (combination.nsize == 0) { // Intel AMX
ret += test<16, 16, 16, class amx16x16x16>();
break;
}

if (combination.nsize == 16) { // architecture::intel_gpu_pvc
ret += test<16, 16, 16, class pvc16x16x16>();
ret += test<32, 64, 16, class pvc32x64x16>();
ret += test<1, 64, 16, class pvc1x64x16>();
break;
}
}

return ret;
}
2 changes: 2 additions & 0 deletions sycl/test-e2e/Matrix/runtime_query_pvc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ int main() {
matrix_type::fp32, matrix_type::fp32},
{0, 0, 0, 16, 16, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::fp32, matrix_type::fp32},
{0, 0, 0, 1, 64, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::fp32, matrix_type::fp32},
{0, 0, 0, 32, 64, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::fp32, matrix_type::fp32},
{8, 0, 0, 0, 16, 8, matrix_type::tf32, matrix_type::tf32,
Expand Down

0 comments on commit ed0619b

Please sign in to comment.