Skip to content

Commit

Permalink
try a different order for prefetches and loads
Browse files Browse the repository at this point in the history
For the block read kernels, the order is now:
1. Prefetch A
2. Load B
3. Load A
4. Prefetch B
  • Loading branch information
bashbaug committed Mar 5, 2024
1 parent a0c2e53 commit 873c6ab
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions samples/99_matrixexperiments/matrix_kernel_tiled.cl
Original file line number Diff line number Diff line change
Expand Up @@ -559,8 +559,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN

int prefetch_k = 0;
for (int p = 0; p < PREFETCH_DISTANCE; p++) {
HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n);
HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k);
HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n);
prefetch_k += tK * KK;
}

Expand All @@ -575,23 +575,24 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN

for (int k = 0; k < K; k += tK * KK) {
// TODO: skip prefetch on the last iterations.
HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n);

short8 aData[KK][MM];
HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData);
HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k);

int8 bData[NN][KK];
HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(B, tN, K, N, k, n, bData);

HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k);
prefetch_k += tK * KK;
short8 aData[KK][MM];
HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData);

for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]);
}
}
if (kk == 0) {
HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n);
prefetch_k += tK * KK;
}
}

split_barrier_wait();
Expand Down Expand Up @@ -636,23 +637,24 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl

for (int k = 0; k < K; k += tK * KK) {
// TODO: skip prefetch on the last iterations.
HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n);

short8 aData[KK][MM];
HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData);
HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k);

int8 bData[NN][KK];
HELPER_NAME(btile_load_blockread_vnni, MM, NN)(B, tN, K, N, k, n, bData);

HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k);
prefetch_k += tK * KK;
short8 aData[KK][MM];
HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData);

for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]);
}
}
if (kk == 0) {
HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n);
prefetch_k += tK * KK;
}
}

split_barrier_wait();
Expand Down

0 comments on commit 873c6ab

Please sign in to comment.