From 1d1dd93b2e73c3e67c43d3af069bc432c0fc4fb0 Mon Sep 17 00:00:00 2001 From: Dounia Date: Mon, 12 Feb 2024 13:24:02 -0800 Subject: [PATCH] remove breaks and make the combination more specific --- .../joint_matrix_rowmajorA_rowmajorB_impl.hpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp b/sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp index d43fe9f9be793..036f75abb9c97 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp @@ -110,19 +110,22 @@ int main() { matrix_combinations>(); for (unsigned int i = 0; i < combinations.size(); i++) { if (combinations[i].atype == matrix_type::bf16) { - if (combinations[i].nsize == 0 || combinations[i].nsize == 16) { + if (combinations[i].nsize == 0 || + (combinations[i].nsize == 16 && combinations[i].max_msize == 8 && + combinations[i].ksize == 16)) { gemm_row_major<16, 16, class gemm_bfloat16_16, bfloat16, bfloat16, float>(); - break; } - if (combinations[i].nsize == 8) { + if (combinations[i].nsize == 8 && combinations[i].max_msize == 8 && + combinations[i].ksize == 16) { gemm_row_major<8, 16, class gemm_bfloat16_8, bfloat16, bfloat16, float>(); - break; } } if (combinations[i].atype == matrix_type::sint8) { - if (combinations[i].nsize == 0 || combinations[i].nsize == 16) { + if (combinations[i].nsize == 0 || + (combinations[i].nsize == 16 && combinations[i].max_msize == 8 && + combinations[i].ksize == 32)) { gemm_row_major<16, 32, class gemm_int8_16, int8_t, int8_t, int32_t>(); gemm_row_major<16, 32, class gemm_us_int8_16, uint8_t, int8_t, int32_t>(); @@ -130,15 +133,14 @@ int main() { int32_t>(); gemm_row_major<16, 32, class gemm_uu_int8_16, uint8_t, uint8_t, int32_t>(); - break; } - if (combinations[i].nsize == 8) { + if (combinations[i].nsize == 8 && combinations[i].max_msize == 8 && + combinations[i].ksize == 32) { gemm_row_major<8, 32, class gemm_int8_8, int8_t, int8_t, int32_t>(); gemm_row_major<8, 32, class gemm_us_int8_8, uint8_t, int8_t, int32_t>(); gemm_row_major<8, 32, class gemm_su_int8_8, int8_t, uint8_t, int32_t>(); gemm_row_major<8, 32, class gemm_uu_int8_8, uint8_t, uint8_t, int32_t>(); - break; } } }