diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 2b3e7be..ef065c9 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -559,8 +559,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; } @@ -575,16 +575,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int k = 0; k < K; k += tK * KK) { // TODO: skip prefetch on the last iterations. - HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); - - short8 aData[KK][MM]; - HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); int8 bData[NN][KK]; HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(B, tN, K, N, k, n, bData); - HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); - prefetch_k += tK * KK; + short8 aData[KK][MM]; + HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { @@ -592,6 +589,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]); } } + if (kk == 0) { + HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + } } split_barrier_wait(); @@ -636,16 +637,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int k = 0; k < K; k += tK * KK) { // TODO: skip prefetch on the last iterations. - HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); - - short8 aData[KK][MM]; - HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); int8 bData[NN][KK]; HELPER_NAME(btile_load_blockread_vnni, MM, NN)(B, tN, K, N, k, n, bData); - HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); - prefetch_k += tK * KK; + short8 aData[KK][MM]; + HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { @@ -653,6 +651,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]); } } + if (kk == 0) { + HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + } } split_barrier_wait();