Skip to content

Commit

Permalink
add support for prefetching multiple iterations ahead
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Feb 22, 2024
1 parent c88dc58 commit 083a946
Showing 1 changed file with 64 additions and 44 deletions.
108 changes: 64 additions & 44 deletions samples/99_matrixexperiments/matrix_kernel_tiled.cl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
#define SGS_PER_WG 4
#endif

#if !defined(PREFETCH_DISTANCE)
#define PREFETCH_DISTANCE 1
#endif

#if HAS_SIMD8

__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1)))
Expand Down Expand Up @@ -229,16 +233,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f
const int n = get_group_id(0) * tN * NN;

// Initial prefetch:
const int init_k = 0;
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm+=2) {
prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, init_k + kk * tK, K);
int prefetch_k = 0;
for (int p = 0; p < PREFETCH_DISTANCE; p++) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm+=2) {
prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K);
}
}
}
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn+=2) {
prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, init_k + kk * tK, n + nn * tN, N);
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn+=2) {
prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N);
}
}
prefetch_k += tK * KK;
}

float8 sum[MM][NN];
Expand All @@ -252,17 +259,18 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f

for (int k = 0; k < K; k += tK * KK) {
// Next prefetch:
const int next_k = k + tK * KK;
// TODO: skip prefetch on the last iterations.
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm+=2) {
prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, next_k + kk * tK, K);
prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K);
}
}
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn+=2) {
prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, next_k + kk * tK, n + nn * tN, N);
prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N);
}
}
prefetch_k += tK * KK;

short8 aData[KK][MM];
if (KK % 2 == 0) {
Expand Down Expand Up @@ -320,16 +328,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float
const int n = get_group_id(0) * tN * NN;

// Initial prefetch:
const int init_k = 0;
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm+=2) {
prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, init_k + kk * tK, K);
int prefetch_k = 0;
for (int p = 0; p < PREFETCH_DISTANCE; p++) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm+=2) {
prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K);
}
}
}
for (int kk = 0; kk < KK; kk+=2) {
for (int nn = 0; nn < NN; nn++) {
prefetch_b_vnni_d16_k16v2_n16_sg16(B, init_k + kk * tK, n + nn * tN, N);
for (int kk = 0; kk < KK; kk+=2) {
for (int nn = 0; nn < NN; nn++) {
prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N);
}
}
prefetch_k += tK * KK;
}

float8 sum[MM][NN];
Expand All @@ -343,17 +354,18 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float

for (int k = 0; k < K; k += tK * KK) {
// Next prefetch:
const int next_k = k + tK * KK;
// TODO: skip prefetch on the last iterations.
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm+=2) {
prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, next_k + kk * tK, K);
prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K);
}
}
for (int kk = 0; kk < KK; kk+=2) {
for (int nn = 0; nn < NN; nn++) {
prefetch_b_vnni_d16_k16v2_n16_sg16(B, next_k + kk * tK, n + nn * tN, N);
prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N);
}
}
prefetch_k += tK * KK;

short8 aData[KK][MM];
if (KK % 2 == 0) {
Expand Down Expand Up @@ -414,16 +426,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN
const int n = get_group_id(0) * tN * NN;

// Initial prefetch:
const int init_k = 0;
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm+=2) {
prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, init_k + kk * tK, K);
int prefetch_k = 0;
for (int p = 0; p < PREFETCH_DISTANCE; p++) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm+=2) {
prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K);
}
}
}
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn+=2) {
prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, init_k + kk * tK, n + nn * tN, N);
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn+=2) {
prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N);
}
}
prefetch_k += tK * KK;
}

float8 sum[MM][NN];
Expand All @@ -437,17 +452,18 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN

for (int k = 0; k < K; k += tK * KK) {
// Next prefetch:
const int next_k = k + tK * KK;
// TODO: skip prefetch on the last iterations.
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm+=2) {
prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, next_k + kk * tK, K);
prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K);
}
}
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn+=2) {
prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, next_k + kk * tK, n + nn * tN, N);
prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N);
}
}
prefetch_k += tK * KK;

short8 aData[KK][MM];
if (KK % 2 == 0) {
Expand Down Expand Up @@ -506,16 +522,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl
const int n = get_group_id(0) * tN * NN;

// Initial prefetch:
const int init_k = 0;
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm+=2) {
prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, init_k + kk * tK, K);
int prefetch_k = 0;
for (int p = 0; p < PREFETCH_DISTANCE; p++) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm+=2) {
prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K);
}
}
}
for (int kk = 0; kk < KK; kk+=2) {
for (int nn = 0; nn < NN; nn++) {
prefetch_b_vnni_d16_k16v2_n16_sg16(B, init_k + kk * tK, n + nn * tN, N);
for (int kk = 0; kk < KK; kk+=2) {
for (int nn = 0; nn < NN; nn++) {
prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N);
}
}
prefetch_k += tK * KK;
}

float8 sum[MM][NN];
Expand All @@ -529,17 +548,18 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl

for (int k = 0; k < K; k += tK * KK) {
// Next prefetch:
const int next_k = k + tK * KK;
// TODO: skip prefetch on the last iterations.
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm+=2) {
prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, next_k + kk * tK, K);
prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K);
}
}
for (int kk = 0; kk < KK; kk+=2) {
for (int nn = 0; nn < NN; nn++) {
prefetch_b_vnni_d16_k16v2_n16_sg16(B, next_k + kk * tK, n + nn * tN, N);
prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N);
}
}
prefetch_k += tK * KK;

short8 aData[KK][MM];
if (KK % 2 == 0) {
Expand Down

0 comments on commit 083a946

Please sign in to comment.