diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index fae01b3..a002514 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -881,12 +881,12 @@ int main(int argc, char** argv) } if (mask & 0x400) { - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + //bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + //bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + //bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + //bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + //bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + //bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); } diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 39d23ad..a8d6b4d 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -465,6 +465,9 @@ void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global ushort* B, int tN, in if (KK % 2 == 0 & NN % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn+=2) { + //if (get_sub_group_local_id() == 0) { + // printf("btile block load: %d, %d, %2d: n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), n, k, nn, kk, n + nn * tN, k + kk * tK); + //} int8 tmp[2][2]; intel_subgroup_block_read_transform_u16_k32n16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), tmp); for (int tnn = 0; tnn < 2; tnn++) { @@ -555,11 +558,14 @@ void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) { - if (KK % 2 == 0 & NN == 4 & SGS_PER_WG_Y >= 2) { - const int nn = (get_sub_group_id() / SGS_PER_WG_X) % 2 * 2; - for (int kk = 0; kk < KK; kk+=2) { - intel_subgroup_block_prefetch_u16_m32k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); - } + if (KK == 2 & NN == 4 & SGS_PER_WG_Y >= 4) { + const int sg_index_y = get_sub_group_id() / SGS_PER_WG_X; // index in [0, SGS_PER_WG_Y) + const int nn = sg_index_y % 2 * 2; // nn(sg_index_y) == 0, 2, 0, 2, 0, 2, 0, 2, ... + const int kk = sg_index_y / 2 % 2; // kk(sg_index_y) == 0, 0, 1, 1, 0, 0, 1, 1, ... + //if (get_sub_group_local_id() == 0) { + // printf("btile block prefetch: %d, %d, %2d: sg_y = %d, n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_y, n, k, nn, kk, n + nn * tN, k + kk * tK); + //} + intel_subgroup_block_prefetch_u16_m16k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); } else if (KK % 2 == 0 & NN % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn += 2) { @@ -589,11 +595,11 @@ void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN void HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) { - if (KK % 2 == 0 & NN == 4 & SGS_PER_WG_Y >= 4) { - const int nn = (get_sub_group_id() / SGS_PER_WG_X) % 4; - for (int kk = 0; kk < KK; kk+=2) { - intel_subgroup_block_prefetch_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); - } + if (KK == 2 & NN == 4 & SGS_PER_WG_Y >= 4) { + const int sg_index_y = get_sub_group_id() / SGS_PER_WG_X; // index in [0, SGS_PER_WG_Y) + const int nn = sg_index_y % 4; // nn(sg_index_y) == 0, 1, 2, 3, 0, 1, 2, 3 + const int kk = 0; // kk(sg_index_y) == 0, 0, 0, 0, 0, 0, 0, 0 + intel_subgroup_block_prefetch_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); } else if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn++) {