Skip to content

Commit

Permalink
[SYCL][HIP] Update amd joint_matrix tests to reflect changes to joint…
Browse files Browse the repository at this point in the history
…_matrix_mad API. (#13250)

- The` joint_matrix_mad` API has been modified to accept the output as
an argument to the function. This pull request updates the relevant
tests to accommodate this change for amd gpu.
- Minor update to check joint_matrix parameters in compile time.
  • Loading branch information
mmoadeli committed Apr 4, 2024
1 parent 75afc83 commit 0bcabae
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// REQUIRES: hip

// RUN: %clangxx -fsycl -fsycl-targets=amd_gpu_gfx90a %s -o compile-query-hip

#include <iostream>
Expand All @@ -14,21 +13,16 @@ int main() {
using myparams = matrix_params<architecture::amd_gpu_gfx90a, int8_t, int8_t,
int32_t, int32_t, 32, 32, 8>;

size_t dmsize = myparams::M;
size_t dnsize = myparams::N;
size_t dksize = myparams::K;
std::cout
<< "sizes of AMD gpu gfx90a matrix_params chosen by the user are: M "
<< dmsize << " N " << dnsize << " K " << dksize << std::endl;
static_assert(myparams::M == 32);
static_assert(myparams::N == 32);
static_assert(myparams::K == 8);

// Sizes-only compile-time query: types are given, generate default sizes
using myparams2 = matrix_params<architecture::amd_gpu_gfx90a, int8_t, int8_t,
int32_t, int32_t>;
myparams2 p;
dmsize = myparams2::M;
dnsize = myparams2::N;
dksize = myparams2::K;
std::cout << "default AMD gpu gfx90a sizes matrix_params are: M " << dmsize
<< " N " << dnsize << " K " << dksize << std::endl;
static_assert(myparams2::M == 16);
static_assert(myparams2::N == 16);
static_assert(myparams2::K == 4);

return 0;
};
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
// REQUIRES: hip
// XFAIL: hip

// RUN: %clangxx -fsycl-device-only -fsycl-targets=amd_gpu_gfx90a -S -Xclang -emit-llvm %s -o -| FileCheck %s

#include <sycl/sycl.hpp>
Expand All @@ -10,12 +8,10 @@ using namespace sycl::ext::oneapi::experimental::matrix;
using sycl::ext::oneapi::bfloat16;

int main() {

buffer<bfloat16, 1> bufA(nullptr, range<1>(1));
buffer<bfloat16, 1> bufB(nullptr, range<1>(1));
buffer<float, 1> bufC(nullptr, range<1>(1));
buffer<float, 1> bufD(nullptr, range<1>(1));

queue q;

q.submit([&](handler &cgh) {
Expand All @@ -42,9 +38,8 @@ int main() {
sub_a{};
joint_matrix<sub_group, bfloat16, use::b, 16, 16, layout::row_major>
sub_b{};

// CHECK: tail call <4 x float> @llvm.amdgcn.mfma.f32.16x16x16bf16.1k(<4 x i16> %{{.*}}, <4 x i16> %{{.*}} <4 x float> zeroinitializer, i32 0, i32 0, i32 0)
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
// CHECK: tail call <4 x float> @llvm.amdgcn.mfma.f32.16x16x16bf16.1k(<4 x i16> zeroinitializer, <4 x i16> zeroinitializer, <4 x float> zeroinitializer, i32 0, i32 0, i32 0)
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
joint_matrix_store(
sg, sub_c, accD.template get_multi_ptr<access::decorated::yes>(),
16, layout::row_major);
Expand All @@ -61,8 +56,8 @@ int main() {
joint_matrix<sub_group, bfloat16, use::b, 8, 32, layout::col_major>
sub_b{};

// CHECK: tail call <16 x float> @llvm.amdgcn.mfma.f32.32x32x8bf16.1k(<4 x i16> {{.*}}, <4 x i16> {{.*}}, <16 x float> zeroinitializer, i32 0, i32 0, i32 0)
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
// CHECK: tail call <16 x float> @llvm.amdgcn.mfma.f32.32x32x8bf16.1k(<4 x i16> zeroinitializer, <4 x i16> zeroinitializer, <16 x float> zeroinitializer, i32 0, i32 0, i32 0)
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
joint_matrix_store(
sg, sub_c, accD.template get_multi_ptr<access::decorated::yes>(),
32, layout::row_major);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
// REQUIRES: hip
// XFAIL: hip

// RUN: %clangxx -fsycl-device-only -fsycl-targets=amd_gpu_gfx90a -S -Xclang -emit-llvm %s -o -| FileCheck %s

#include <sycl/sycl.hpp>
Expand All @@ -9,12 +7,10 @@ using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

int main() {

buffer<double, 1> bufA(nullptr, range<1>(1));
buffer<double, 1> bufB(nullptr, range<1>(1));
buffer<double, 1> bufC(nullptr, range<1>(1));
buffer<double, 1> bufD(nullptr, range<1>(1));

queue q;

q.submit([&](handler &cgh) {
Expand Down Expand Up @@ -42,8 +38,8 @@ int main() {
joint_matrix<sub_group, double, use::b, 4, 16, layout::row_major>
sub_b{};

// CHECK: tail call <4 x double> @llvm.amdgcn.mfma.f64.16x16x4f64(double %{{.*}}, double %{{.*}}, <4 x double> zeroinitializer, i32 0, i32 0, i32 0)
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
// CHECK: tail call <4 x double> @llvm.amdgcn.mfma.f64.16x16x4f64(double {{.*}}, double {{.*}}, <4 x double> zeroinitializer, i32 0, i32 0, i32 0)
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
joint_matrix_store(
sg, sub_c, accD.template get_multi_ptr<access::decorated::yes>(),
16, layout::row_major);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
// REQUIRES: hip
// XFAIL: hip

// RUN: %clangxx -fsycl-device-only -fsycl-targets=amd_gpu_gfx90a -S -Xclang -emit-llvm %s -o -| FileCheck %s

#include <sycl/sycl.hpp>
Expand All @@ -9,12 +7,10 @@ using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

int main() {

buffer<half, 1> bufA(nullptr, range<1>(1));
buffer<half, 1> bufB(nullptr, range<1>(1));
buffer<float, 1> bufC(nullptr, range<1>(1));
buffer<float, 1> bufD(nullptr, range<1>(1));

queue q;

q.submit([&](handler &cgh) {
Expand Down Expand Up @@ -42,8 +38,8 @@ int main() {
joint_matrix<sub_group, half, use::b, 16, 16, layout::row_major>
sub_b{};

// CHECK: tail call <4 x float> @llvm.amdgcn.mfma.f32.16x16x16f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <4 x float> zeroinitializer, i32 0, i32 0, i32 0)
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
// CHECK: tail call <4 x float> @llvm.amdgcn.mfma.f32.16x16x16f16(<4 x half> zeroinitializer, <4 x half> zeroinitializer, <4 x float> zeroinitializer, i32 0, i32 0, i32 0)
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
joint_matrix_store(
sg, sub_c, accD.template get_multi_ptr<access::decorated::yes>(),
16, layout::row_major);
Expand All @@ -60,8 +56,8 @@ int main() {
joint_matrix<sub_group, half, use::b, 8, 32, layout::col_major>
sub_b{};

// CHECK: tail call <16 x float> @llvm.amdgcn.mfma.f32.32x32x8f16(<4 x half> {{.*}}, <4 x half> {{.*}}, <16 x float> zeroinitializer, i32 0, i32 0, i32 0)
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
// CHECK: tail call <16 x float> @llvm.amdgcn.mfma.f32.32x32x8f16(<4 x half> zeroinitializer, <4 x half> zeroinitializer, <16 x float> zeroinitializer, i32 0, i32 0, i32 0)
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
joint_matrix_store(
sg, sub_c, accD.template get_multi_ptr<access::decorated::yes>(),
32, layout::row_major);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
// REQUIRES: hip
// XFAIL: hip

// RUN: %clangxx -fsycl-device-only -fsycl-targets=amd_gpu_gfx90a -S -Xclang -emit-llvm %s -o -| FileCheck %s

#include <sycl/sycl.hpp>
Expand All @@ -9,12 +7,10 @@ using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

int main() {

buffer<int8_t, 1> bufA(nullptr, range<1>(1));
buffer<int8_t, 1> bufB(nullptr, range<1>(1));
buffer<int32_t, 1> bufC(nullptr, range<1>(1));
buffer<int32_t, 1> bufD(nullptr, range<1>(1));

queue q;

q.submit([&](handler &cgh) {
Expand Down Expand Up @@ -42,8 +38,8 @@ int main() {
joint_matrix<sub_group, int8_t, use::b, 16, 16, layout::row_major>
sub_b{};

// CHECK: tail call <4 x i32> @llvm.amdgcn.mfma.i32.16x16x16i8(i32 %{{.*}}, i32 %{{.*}}, <4 x i32> zeroinitializer, i32 0, i32 0, i32 0)
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
// CHECK: tail call <4 x i32> @llvm.amdgcn.mfma.i32.16x16x16i8(i32 {{.*}}, i32 {{.*}}, <4 x i32> zeroinitializer, i32 0, i32 0, i32 0)
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
joint_matrix_store(
sg, sub_c, accD.template get_multi_ptr<access::decorated::yes>(),
16, layout::row_major);
Expand All @@ -61,7 +57,7 @@ int main() {
sub_b{};

// CHECK: tail call <16 x i32> @llvm.amdgcn.mfma.i32.32x32x8i8(i32 {{.*}}, i32 {{.*}}, <16 x i32> zeroinitializer, i32 0, i32 0, i32 0)
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
joint_matrix_store(
sg, sub_c, accD.template get_multi_ptr<access::decorated::yes>(),
32, layout::row_major);
Expand Down

0 comments on commit 0bcabae

Please sign in to comment.