Skip to content

Commit

Permalink
remove breaks and make the combination more specific
Browse files Browse the repository at this point in the history
  • Loading branch information
dkhaldi committed Feb 12, 2024
1 parent da872db commit 1d1dd93
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions sycl/test-e2e/Matrix/joint_matrix_rowmajorA_rowmajorB_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,35 +110,37 @@ int main() {
matrix_combinations>();
for (unsigned int i = 0; i < combinations.size(); i++) {
if (combinations[i].atype == matrix_type::bf16) {
if (combinations[i].nsize == 0 || combinations[i].nsize == 16) {
if (combinations[i].nsize == 0 ||
(combinations[i].nsize == 16 && combinations[i].max_msize == 8 &&
combinations[i].ksize == 16)) {
gemm_row_major<16, 16, class gemm_bfloat16_16, bfloat16, bfloat16,
float>();
break;
}
if (combinations[i].nsize == 8) {
if (combinations[i].nsize == 8 && combinations[i].max_msize == 8 &&
combinations[i].ksize == 16) {
gemm_row_major<8, 16, class gemm_bfloat16_8, bfloat16, bfloat16,
float>();
break;
}
}
if (combinations[i].atype == matrix_type::sint8) {
if (combinations[i].nsize == 0 || combinations[i].nsize == 16) {
if (combinations[i].nsize == 0 ||
(combinations[i].nsize == 16 && combinations[i].max_msize == 8 &&
combinations[i].ksize == 32)) {
gemm_row_major<16, 32, class gemm_int8_16, int8_t, int8_t, int32_t>();
gemm_row_major<16, 32, class gemm_us_int8_16, uint8_t, int8_t,
int32_t>();
gemm_row_major<16, 32, class gemm_su_int8_16, int8_t, uint8_t,
int32_t>();
gemm_row_major<16, 32, class gemm_uu_int8_16, uint8_t, uint8_t,
int32_t>();
break;
}
if (combinations[i].nsize == 8) {
if (combinations[i].nsize == 8 && combinations[i].max_msize == 8 &&
combinations[i].ksize == 32) {
gemm_row_major<8, 32, class gemm_int8_8, int8_t, int8_t, int32_t>();
gemm_row_major<8, 32, class gemm_us_int8_8, uint8_t, int8_t, int32_t>();
gemm_row_major<8, 32, class gemm_su_int8_8, int8_t, uint8_t, int32_t>();
gemm_row_major<8, 32, class gemm_uu_int8_8, uint8_t, uint8_t,
int32_t>();
break;
}
}
}
Expand Down

0 comments on commit 1d1dd93

Please sign in to comment.