Skip to content

Commit

Permalink
Changed parensing
Browse files Browse the repository at this point in the history
  • Loading branch information
konradkusiak97 committed Feb 6, 2024
1 parent 0bf4b6f commit f2e1966
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 94 deletions.
129 changes: 64 additions & 65 deletions sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,25 +490,25 @@ struct matrix_params<
template <typename Ta, typename Tc>
constexpr bool is_combination_valid_amd_gfx90a(size_t sM, size_t sN,
size_t sK) {
return ((std::is_same_v<Ta, half> && std::is_same_v<Tc, float> &&
((sM == 32 && sN == 32 && sK == 8) ||
(sM == 16 && sN == 16 && sK == 16))) ||
(std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t> &&
((sM == 32 && sN == 32 && sK == 8) ||
(sM == 16 && sN == 16 && sK == 16))) ||
(std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> &&
((sM == 32 && sN == 32 && sK == 8) ||
(sM == 16 && sN == 16 && sK == 16))) ||
(std::is_same_v<Ta, double> && std::is_same_v<Tc, double> &&
(sM == 16 && sN == 16 && sK == 4)));
return (std::is_same_v<Ta, half> && std::is_same_v<Tc, float> &&
((sM == 32 && sN == 32 && sK == 8) ||
(sM == 16 && sN == 16 && sK == 16))) ||
(std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t> &&
((sM == 32 && sN == 32 && sK == 8) ||
(sM == 16 && sN == 16 && sK == 16))) ||
(std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> &&
((sM == 32 && sN == 32 && sK == 8) ||
(sM == 16 && sN == 16 && sK == 16))) ||
(std::is_same_v<Ta, double> && std::is_same_v<Tc, double> &&
(sM == 16 && sN == 16 && sK == 4));
}

template <typename Ta, typename Tc>
constexpr bool are_types_valid_amd_gfx90a() {
return ((std::is_same_v<Ta, half> && std::is_same_v<Tc, float>) ||
(std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t>) ||
(std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float>) ||
(std::is_same_v<Ta, double> && std::is_same_v<Tc, double>));
return (std::is_same_v<Ta, half> && std::is_same_v<Tc, float>) ||
(std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t>) ||
(std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float>) ||
(std::is_same_v<Ta, double> && std::is_same_v<Tc, double>);
}

// Default-values query:
Expand All @@ -521,7 +521,7 @@ struct matrix_params<
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
std::is_same_v<Ta, Tb> && std::is_same_v<Tc, Td>)>> {
static_assert(
(are_types_valid_amd_gfx90a<Ta, Tc>()),
are_types_valid_amd_gfx90a<Ta, Tc>(),
"Invalid types for AMD gfx90a, supported types are half, float, "
"int8_t, int32_t, double and bfloat16 ");

Expand Down Expand Up @@ -577,60 +577,60 @@ struct matrix_params<

template <typename Ta, typename Tc, typename Td>
constexpr bool are_types_valid_cuda_sm70() {
return (((std::is_same_v<Ta, half> && std::is_same_v<Tc, float> &&
std::is_same_v<Td, float>) ||
(std::is_same_v<Ta, half> && std::is_same_v<Tc, half> &&
std::is_same_v<Td, half>) ||
(std::is_same_v<Ta, half> && std::is_same_v<Tc, float> &&
std::is_same_v<Td, half>) ||
(std::is_same_v<Ta, half> && std::is_same_v<Tc, half> &&
std::is_same_v<Td, float>)));
return (std::is_same_v<Ta, half> && std::is_same_v<Tc, float> &&
std::is_same_v<Td, float>) ||
(std::is_same_v<Ta, half> && std::is_same_v<Tc, half> &&
std::is_same_v<Td, half>) ||
(std::is_same_v<Ta, half> && std::is_same_v<Tc, float> &&
std::is_same_v<Td, half>) ||
(std::is_same_v<Ta, half> && std::is_same_v<Tc, half> &&
std::is_same_v<Td, float>);
}

template <typename Ta, typename Tc, typename Td>
constexpr bool are_types_valid_cuda_sm72() {
return (((std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t> &&
std::is_same_v<Td, int32_t>) ||
(std::is_same_v<Ta, uint8_t> && std::is_same_v<Tc, int32_t> &&
std::is_same_v<Td, int32_t>)));
return (std::is_same_v<Ta, int8_t> && std::is_same_v<Tc, int32_t> &&
std::is_same_v<Td, int32_t>) ||
(std::is_same_v<Ta, uint8_t> && std::is_same_v<Tc, int32_t> &&
std::is_same_v<Td, int32_t>);
}

template <typename Ta, typename Tc, typename Td>
constexpr bool are_types_valid_cuda_sm80() {
return ((std::is_same_v<Ta, precision::tf32> && std::is_same_v<Tc, float> &&
std::is_same_v<Td, float>) ||
(std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> &&
std::is_same_v<Td, float>) ||
(std::is_same_v<Ta, double> && std::is_same_v<Tc, double> &&
std::is_same_v<Td, double>));
return (std::is_same_v<Ta, precision::tf32> && std::is_same_v<Tc, float> &&
std::is_same_v<Td, float>) ||
(std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> &&
std::is_same_v<Td, float>) ||
(std::is_same_v<Ta, double> && std::is_same_v<Tc, double> &&
std::is_same_v<Td, double>);
}

template <typename Ta, typename Tc, typename Td>
constexpr bool is_combination_valid_cuda_sm70(size_t sM, size_t sN, size_t sK) {
return (are_types_valid_cuda_sm70<Ta, Tc, Td>() &&
((sM == 8 && sN == 32 && sK == 16) ||
(sM == 16 && sN == 16 && sK == 16) ||
(sM == 32 && sN == 8 && sK == 16)));
return are_types_valid_cuda_sm70<Ta, Tc, Td>() &&
((sM == 8 && sN == 32 && sK == 16) ||
(sM == 16 && sN == 16 && sK == 16) ||
(sM == 32 && sN == 8 && sK == 16));
}

template <typename Ta, typename Tc, typename Td>
constexpr bool is_combination_valid_cuda_sm72(size_t sM, size_t sN, size_t sK) {
return (are_types_valid_cuda_sm72<Ta, Tc, Td>() &&
((sM == 8 && sN == 32 && sK == 16) ||
(sM == 16 && sN == 16 && sK == 16) ||
(sM == 32 && sN == 8 && sK == 16)));
return are_types_valid_cuda_sm72<Ta, Tc, Td>() &&
((sM == 8 && sN == 32 && sK == 16) ||
(sM == 16 && sN == 16 && sK == 16) ||
(sM == 32 && sN == 8 && sK == 16));
}

template <typename Ta, typename Tc, typename Td>
constexpr bool is_combination_valid_cuda_sm80(size_t sM, size_t sN, size_t sK) {
return (((std::is_same_v<Ta, precision::tf32> && std::is_same_v<Tc, float> &&
std::is_same_v<Td, float>)&&(sM == 16 && sN == 16 && sK == 8)) ||
((std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> &&
std::is_same_v<Td, float>)&&((sM == 16 && sN == 16 && sK == 16) ||
(sM == 8 && sN == 32 && sK == 16) ||
(sM == 32 && sN == 8 && sK == 16))) ||
((std::is_same_v<Ta, double> && std::is_same_v<Tc, double> &&
std::is_same_v<Td, double>)&&(sM == 8 && sN == 8 && sK == 4)));
return ((std::is_same_v<Ta, precision::tf32> && std::is_same_v<Tc, float> &&
std::is_same_v<Td, float>)&&(sM == 16 && sN == 16 && sK == 8)) ||
((std::is_same_v<Ta, bfloat16> && std::is_same_v<Tc, float> &&
std::is_same_v<Td, float>)&&((sM == 16 && sN == 16 && sK == 16) ||
(sM == 8 && sN == 32 && sK == 16) ||
(sM == 32 && sN == 8 && sK == 16))) ||
((std::is_same_v<Ta, double> && std::is_same_v<Tc, double> &&
std::is_same_v<Td, double>)&&(sM == 8 && sN == 8 && sK == 4));
}

// Default-values query (nvidia sm70):
Expand All @@ -643,7 +643,7 @@ struct matrix_params<
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
std::is_same_v<Ta, Tb>)>> {
static_assert(
(are_types_valid_cuda_sm70<Ta, Tc, Td>()),
are_types_valid_cuda_sm70<Ta, Tc, Td>(),
"Invalid types for nvidia sm70, supported types are half and float ");

// Default sizes for nvidia sm70 were chosen to represent a square matrix
Expand Down Expand Up @@ -671,8 +671,8 @@ struct matrix_params<
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
std::is_same_v<Ta, Tb>)>::type> {
static_assert(
(are_types_valid_cuda_sm70<Ta, Tc, Td>() ||
are_types_valid_cuda_sm72<Ta, Tc, Td>()),
are_types_valid_cuda_sm70<Ta, Tc, Td>() ||
are_types_valid_cuda_sm72<Ta, Tc, Td>(),
"Invalid types for nvidia sm72, supported types are half, float "
"int8_t, uint8_t and int32_t ");

Expand Down Expand Up @@ -700,17 +700,16 @@ struct matrix_params<
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
std::is_same_v<Ta, Tb>)>> {
static_assert(
(are_types_valid_cuda_sm70<Ta, Tc, Td>() ||
are_types_valid_cuda_sm72<Ta, Tc, Td>() ||
are_types_valid_cuda_sm80<Ta, Tc, Td>()),
are_types_valid_cuda_sm70<Ta, Tc, Td>() ||
are_types_valid_cuda_sm72<Ta, Tc, Td>() ||
are_types_valid_cuda_sm80<Ta, Tc, Td>(),
"Invalid types for nvidia sm80, supported types are half, float "
"int8_t, uint8_t, int32_t, double, tf32 and bfloat16 ");

static constexpr std::size_t M = ((sizeof(Ta) == 8) ? 8 : 16);
static constexpr std::size_t N = ((sizeof(Ta) == 8) ? 8 : 16);
static constexpr std::size_t M = (sizeof(Ta) == 8) ? 8 : 16;
static constexpr std::size_t N = (sizeof(Ta) == 8) ? 8 : 16;
static constexpr std::size_t K =
((std::is_same_v<Ta, precision::tf32>) ? 8
: ((sizeof(Ta) == 8) ? 4 : 16));
std::is_same_v<Ta, precision::tf32> ? 8 : (sizeof(Ta) == 8 ? 4 : 16);

template <typename Group, layout Layout>
using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;
Expand Down Expand Up @@ -763,8 +762,8 @@ struct matrix_params<
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0)>> {
static_assert(
(is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM, sN, sK) ||
is_combination_valid_cuda_sm72<Ta, Tc, Td>(sM, sN, sK)),
is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM, sN, sK) ||
is_combination_valid_cuda_sm72<Ta, Tc, Td>(sM, sN, sK),
"Invalid parameters for nvidia sm72, query valid combinations "
"using: "
"q.get_device().get_info<sycl::info::device::matrix::combinations>()");
Expand Down Expand Up @@ -794,9 +793,9 @@ struct matrix_params<
!std::is_same_v<Tc, void> && !std::is_same_v<Td, void> &&
std::is_same_v<Ta, Tb> && sM != 0 && sN != 0 && sK != 0)>> {
static_assert(
(is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM, sN, sK) ||
is_combination_valid_cuda_sm72<Ta, Tc, Td>(sM, sN, sK) ||
is_combination_valid_cuda_sm80<Ta, Tc, Td>(sM, sN, sK)),
is_combination_valid_cuda_sm70<Ta, Tc, Td>(sM, sN, sK) ||
is_combination_valid_cuda_sm72<Ta, Tc, Td>(sM, sN, sK) ||
is_combination_valid_cuda_sm80<Ta, Tc, Td>(sM, sN, sK),
"Invalid parameters for nvidia sm80, query valid combinations "
"using: "
"q.get_device().get_info<sycl::info::device::matrix::combinations>()");
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/device_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ struct get_device_info_impl<
{8, 0, 0, 0, 8, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::fp32, matrix_type::fp32},
};
else if ((architecture::amd_gpu_gfx90a == DeviceArch))
else if (architecture::amd_gpu_gfx90a == DeviceArch)
return {
{0, 0, 0, 32, 32, 8, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp32, matrix_type::fp32},
Expand Down
28 changes: 14 additions & 14 deletions sycl/test-e2e/Matrix/runtime_query_hip_gfx90a.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@ using namespace sycl::ext::oneapi::experimental::matrix;

bool find_combination(const combination &comb,
const std::vector<combination> &expected_combinations) {
return (std::find_if(expected_combinations.begin(),
expected_combinations.end(),
[&comb](const auto &expected_comb) {
return (comb.max_msize == expected_comb.max_msize &&
comb.max_nsize == expected_comb.max_nsize &&
comb.max_ksize == expected_comb.max_ksize &&
comb.msize == expected_comb.msize &&
comb.nsize == expected_comb.nsize &&
comb.ksize == expected_comb.ksize &&
comb.atype == expected_comb.atype &&
comb.btype == expected_comb.btype &&
comb.ctype == expected_comb.ctype &&
comb.dtype == expected_comb.dtype);
}) != expected_combinations.end());
return std::find_if(expected_combinations.begin(),
expected_combinations.end(),
[&comb](const auto &expected_comb) {
return (comb.max_msize == expected_comb.max_msize &&
comb.max_nsize == expected_comb.max_nsize &&
comb.max_ksize == expected_comb.max_ksize &&
comb.msize == expected_comb.msize &&
comb.nsize == expected_comb.nsize &&
comb.ksize == expected_comb.ksize &&
comb.atype == expected_comb.atype &&
comb.btype == expected_comb.btype &&
comb.ctype == expected_comb.ctype &&
comb.dtype == expected_comb.dtype);
}) != expected_combinations.end();
}

int main() {
Expand Down
28 changes: 14 additions & 14 deletions sycl/test-e2e/Matrix/runtime_query_tensorcores.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@ using namespace sycl::ext::oneapi::experimental::matrix;

bool find_combination(const combination &comb,
const std::vector<combination> &expected_combinations) {
return (std::find_if(expected_combinations.begin(),
expected_combinations.end(),
[&comb](const auto &expected_comb) {
return (comb.max_msize == expected_comb.max_msize &&
comb.max_nsize == expected_comb.max_nsize &&
comb.max_ksize == expected_comb.max_ksize &&
comb.msize == expected_comb.msize &&
comb.nsize == expected_comb.nsize &&
comb.ksize == expected_comb.ksize &&
comb.atype == expected_comb.atype &&
comb.btype == expected_comb.btype &&
comb.ctype == expected_comb.ctype &&
comb.dtype == expected_comb.dtype);
}) != expected_combinations.end());
return std::find_if(expected_combinations.begin(),
expected_combinations.end(),
[&comb](const auto &expected_comb) {
return (comb.max_msize == expected_comb.max_msize &&
comb.max_nsize == expected_comb.max_nsize &&
comb.max_ksize == expected_comb.max_ksize &&
comb.msize == expected_comb.msize &&
comb.nsize == expected_comb.nsize &&
comb.ksize == expected_comb.ksize &&
comb.atype == expected_comb.atype &&
comb.btype == expected_comb.btype &&
comb.ctype == expected_comb.ctype &&
comb.dtype == expected_comb.dtype);
}) != expected_combinations.end();
}

int main() {
Expand Down

0 comments on commit f2e1966

Please sign in to comment.