Skip to content

Commit

Permalink
switch the prefetch order back for now
Browse files Browse the repository at this point in the history
At least for now, we will follow the order:
1. Load B Matrix Tile
2. Load A matrix Tile
3. Prefetch Next B Matrix Tile
4. Prefetch Next A Matrix Tile
5. Compute
6. Loop back to 1.
  • Loading branch information
bashbaug committed Mar 6, 2024
1 parent ef205e3 commit 2ea4d56
Showing 1 changed file with 11 additions and 16 deletions.
27 changes: 11 additions & 16 deletions samples/99_matrixexperiments/matrix_kernel_tiled.cl
Original file line number Diff line number Diff line change
Expand Up @@ -605,8 +605,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN

int prefetch_k = 0;
for (int p = 0; p < PREFETCH_DISTANCE; p++) {
HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k);
HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n);
HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k);
prefetch_k += tK * KK;
}

Expand All @@ -620,25 +620,22 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN
split_barrier_arrive();

for (int k = 0; k < K; k += tK * KK) {
// TODO: skip prefetch on the last iterations.
HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k);

int8 bData[NN][KK];
HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(B, tN, K, N, k, n, bData);

short8 aData[KK][MM];
HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData);

HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n);
HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k);
prefetch_k += tK * KK;

for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]);
}
}
if (kk == 0) {
HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n);
prefetch_k += tK * KK;
}
}

split_barrier_wait();
Expand Down Expand Up @@ -667,8 +664,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl

int prefetch_k = 0;
for (int p = 0; p < PREFETCH_DISTANCE; p++) {
HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k);
HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n);
HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k);
prefetch_k += tK * KK;
}

Expand All @@ -682,25 +679,23 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl
split_barrier_arrive();

for (int k = 0; k < K; k += tK * KK) {
// TODO: skip prefetch on the last iterations.
HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k);

int8 bData[NN][KK];
HELPER_NAME(btile_load_blockread_vnni, MM, NN)(B, tN, K, N, k, n, bData);

short8 aData[KK][MM];
HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData);

// TODO: skip prefetch on the last iterations.
HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n);
HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k);
prefetch_k += tK * KK;

for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]);
}
}
if (kk == 0) {
HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n);
prefetch_k += tK * KK;
}
}

split_barrier_wait();
Expand Down

0 comments on commit 2ea4d56

Please sign in to comment.