diff --git a/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp index af2f37d152ca0..faad677343611 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/static-query-use.hpp @@ -490,25 +490,25 @@ struct matrix_params< template constexpr bool is_combination_valid_amd_gfx90a(size_t sM, size_t sN, size_t sK) { - return ((std::is_same_v && std::is_same_v && - ((sM == 32 && sN == 32 && sK == 8) || - (sM == 16 && sN == 16 && sK == 16))) || - (std::is_same_v && std::is_same_v && - ((sM == 32 && sN == 32 && sK == 8) || - (sM == 16 && sN == 16 && sK == 16))) || - (std::is_same_v && std::is_same_v && - ((sM == 32 && sN == 32 && sK == 8) || - (sM == 16 && sN == 16 && sK == 16))) || - (std::is_same_v && std::is_same_v && - (sM == 16 && sN == 16 && sK == 4))); + return (std::is_same_v && std::is_same_v && + ((sM == 32 && sN == 32 && sK == 8) || + (sM == 16 && sN == 16 && sK == 16))) || + (std::is_same_v && std::is_same_v && + ((sM == 32 && sN == 32 && sK == 8) || + (sM == 16 && sN == 16 && sK == 16))) || + (std::is_same_v && std::is_same_v && + ((sM == 32 && sN == 32 && sK == 8) || + (sM == 16 && sN == 16 && sK == 16))) || + (std::is_same_v && std::is_same_v && + (sM == 16 && sN == 16 && sK == 4)); } template constexpr bool are_types_valid_amd_gfx90a() { - return ((std::is_same_v && std::is_same_v) || - (std::is_same_v && std::is_same_v) || - (std::is_same_v && std::is_same_v) || - (std::is_same_v && std::is_same_v)); + return (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v); } // Default-values query: @@ -521,7 +521,7 @@ struct matrix_params< !std::is_same_v && !std::is_same_v && std::is_same_v && std::is_same_v)>> { static_assert( - (are_types_valid_amd_gfx90a()), + are_types_valid_amd_gfx90a(), "Invalid types for AMD gfx90a, supported types are half, float, " "int8_t, int32_t, double and bfloat16 "); @@ -577,60 +577,60 @@ struct matrix_params< template constexpr bool are_types_valid_cuda_sm70() { - return (((std::is_same_v && std::is_same_v && - std::is_same_v) || - (std::is_same_v && std::is_same_v && - std::is_same_v) || - (std::is_same_v && std::is_same_v && - std::is_same_v) || - (std::is_same_v && std::is_same_v && - std::is_same_v))); + return (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v); } template constexpr bool are_types_valid_cuda_sm72() { - return (((std::is_same_v && std::is_same_v && - std::is_same_v) || - (std::is_same_v && std::is_same_v && - std::is_same_v))); + return (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v); } template constexpr bool are_types_valid_cuda_sm80() { - return ((std::is_same_v && std::is_same_v && - std::is_same_v) || - (std::is_same_v && std::is_same_v && - std::is_same_v) || - (std::is_same_v && std::is_same_v && - std::is_same_v)); + return (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v); } template constexpr bool is_combination_valid_cuda_sm70(size_t sM, size_t sN, size_t sK) { - return (are_types_valid_cuda_sm70() && - ((sM == 8 && sN == 32 && sK == 16) || - (sM == 16 && sN == 16 && sK == 16) || - (sM == 32 && sN == 8 && sK == 16))); + return are_types_valid_cuda_sm70() && + ((sM == 8 && sN == 32 && sK == 16) || + (sM == 16 && sN == 16 && sK == 16) || + (sM == 32 && sN == 8 && sK == 16)); } template constexpr bool is_combination_valid_cuda_sm72(size_t sM, size_t sN, size_t sK) { - return (are_types_valid_cuda_sm72() && - ((sM == 8 && sN == 32 && sK == 16) || - (sM == 16 && sN == 16 && sK == 16) || - (sM == 32 && sN == 8 && sK == 16))); + return are_types_valid_cuda_sm72() && + ((sM == 8 && sN == 32 && sK == 16) || + (sM == 16 && sN == 16 && sK == 16) || + (sM == 32 && sN == 8 && sK == 16)); } template constexpr bool is_combination_valid_cuda_sm80(size_t sM, size_t sN, size_t sK) { - return (((std::is_same_v && std::is_same_v && - std::is_same_v)&&(sM == 16 && sN == 16 && sK == 8)) || - ((std::is_same_v && std::is_same_v && - std::is_same_v)&&((sM == 16 && sN == 16 && sK == 16) || - (sM == 8 && sN == 32 && sK == 16) || - (sM == 32 && sN == 8 && sK == 16))) || - ((std::is_same_v && std::is_same_v && - std::is_same_v)&&(sM == 8 && sN == 8 && sK == 4))); + return ((std::is_same_v && std::is_same_v && + std::is_same_v)&&(sM == 16 && sN == 16 && sK == 8)) || + ((std::is_same_v && std::is_same_v && + std::is_same_v)&&((sM == 16 && sN == 16 && sK == 16) || + (sM == 8 && sN == 32 && sK == 16) || + (sM == 32 && sN == 8 && sK == 16))) || + ((std::is_same_v && std::is_same_v && + std::is_same_v)&&(sM == 8 && sN == 8 && sK == 4)); } // Default-values query (nvidia sm70): @@ -643,7 +643,7 @@ struct matrix_params< !std::is_same_v && !std::is_same_v && std::is_same_v)>> { static_assert( - (are_types_valid_cuda_sm70()), + are_types_valid_cuda_sm70(), "Invalid types for nvidia sm70, supported types are half and float "); // Default sizes for nvidia sm70 were chosen to represent a square matrix @@ -671,8 +671,8 @@ struct matrix_params< !std::is_same_v && !std::is_same_v && std::is_same_v)>::type> { static_assert( - (are_types_valid_cuda_sm70() || - are_types_valid_cuda_sm72()), + are_types_valid_cuda_sm70() || + are_types_valid_cuda_sm72(), "Invalid types for nvidia sm72, supported types are half, float " "int8_t, uint8_t and int32_t "); @@ -700,17 +700,16 @@ struct matrix_params< !std::is_same_v && !std::is_same_v && std::is_same_v)>> { static_assert( - (are_types_valid_cuda_sm70() || - are_types_valid_cuda_sm72() || - are_types_valid_cuda_sm80()), + are_types_valid_cuda_sm70() || + are_types_valid_cuda_sm72() || + are_types_valid_cuda_sm80(), "Invalid types for nvidia sm80, supported types are half, float " "int8_t, uint8_t, int32_t, double, tf32 and bfloat16 "); - static constexpr std::size_t M = ((sizeof(Ta) == 8) ? 8 : 16); - static constexpr std::size_t N = ((sizeof(Ta) == 8) ? 8 : 16); + static constexpr std::size_t M = (sizeof(Ta) == 8) ? 8 : 16; + static constexpr std::size_t N = (sizeof(Ta) == 8) ? 8 : 16; static constexpr std::size_t K = - ((std::is_same_v) ? 8 - : ((sizeof(Ta) == 8) ? 4 : 16)); + std::is_same_v ? 8 : (sizeof(Ta) == 8 ? 4 : 16); template using joint_matrix_a = joint_matrix; @@ -763,8 +762,8 @@ struct matrix_params< !std::is_same_v && !std::is_same_v && std::is_same_v && sM != 0 && sN != 0 && sK != 0)>> { static_assert( - (is_combination_valid_cuda_sm70(sM, sN, sK) || - is_combination_valid_cuda_sm72(sM, sN, sK)), + is_combination_valid_cuda_sm70(sM, sN, sK) || + is_combination_valid_cuda_sm72(sM, sN, sK), "Invalid parameters for nvidia sm72, query valid combinations " "using: " "q.get_device().get_info()"); @@ -794,9 +793,9 @@ struct matrix_params< !std::is_same_v && !std::is_same_v && std::is_same_v && sM != 0 && sN != 0 && sK != 0)>> { static_assert( - (is_combination_valid_cuda_sm70(sM, sN, sK) || - is_combination_valid_cuda_sm72(sM, sN, sK) || - is_combination_valid_cuda_sm80(sM, sN, sK)), + is_combination_valid_cuda_sm70(sM, sN, sK) || + is_combination_valid_cuda_sm72(sM, sN, sK) || + is_combination_valid_cuda_sm80(sM, sN, sK), "Invalid parameters for nvidia sm80, query valid combinations " "using: " "q.get_device().get_info()"); diff --git a/sycl/source/detail/device_info.hpp b/sycl/source/detail/device_info.hpp index 8f2329acb1671..e9dbd52212ce8 100644 --- a/sycl/source/detail/device_info.hpp +++ b/sycl/source/detail/device_info.hpp @@ -792,7 +792,7 @@ struct get_device_info_impl< {8, 0, 0, 0, 8, 16, matrix_type::bf16, matrix_type::bf16, matrix_type::fp32, matrix_type::fp32}, }; - else if ((architecture::amd_gpu_gfx90a == DeviceArch)) + else if (architecture::amd_gpu_gfx90a == DeviceArch) return { {0, 0, 0, 32, 32, 8, matrix_type::fp16, matrix_type::fp16, matrix_type::fp32, matrix_type::fp32}, diff --git a/sycl/test-e2e/Matrix/runtime_query_hip_gfx90a.cpp b/sycl/test-e2e/Matrix/runtime_query_hip_gfx90a.cpp index 1d554ee7bc9eb..2eef5ee1ef933 100644 --- a/sycl/test-e2e/Matrix/runtime_query_hip_gfx90a.cpp +++ b/sycl/test-e2e/Matrix/runtime_query_hip_gfx90a.cpp @@ -8,20 +8,20 @@ using namespace sycl::ext::oneapi::experimental::matrix; bool find_combination(const combination &comb, const std::vector &expected_combinations) { - return (std::find_if(expected_combinations.begin(), - expected_combinations.end(), - [&comb](const auto &expected_comb) { - return (comb.max_msize == expected_comb.max_msize && - comb.max_nsize == expected_comb.max_nsize && - comb.max_ksize == expected_comb.max_ksize && - comb.msize == expected_comb.msize && - comb.nsize == expected_comb.nsize && - comb.ksize == expected_comb.ksize && - comb.atype == expected_comb.atype && - comb.btype == expected_comb.btype && - comb.ctype == expected_comb.ctype && - comb.dtype == expected_comb.dtype); - }) != expected_combinations.end()); + return std::find_if(expected_combinations.begin(), + expected_combinations.end(), + [&comb](const auto &expected_comb) { + return (comb.max_msize == expected_comb.max_msize && + comb.max_nsize == expected_comb.max_nsize && + comb.max_ksize == expected_comb.max_ksize && + comb.msize == expected_comb.msize && + comb.nsize == expected_comb.nsize && + comb.ksize == expected_comb.ksize && + comb.atype == expected_comb.atype && + comb.btype == expected_comb.btype && + comb.ctype == expected_comb.ctype && + comb.dtype == expected_comb.dtype); + }) != expected_combinations.end(); } int main() { diff --git a/sycl/test-e2e/Matrix/runtime_query_tensorcores.cpp b/sycl/test-e2e/Matrix/runtime_query_tensorcores.cpp index ad247dfc6cf0d..e69c512ac400e 100644 --- a/sycl/test-e2e/Matrix/runtime_query_tensorcores.cpp +++ b/sycl/test-e2e/Matrix/runtime_query_tensorcores.cpp @@ -8,20 +8,20 @@ using namespace sycl::ext::oneapi::experimental::matrix; bool find_combination(const combination &comb, const std::vector &expected_combinations) { - return (std::find_if(expected_combinations.begin(), - expected_combinations.end(), - [&comb](const auto &expected_comb) { - return (comb.max_msize == expected_comb.max_msize && - comb.max_nsize == expected_comb.max_nsize && - comb.max_ksize == expected_comb.max_ksize && - comb.msize == expected_comb.msize && - comb.nsize == expected_comb.nsize && - comb.ksize == expected_comb.ksize && - comb.atype == expected_comb.atype && - comb.btype == expected_comb.btype && - comb.ctype == expected_comb.ctype && - comb.dtype == expected_comb.dtype); - }) != expected_combinations.end()); + return std::find_if(expected_combinations.begin(), + expected_combinations.end(), + [&comb](const auto &expected_comb) { + return (comb.max_msize == expected_comb.max_msize && + comb.max_nsize == expected_comb.max_nsize && + comb.max_ksize == expected_comb.max_ksize && + comb.msize == expected_comb.msize && + comb.nsize == expected_comb.nsize && + comb.ksize == expected_comb.ksize && + comb.atype == expected_comb.atype && + comb.btype == expected_comb.btype && + comb.ctype == expected_comb.ctype && + comb.dtype == expected_comb.dtype); + }) != expected_combinations.end(); } int main() {