From 031e076143111aa632eca6dd0b6082502b9c6535 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 18 Jan 2024 22:18:43 -0800 Subject: [PATCH] add support for larger K values for some tiled kernels --- .../matrix_kernel_tiled.cl | 111 +++++++++++------- 1 file changed, 71 insertions(+), 40 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index cb3c6ca..05c2ace 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -10,6 +10,10 @@ #error "NN is undefined! This should be defined as the number of matrix tiles in the N dimension." #endif +#if !defined(KK) +#define KK 1 +#endif + #if !defined(cl_intel_split_work_group_barrier) || defined(NO_SPLIT_BARRIERS) #if !defined(cl_intel_split_work_group_barrier) #warning "Unexpected: cl_intel_split_work_group_barrier is not supported?" @@ -49,20 +53,26 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl split_barrier_arrive(); - for (int k = 0; k < K; k += tK) { - int8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); + for (int k = 0; k < K; k += tK * KK) { + int8 aData[KK][MM]; + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); + } } - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); + int8 bData[KK][NN]; + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } } - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = mat_mul_sg8(aData[mm], bData[nn], sum[mm][nn]); + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_sg8(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + } } } @@ -97,20 +107,26 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* split_barrier_arrive(); - for (int k = 0; k < K; k += tK) { - int8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); + for (int k = 0; k < K; k += tK * KK) { + int8 aData[KK][MM]; + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); + } } - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); + int8 bData[KK][NN]; + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } } - for (int nn = 0; nn < NN; nn++) { - for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg8(aData[mm], bData[nn], sum[mm][nn]); + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[mm][nn] = mat_mul_sg8(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + } } } @@ -147,20 +163,26 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f split_barrier_arrive(); - for (int k = 0; k < K; k += tK) { - short8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k, K); + for (int k = 0; k < K; k += tK * KK) { + short8 aData[KK][MM]; + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); + } } - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); + int8 bData[KK][NN]; + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } } - for (int nn = 0; nn < NN; nn++) { - for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + } } } @@ -195,20 +217,26 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float split_barrier_arrive(); - for (int k = 0; k < K; k += tK) { - short8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k, K); + for (int k = 0; k < K; k += tK * KK) { + short8 aData[KK][MM]; + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); + } } - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); + int8 bData[KK][NN]; + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } } - for (int nn = 0; nn < NN; nn++) { - for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + } } } @@ -342,3 +370,6 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl } #endif // cl_intel_subgroup_extended_block_read + +#undef KK +#undef SGS_PER_WG