Skip to content

Commit

Permalink
[SYCL][Joint Matrix] Test stores A and B for bfloat16 16x16x16, 32x64…
Browse files Browse the repository at this point in the history
…x16, 1x64x16 (#13572)
  • Loading branch information
YuriPlyakhin authored Apr 29, 2024
1 parent 3756fd1 commit cfd0d41
Showing 1 changed file with 67 additions and 44 deletions.
111 changes: 67 additions & 44 deletions sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ void assert_ops_ref(host_accessor<T, 2, access::mode::read> mat,
}

template <typename T, size_t NUM_ROWS, size_t NUM_COLS, size_t SUB_ROWS,
size_t SUB_COLS, class kernel_name, typename OP>
void verify_op_a(const T l, const T r, const float ref, OP op) {
T mat[NUM_ROWS][NUM_COLS];
big_matrix<T, NUM_ROWS, NUM_COLS> big_mat((T *)&mat);
size_t SUB_COLS, use Use, layout Layout, size_t VF, class kernel_name,
typename OP>
void verify_op_ab(const T l, const T r, const float ref, OP op) {
T mat[NUM_ROWS / VF][NUM_COLS * VF];
big_matrix<T, NUM_ROWS / VF, NUM_COLS * VF> big_mat((T *)&mat);

buffer<T, 2> bufMat(big_mat.get_data(), range<2>(NUM_ROWS, NUM_COLS));
buffer<T, 2> bufMat(big_mat.get_data(),
range<2>(NUM_ROWS / VF, NUM_COLS * VF));

queue q;
size_t sg_size = get_sg_size<kernel_name>(q);
Expand All @@ -47,20 +49,19 @@ void verify_op_a(const T l, const T r, const float ref, OP op) {
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

sub_group sg = spmd_item.get_sub_group();
joint_matrix<sub_group, T, use::a, SUB_ROWS, SUB_COLS,
layout::row_major>
sub_mat;
joint_matrix<sub_group, T, Use, SUB_ROWS, SUB_COLS, Layout> sub_mat;
joint_matrix_fill(sg, sub_mat, l);
joint_matrix_apply(sg, sub_mat, [=](T &x) { x = op(x, r); });
ext::intel::experimental::matrix::joint_matrix_store(
sg, sub_mat,
accessMat.template get_multi_ptr<access::decorated::no>() +
(sg_startx * SUB_ROWS) * NUM_COLS +
sg_starty / sg_size * SUB_COLS,
NUM_COLS);
(sg_startx * SUB_ROWS / VF) * NUM_COLS * VF +
sg_starty / sg_size * SUB_COLS * VF,
NUM_COLS * VF);
}); // parallel for
}).wait();
assert_ops_ref<T, NUM_ROWS, NUM_COLS>(bufMat.get_host_access(read_only), ref);
assert_ops_ref<T, NUM_ROWS / VF, NUM_COLS * VF>(
bufMat.get_host_access(read_only), ref);
}

template <typename T, size_t NUM_ROWS, size_t NUM_COLS, size_t SUB_ROWS,
Expand Down Expand Up @@ -105,37 +106,55 @@ void verify_op_c(const T l, const T r, const float ref, OP op) {
}

// Avoid same kernel name for different types
template <typename T, class name> class ewops_a {};
template <typename T, size_t SROWS, size_t SCOLS> void test_ewops_a() {
std::cout << "Test A " << SROWS << "x" << SCOLS << "\n";
template <typename T, size_t SROWS, size_t SCOLS, use Use, class name>
class ewops_ab {};
template <typename T, size_t SROWS, size_t SCOLS, use Use, layout Layout,
size_t VF>
void test_ewops_ab() {
if constexpr (Use == use::a)
std::cout << "Test A ";
else
std::cout << "Test B ";
std::cout << SROWS << "x" << SCOLS << "\n";

static constexpr size_t NROWS = SROWS * 2;
static constexpr size_t NCOLS = SCOLS * 2;

verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_add>>(
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
ewops_ab<T, SROWS, SCOLS, Use, class ab_add>>(
T(5.0), T(2.0), 7.0, [](auto l, auto r) { return l + r; });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_sub>>(
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
ewops_ab<T, SROWS, SCOLS, Use, class ab_sub>>(
T(5.0), T(2.0), 3.0, [](auto l, auto r) { return l - r; });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_mul>>(
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
ewops_ab<T, SROWS, SCOLS, Use, class ab_mul>>(
T(5.0), T(2.0), 10.0, [](auto l, auto r) { return l * r; });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_div>>(
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
ewops_ab<T, SROWS, SCOLS, Use, class ab_div>>(
T(5.0), T(2.0), 2.5, [](auto l, auto r) { return l / r; });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_logical>>(
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
ewops_ab<T, SROWS, SCOLS, Use, class ab_logical>>(
T(5.0), T(5.0), 5.0, [](auto l, auto r) { return l == r ? l : T(1.0); });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_eq>>(
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
ewops_ab<T, SROWS, SCOLS, Use, class ab_eq>>(
T(5.0), T(4.0), 4.0, [](auto l, auto r) { return l == r ? l : r; });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_ne>>(
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
ewops_ab<T, SROWS, SCOLS, Use, class ab_ne>>(
T(5.0), T(5.0), 1.0, [](auto l, auto r) { return l != r ? l : T(1.0); });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_gt>>(
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
ewops_ab<T, SROWS, SCOLS, Use, class ab_gt>>(
T(5.0), T(2.0), 3.0,
[](auto l, auto r) { return l > r ? T(3.0) : T(2.0); });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_lt>>(
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
ewops_ab<T, SROWS, SCOLS, Use, class ab_lt>>(
T(5.0), T(2.0), 2.0,
[](auto l, auto r) { return l < r ? T(3.0) : T(2.0); });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_ge>>(
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
ewops_ab<T, SROWS, SCOLS, Use, class ab_ge>>(
T(5.0), T(2.0), 3.0,
[](auto l, auto r) { return l >= r ? T(3.0) : T(2.0); });
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_le>>(
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
ewops_ab<T, SROWS, SCOLS, Use, class ab_le>>(
T(5.0), T(2.0), 2.0,
[](auto l, auto r) { return l <= r ? T(3.0) : T(2.0); });
}
Expand Down Expand Up @@ -194,30 +213,34 @@ int main() {
.get_info<sycl::ext::oneapi::experimental::info::device::
matrix_combinations>();

for (unsigned int i = 0; i < combinations.size(); i++) {
if (combinations[i].atype == matrix_type::bf16) {

if (combinations[i].nsize == 0 ||
(combinations[i].msize == 0 && combinations[i].nsize == 16)) {
test_ewops_a<bfloat16, 8, 16>();
test_ewops_c<float, 8, 16>();
}

if (combinations[i].msize == 16 && combinations[i].nsize == 16) {
for (auto &combination : combinations) {
if (combination.nsize == 0 ||
combination.nsize == 16) { // Intel AMX or architecture::intel_gpu_pvc
test_ewops_ab<bfloat16, 1, 16, use::a, layout::row_major, 1>();
test_ewops_ab<bfloat16, 8, 16, use::a, layout::row_major, 1>();
test_ewops_ab<bfloat16, 16, 16, use::b, layout::ext_intel_packed, 2>();
test_ewops_c<float, 1, 16>();
test_ewops_c<float, 8, 16>();

if (combination.nsize == 16) { // architecture::intel_gpu_pvc
test_ewops_ab<bfloat16, 16, 16, use::a, layout::row_major, 1>();
test_ewops_c<float, 16, 16>();
}

// This combination is not currently supported for sub group size = 32 in IGC
#if (!defined(SG_SZ) || SG_SZ != 32)
if (combinations[i].msize == 32 && combinations[i].nsize == 64) {
test_ewops_ab<bfloat16, 32, 16, use::a, layout::row_major, 1>();
test_ewops_ab<bfloat16, 16, 64, use::b, layout::ext_intel_packed, 2>();
test_ewops_c<float, 1, 64>();
test_ewops_c<float, 32, 64>();
}
#endif

if (combinations[i].nsize == 8) {
test_ewops_a<bfloat16, 8, 16>();
test_ewops_c<float, 8, 8>();
}
break;
}

if (combination.nsize == 8) { // architecture::intel_gpu_dg2*
test_ewops_ab<bfloat16, 8, 16, use::a, layout::row_major, 1>();
test_ewops_ab<bfloat16, 16, 8, use::b, layout::ext_intel_packed, 2>();
test_ewops_c<float, 8, 8>();
break;
}
}

Expand Down

0 comments on commit cfd0d41

Please sign in to comment.