From 1e3af564106f211b01b931969ba9a35e31af9eec Mon Sep 17 00:00:00 2001 From: Dounia Khaldi Date: Wed, 24 Jan 2024 13:40:48 -0800 Subject: [PATCH] [SYCL][joint matrix] add missing licence to test and add combination-based query --- .../Matrix/SG32/element_wise_all_ops.cpp | 10 ++----- sycl/test-e2e/Matrix/common.hpp | 20 ++++++++++++++ sycl/test-e2e/Matrix/element_wise_all_ops.cpp | 10 ++----- .../Matrix/element_wise_all_ops_impl.hpp | 26 ++++--------------- 4 files changed, 29 insertions(+), 37 deletions(-) diff --git a/sycl/test-e2e/Matrix/SG32/element_wise_all_ops.cpp b/sycl/test-e2e/Matrix/SG32/element_wise_all_ops.cpp index 91b36ee032e27..7b90389af548b 100644 --- a/sycl/test-e2e/Matrix/SG32/element_wise_all_ops.cpp +++ b/sycl/test-e2e/Matrix/SG32/element_wise_all_ops.cpp @@ -5,19 +5,13 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// REQUIRES: matrix +// REQUIRES: cpu, gpu // REQUIRES-INTEL-DRIVER: lin: 27501, win: 101.4943 // RUN: %{build} -o %t.out // RUN: %{run} %t.out -#include -#include -#include - -using namespace sycl; -using namespace sycl::ext::oneapi::experimental::matrix; -using bfloat16 = sycl::ext::oneapi::bfloat16; +#include "../common.hpp" constexpr size_t SG_SZ = 32; constexpr size_t TN = 16; diff --git a/sycl/test-e2e/Matrix/common.hpp b/sycl/test-e2e/Matrix/common.hpp index 675261a17f3cb..93a52a4c9bfeb 100644 --- a/sycl/test-e2e/Matrix/common.hpp +++ b/sycl/test-e2e/Matrix/common.hpp @@ -1,3 +1,10 @@ +//==------------------ common.hpp - DPC++ joint_matrix---------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// #include #include #include @@ -173,3 +180,16 @@ bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) { } return true; } + +bool is_type_supported_by_device(queue q, matrix_type type) { + std::vector combinations = + q.get_device() + .get_info(); + for (int i = 0; i < combinations.size(); i++) { + if (combinations[i].atype == type) { + return true; + } + } + return false; +} diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops.cpp b/sycl/test-e2e/Matrix/element_wise_all_ops.cpp index fd3648664a52c..c4a9967a658db 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops.cpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops.cpp @@ -5,18 +5,12 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// REQUIRES: matrix +// REQUIRES: cpu, gpu // RUN: %{build} -o %t.out // RUN: %{run} %t.out -#include -#include -#include - -using namespace sycl; -using namespace sycl::ext::oneapi::experimental::matrix; -using bfloat16 = sycl::ext::oneapi::bfloat16; +#include "common.hpp" #define SG_SZ 16 constexpr size_t TN = 16; diff --git a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp index b11d3093bf08d..b0e4b51cbbd72 100644 --- a/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp +++ b/sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp @@ -5,24 +5,6 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// - -static float make_fp32(bfloat16 x) { - unsigned int y = *((int *)&x); - y = y << 16; - float *res = reinterpret_cast(&y); - return *res; -} - -template struct big_matrix { -public: - T *mat; - -public: - T *get_data() { return mat; } - void set_data(T *data) { mat = data; } - big_matrix(T *data) : mat(data) {} -}; - template void assert_ops_ref(host_accessor mat, const float ref) { @@ -181,9 +163,11 @@ int main() { static constexpr size_t MATRIX_M = TM * 2; static constexpr size_t MATRIX_N = TN * 2; static constexpr size_t MATRIX_K = TK * 2; - - test_ewops_a(); - test_ewops_c(); + queue q; + if (is_type_supported_by_device(q, matrix_type::bf16)) { + test_ewops_a(); + test_ewops_c(); + } return 0; }