Skip to content

Commit

Permalink
[SYCL][joint matrix] add missing licence to test and add combination-…
Browse files Browse the repository at this point in the history
…based query
  • Loading branch information
dkhaldi committed Jan 24, 2024
1 parent d2463c6 commit 1e3af56
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 37 deletions.
10 changes: 2 additions & 8 deletions sycl/test-e2e/Matrix/SG32/element_wise_all_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,13 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// REQUIRES: matrix
// REQUIRES: cpu, gpu
// REQUIRES-INTEL-DRIVER: lin: 27501, win: 101.4943

// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

#include <iostream>
#include <random>
#include <sycl/sycl.hpp>

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;
using bfloat16 = sycl::ext::oneapi::bfloat16;
#include "../common.hpp"

constexpr size_t SG_SZ = 32;
constexpr size_t TN = 16;
Expand Down
20 changes: 20 additions & 0 deletions sycl/test-e2e/Matrix/common.hpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
//==------------------ common.hpp - DPC++ joint_matrix---------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include <cmath>
#include <iostream>
#include <random>
Expand Down Expand Up @@ -173,3 +180,16 @@ bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
}
return true;
}

bool is_type_supported_by_device(queue q, matrix_type type) {
std::vector<combination> combinations =
q.get_device()
.get_info<sycl::ext::oneapi::experimental::info::device::
matrix_combinations>();
for (int i = 0; i < combinations.size(); i++) {
if (combinations[i].atype == type) {
return true;
}
}
return false;
}
10 changes: 2 additions & 8 deletions sycl/test-e2e/Matrix/element_wise_all_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,12 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// REQUIRES: matrix
// REQUIRES: cpu, gpu

// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

#include <iostream>
#include <random>
#include <sycl/sycl.hpp>

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;
using bfloat16 = sycl::ext::oneapi::bfloat16;
#include "common.hpp"

#define SG_SZ 16
constexpr size_t TN = 16;
Expand Down
26 changes: 5 additions & 21 deletions sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

static float make_fp32(bfloat16 x) {
unsigned int y = *((int *)&x);
y = y << 16;
float *res = reinterpret_cast<float *>(&y);
return *res;
}

template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
public:
T *mat;

public:
T *get_data() { return mat; }
void set_data(T *data) { mat = data; }
big_matrix(T *data) : mat(data) {}
};

template <typename T, size_t NUM_ROWS, size_t NUM_COLS>
void assert_ops_ref(host_accessor<T, 2, access::mode::read> mat,
const float ref) {
Expand Down Expand Up @@ -181,9 +163,11 @@ int main() {
static constexpr size_t MATRIX_M = TM * 2;
static constexpr size_t MATRIX_N = TN * 2;
static constexpr size_t MATRIX_K = TK * 2;

test_ewops_a<bfloat16, MATRIX_M, MATRIX_K, TM, TK>();
test_ewops_c<float, MATRIX_M, MATRIX_N, TM, TN>();
queue q;
if (is_type_supported_by_device(q, matrix_type::bf16)) {
test_ewops_a<bfloat16, MATRIX_M, MATRIX_K, TM, TK>();
test_ewops_c<float, MATRIX_M, MATRIX_N, TM, TN>();
}

return 0;
}

0 comments on commit 1e3af56

Please sign in to comment.