From 697754d45f8109124e68f9e7d051ecf129dc78c3 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 1 Mar 2024 17:31:09 -0800 Subject: [PATCH] use more helper functions for DG2 tiled kernels --- .../matrix_kernel_tiled.cl | 147 ++++++------------ 1 file changed, 49 insertions(+), 98 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 75ab95b..4e7db7b 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -40,6 +40,24 @@ #define PREFETCH_DISTANCE 1 #endif +void HELPER_NAME(btile_load_rowmajor, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[KK][NN]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(btile_load_vnni, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[KK][NN]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } + } +} + #if HAS_SIMD8 void HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int K, int m, int prefetch_k) @@ -69,6 +87,25 @@ void HELPER_NAME(btile_prefetch_vnni_sg8, MM, NN)(global ushort* B, int tN, int } } +void HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int K, int m, int k, int8 aData[KK][MM]) +{ + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + int16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, k + kk * tK, K); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); + } + } + } +} + __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { @@ -81,16 +118,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl // Initial prefetch: int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=4) { - prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; } @@ -106,41 +135,15 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl for (int k = 0; k < K; k += tK * KK) { // Next prefetch: // TODO: skip prefetch on the last iterations. - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=4) { - prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; int8 aData[KK][MM]; - if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - int16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, k + kk * tK, K); - aData[kk + 0][mm] = aTemp.lo; - aData[kk + 1][mm] = aTemp.hi; - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); - } - } - } + HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(A, tM, K, m, k, aData); int8 bData[KK][NN]; - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(btile_load_rowmajor, MM, NN)(B, tN, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm++) { @@ -176,16 +179,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* // Initial prefetch: int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_vnni_d16_k16_n8v2_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_vnni_sg8, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; } @@ -201,41 +196,15 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* for (int k = 0; k < K; k += tK * KK) { // Next prefetch: // TODO: skip prefetch on the last iterations. - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_vnni_d16_k16_n8v2_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; int8 aData[KK][MM]; - if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - int16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, k + kk * tK, K); - aData[kk + 0][mm] = aTemp.lo; - aData[kk + 1][mm] = aTemp.hi; - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); - } - } - } + HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(A, tM, K, m, k, aData); int8 bData[KK][NN]; - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(btile_load_vnni, MM, NN)(B, tN, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { @@ -307,24 +276,6 @@ void HELPER_NAME(atile_load_rowmajor, MM, NN)(global ushort* A, int tM, int K, i } } -void HELPER_NAME(btile_load_rowmajor, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[KK][NN]) -{ - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); - } - } -} - -void HELPER_NAME(btile_load_vnni, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[KK][NN]) -{ - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); - } - } -} - __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) {