diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 0655d10..75ab95b 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -46,7 +46,7 @@ void HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(global ushort* A, int tM, { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K); + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); } } } @@ -55,7 +55,7 @@ void HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(global ushort* B, int tN, { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=4) { - prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, init_k + kk * tK, n + nn * tN, N); + prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); } } } @@ -64,7 +64,7 @@ void HELPER_NAME(btile_prefetch_vnni_sg8, MM, NN)(global ushort* B, int tN, int { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_vnni_d16_k16_n8v2_sg8(B, init_k + kk * tK, n + nn * tN, N); + prefetch_b_vnni_d16_k16_n8v2_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); } } } @@ -83,12 +83,12 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl for (int p = 0; p < PREFETCH_DISTANCE; p++) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K); + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); } } for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=4) { - prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, init_k + kk * tK, n + nn * tN, N); + prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); } } prefetch_k += tK * KK; @@ -178,12 +178,12 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* for (int p = 0; p < PREFETCH_DISTANCE; p++) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K); + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); } } for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_vnni_d16_k16_n8v2_sg8(B, init_k + kk * tK, n + nn * tN, N); + prefetch_b_vnni_d16_k16_n8v2_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); } } prefetch_k += tK * KK;