Skip to content

Commit

Permalink
add support for larger K values for some tiled kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Jan 19, 2024
1 parent d09b982 commit 031e076
Showing 1 changed file with 71 additions and 40 deletions.
111 changes: 71 additions & 40 deletions samples/99_matrixexperiments/matrix_kernel_tiled.cl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#error "NN is undefined! This should be defined as the number of matrix tiles in the N dimension."
#endif

#if !defined(KK)
#define KK 1
#endif

#if !defined(cl_intel_split_work_group_barrier) || defined(NO_SPLIT_BARRIERS)
#if !defined(cl_intel_split_work_group_barrier)
#warning "Unexpected: cl_intel_split_work_group_barrier is not supported?"
Expand Down Expand Up @@ -49,20 +53,26 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl

split_barrier_arrive();

for (int k = 0; k < K; k += tK) {
int8 aData[MM];
for (int mm = 0; mm < MM; mm++) {
aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K);
for (int k = 0; k < K; k += tK * KK) {
int8 aData[KK][MM];
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K);
}
}

int8 bData[NN];
for (int nn = 0; nn < NN; nn++) {
bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N);
int8 bData[KK][NN];
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N);
}
}

for (int mm = 0; mm < MM; mm++) {
for (int nn = 0; nn < NN; nn++) {
sum[mm][nn] = mat_mul_sg8(aData[mm], bData[nn], sum[mm][nn]);
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
for (int nn = 0; nn < NN; nn++) {
sum[mm][nn] = mat_mul_sg8(aData[kk][mm], bData[kk][nn], sum[mm][nn]);
}
}
}

Expand Down Expand Up @@ -97,20 +107,26 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float*

split_barrier_arrive();

for (int k = 0; k < K; k += tK) {
int8 aData[MM];
for (int mm = 0; mm < MM; mm++) {
aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K);
for (int k = 0; k < K; k += tK * KK) {
int8 aData[KK][MM];
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K);
}
}

int8 bData[NN];
for (int nn = 0; nn < NN; nn++) {
bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N);
int8 bData[KK][NN];
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N);
}
}

for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
sum[mm][nn] = mat_mul_sg8(aData[mm], bData[nn], sum[mm][nn]);
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
sum[mm][nn] = mat_mul_sg8(aData[kk][mm], bData[kk][nn], sum[mm][nn]);
}
}
}

Expand Down Expand Up @@ -147,20 +163,26 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f

split_barrier_arrive();

for (int k = 0; k < K; k += tK) {
short8 aData[MM];
for (int mm = 0; mm < MM; mm++) {
aData[mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k, K);
for (int k = 0; k < K; k += tK * KK) {
short8 aData[KK][MM];
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K);
}
}

int8 bData[NN];
for (int nn = 0; nn < NN; nn++) {
bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N);
int8 bData[KK][NN];
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N);
}
}

for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]);
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]);
}
}
}

Expand Down Expand Up @@ -195,20 +217,26 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float

split_barrier_arrive();

for (int k = 0; k < K; k += tK) {
short8 aData[MM];
for (int mm = 0; mm < MM; mm++) {
aData[mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k, K);
for (int k = 0; k < K; k += tK * KK) {
short8 aData[KK][MM];
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K);
}
}

int8 bData[NN];
for (int nn = 0; nn < NN; nn++) {
bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N);
int8 bData[KK][NN];
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N);
}
}

for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]);
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]);
}
}
}

Expand Down Expand Up @@ -342,3 +370,6 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl
}

#endif // cl_intel_subgroup_extended_block_read

#undef KK
#undef SGS_PER_WG

0 comments on commit 031e076

Please sign in to comment.