Skip to content

Commit

Permalink
[SYCL] [Joint Matrix] Modularized tests (#13785)
Browse files Browse the repository at this point in the history
Some tests had a lot of duplicate code. In many cases functions could be
collapsed to a general `matrix_verify_op()` function that takes in a
lambda, instead of having a unique function for each operation.

Although it may sees that the implementations of `matrix_verify_op()` is
itself is repeated between multiple files that is by design. The
implementations differ by the `use` parameter, which dimensions are used
to traverse the matrix, and packing behaviour of tiles. Although this
too can be abstracted out to template parameters or function arguments,
this will grow already complicated interface to unreasonable sizes.
  • Loading branch information
artemrad committed May 15, 2024
1 parent c630064 commit c173fbf
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 693 deletions.
231 changes: 51 additions & 180 deletions sycl/test-e2e/Matrix/element_wise_all_ops_half_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,221 +9,92 @@
#define TM 8
#define TK 16

template <typename T, size_t M, size_t N>
template <typename T, size_t M, size_t N, typename R>
void assert_ops_ref(host_accessor<T, 2, access::mode::read> C,
const float ref) {
for (size_t i = 0; i < M; i++)
for (size_t j = 0; j < N; j++) {
auto diff = C[i][j] - ref;
assert(std::fabs(static_cast<float>(diff)) <
std::numeric_limits<float>::epsilon());
assert(std::fabs(static_cast<R>(diff)) <
std::numeric_limits<R>::epsilon());
}
}
template <typename T, size_t M, size_t N>
void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
const float ref) {
buffer<half, 2> bufA(A.get_data(), range<2>(M, N));

q.submit([&](handler &cgh) {
auto accA = bufA.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class add_matrix>(
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

sub_group sg = spmd_item.get_sub_group();
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;

joint_matrix_fill(sg, sub_a, 5);

joint_matrix_apply(sg, sub_a,
[=](T &x) { x = x + static_cast<half>(2); });
ext::intel::experimental::matrix::joint_matrix_store(
sg, sub_a,
accA.template get_multi_ptr<access::decorated::no>() +
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
N);
}); // parallel for
}).wait();
assert_ops_ref<T, M, N>(bufA.get_host_access(read_only), ref);
}

template <typename T, size_t M, size_t N>
void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
const float ref) {
buffer<half, 2> bufA(A.get_data(), range<2>(M, N));

q.submit([&](handler &cgh) {
auto accA = bufA.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class sub_matrix>(
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

sub_group sg = spmd_item.get_sub_group();
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;

joint_matrix_fill(sg, sub_a, 5);

joint_matrix_apply(sg, sub_a,
[=](T &x) { x = x - static_cast<half>(2); });
ext::intel::experimental::matrix::joint_matrix_store(
sg, sub_a,
accA.template get_multi_ptr<access::decorated::no>() +
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
N);
}); // parallel for
}).wait();
assert_ops_ref<T, M, N>(bufA.get_host_access(read_only), ref);
}

template <typename T, size_t M, size_t N>
void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
const float ref) {
buffer<half, 2> bufA(A.get_data(), range<2>(M, N));

q.submit([&](handler &cgh) {
auto accA = bufA.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class mul_matrix>(
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

sub_group sg = spmd_item.get_sub_group();
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;

joint_matrix_fill(sg, sub_a, 5);

joint_matrix_apply(sg, sub_a,
[=](T &x) { x = x * static_cast<half>(3.0); });
ext::intel::experimental::matrix::joint_matrix_store(
sg, sub_a,
accA.template get_multi_ptr<access::decorated::no>() +
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
N);
}); // parallel for
}).wait();
assert_ops_ref<T, M, N>(bufA.get_host_access(read_only), ref);
}

template <typename T, size_t M, size_t N>
void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
const float ref) {
template <typename T, size_t M, size_t N, size_t TileM, size_t TileN,
size_t TileK, class kernel_name, typename R, typename OP>
void matrix_verify_op(big_matrix<T, M, N> &A, const R ref, OP op) {
buffer<half, 2> bufA(A.get_data(), range<2>(M, N));

q.submit([&](handler &cgh) {
auto accA = bufA.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class div_matrix>(
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

sub_group sg = spmd_item.get_sub_group();
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;

joint_matrix_fill(sg, sub_a, 4);

joint_matrix_apply(sg, sub_a,
[=](T &x) { x = x / static_cast<half>(2.0); });
ext::intel::experimental::matrix::joint_matrix_store(
sg, sub_a,
accA.template get_multi_ptr<access::decorated::no>() +
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
N);
}); // parallel for
}).wait();
assert_ops_ref<T, M, N>(bufA.get_host_access(read_only), ref);
}

template <typename T, size_t M, size_t N>
void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
const float ref) {
buffer<half, 2> bufA(A.get_data(), range<2>(M, N));
queue q;
size_t sg_size = get_sg_size<kernel_name>(q);
nd_range<2> r({M / TileM, N / TileN * sg_size}, {1, 1 * sg_size});

q.submit([&](handler &cgh) {
auto accA = bufA.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class logic_matrix>(
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
cgh.parallel_for<kernel_name>(
r, [=](nd_item<2> spmd_item)
#ifdef SG_SZ
[[intel::reqd_sub_group_size(SG_SZ)]]
#endif
{
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

sub_group sg = spmd_item.get_sub_group();
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
joint_matrix<sub_group, T, use::a, TileM, TileK, layout::row_major>
sub_a;

joint_matrix_fill(sg, sub_a, 5);

joint_matrix_apply(sg, sub_a, [](T &x) {
if (x) {
if (x > static_cast<half>(2.0) || x >= static_cast<half>(2.0) ||
x < static_cast<half>(2.0) || x <= static_cast<half>(2.0)) {
T val =
(x != static_cast<half>(2.0)) ? x : static_cast<half>(2.0);
val--;
val++;
if (x == static_cast<half>(2.0)) {
val -= 2;
val *= 3;
val /= 2;
} else {
val += 2;
}
x = val;
}
}
});
joint_matrix_apply(sg, sub_a, op);
ext::intel::experimental::matrix::joint_matrix_store(
sg, sub_a,
accA.template get_multi_ptr<access::decorated::no>() +
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
(sg_startx * TileM) * N + sg_starty / sg_size * TileN,
N);
}); // parallel for
}).wait();
assert_ops_ref<T, M, N>(bufA.get_host_access(read_only), ref);
}

static constexpr size_t MATRIX_M = TM * 2;
static constexpr size_t MATRIX_N = TN * 2;
half A[MATRIX_M][MATRIX_N];
float D[MATRIX_M][MATRIX_N];

void matrix_ops_ref(float *D, int M, int N) {
for (int m = 0; m < M; m++)
for (int n = 0; n < N; n++) {
*(D + m * N + n) = 0;
*(D + m * N + n) *= 2;
}
assert_ops_ref<T, M, N, R>(bufA.get_host_access(read_only), ref);
}

int main() {

big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
static constexpr size_t MATRIX_M = TM * 2;
static constexpr size_t MATRIX_N = TN * 2;
half A[MATRIX_M][MATRIX_N];
big_matrix<half, MATRIX_M, MATRIX_N> MA((half *)&A);

size_t NDRangeM = MATRIX_M / TM;
size_t NDRangeN = MATRIX_N / TN;
queue q;
nd_range<2> r({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ});

matrix_verify_add<half, MATRIX_M, MATRIX_N>(q, MA, r, 7.0);
matrix_verify_sub<half, MATRIX_M, MATRIX_N>(q, MA, r, 3.0);
matrix_verify_mul<half, MATRIX_M, MATRIX_N>(q, MA, r, 15.0);
matrix_verify_div<half, MATRIX_M, MATRIX_N>(q, MA, r, 2.0);
matrix_verify_logic<half, MATRIX_M, MATRIX_N>(q, MA, r, 7.0);
matrix_verify_op<half, MATRIX_M, MATRIX_N, TM, TN, TK, class add, float>(
MA, 7.0, [=](auto &x) { x = x + static_cast<half>(2); });
matrix_verify_op<half, MATRIX_M, MATRIX_N, TM, TN, TK, class sub, float>(
MA, 3.0, [=](auto &x) { x = x - static_cast<half>(2); });
matrix_verify_op<half, MATRIX_M, MATRIX_N, TM, TN, TK, class mult, float>(
MA, 15.0, [=](auto &x) { x = x * static_cast<half>(3.0); });
matrix_verify_op<half, MATRIX_M, MATRIX_N, TM, TN, TK, class div, float>(
MA, 2.5, [=](auto &x) { x = x / static_cast<half>(2.0); });
matrix_verify_op<half, MATRIX_M, MATRIX_N, TM, TN, TK, class logic, float>(
MA, 7.0, [=](auto &x) {
if (x) {
if (x > static_cast<half>(2.0) || x >= static_cast<half>(2.0) ||
x < static_cast<half>(2.0) || x <= static_cast<half>(2.0)) {
half val =
(x != static_cast<half>(2.0)) ? x : static_cast<half>(2.0);
val--;
val++;
if (x == static_cast<half>(2.0)) {
val -= 2;
val *= 3;
val /= 2;
} else {
val += 2;
}
x = val;
}
}
});

return 0;
}
Loading

0 comments on commit c173fbf

Please sign in to comment.