Skip to content

Commit

Permalink
fix DG2 prefetches
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Mar 2, 2024
1 parent 459c109 commit b034c1d
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions samples/99_matrixexperiments/matrix_kernel_tiled.cl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(global ushort* A, int tM,
{
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm++) {
prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K);
prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K);
}
}
}
Expand All @@ -55,7 +55,7 @@ void HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(global ushort* B, int tN,
{
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn+=4) {
prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, init_k + kk * tK, n + nn * tN, N);
prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, prefetch_k + kk * tK, n + nn * tN, N);
}
}
}
Expand All @@ -64,7 +64,7 @@ void HELPER_NAME(btile_prefetch_vnni_sg8, MM, NN)(global ushort* B, int tN, int
{
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn+=2) {
prefetch_b_vnni_d16_k16_n8v2_sg8(B, init_k + kk * tK, n + nn * tN, N);
prefetch_b_vnni_d16_k16_n8v2_sg8(B, prefetch_k + kk * tK, n + nn * tN, N);
}
}
}
Expand All @@ -83,12 +83,12 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl
for (int p = 0; p < PREFETCH_DISTANCE; p++) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm++) {
prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K);
prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K);
}
}
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn+=4) {
prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, init_k + kk * tK, n + nn * tN, N);
prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, prefetch_k + kk * tK, n + nn * tN, N);
}
}
prefetch_k += tK * KK;
Expand Down Expand Up @@ -178,12 +178,12 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float*
for (int p = 0; p < PREFETCH_DISTANCE; p++) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm++) {
prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K);
prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K);
}
}
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn+=2) {
prefetch_b_vnni_d16_k16_n8v2_sg8(B, init_k + kk * tK, n + nn * tN, N);
prefetch_b_vnni_d16_k16_n8v2_sg8(B, prefetch_k + kk * tK, n + nn * tN, N);
}
}
prefetch_k += tK * KK;
Expand Down

0 comments on commit b034c1d

Please sign in to comment.