From 2ea4d56f138fad051a96191553cb18329862bd6c Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 6 Mar 2024 10:12:47 -0800 Subject: [PATCH] switch the prefetch order back for now At least for now, we will follow the order: 1. Load B Matrix Tile 2. Load A matrix Tile 3. Prefetch Next B Matrix Tile 4. Prefetch Next A Matrix Tile 5. Compute 6. Loop back to 1. --- .../matrix_kernel_tiled.cl | 27 ++++++++----------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 5d94faf..3421ee4 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -605,8 +605,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); prefetch_k += tK * KK; } @@ -620,25 +620,22 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN split_barrier_arrive(); for (int k = 0; k < K; k += tK * KK) { - // TODO: skip prefetch on the last iterations. - HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); - int8 bData[NN][KK]; HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(B, tN, K, N, k, n, bData); short8 aData[KK][MM]; HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); + HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]); } } - if (kk == 0) { - HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n); - prefetch_k += tK * KK; - } } split_barrier_wait(); @@ -667,8 +664,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); prefetch_k += tK * KK; } @@ -682,25 +679,23 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl split_barrier_arrive(); for (int k = 0; k < K; k += tK * KK) { - // TODO: skip prefetch on the last iterations. - HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); - int8 bData[NN][KK]; HELPER_NAME(btile_load_blockread_vnni, MM, NN)(B, tN, K, N, k, n, bData); short8 aData[KK][MM]; HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); + // TODO: skip prefetch on the last iterations. + HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]); } } - if (kk == 0) { - HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n); - prefetch_k += tK * KK; - } } split_barrier_wait();