diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index da5c8e4..39d23ad 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -407,7 +407,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float #ifdef cl_intel_subgroup_extended_block_read -void HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k, short8 aData[KK][MM]) +void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k, short8 aData[KK][MM]) { if (KK % 2 == 0 & MM % 4 == 0) { for (int kk = 0; kk < KK; kk+=2) { @@ -460,7 +460,7 @@ void HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(global ushort* A, int tM } } -void HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) +void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) { if (KK % 2 == 0 & NN % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { @@ -499,7 +499,7 @@ void HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(global ushort* B, int tN } } -void HELPER_NAME(btile_load_blockread_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) +void HELPER_NAME(btile_block_load_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) { if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { @@ -611,7 +611,6 @@ void HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(global ushort* B, int tN, in __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) - { __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 8; @@ -639,10 +638,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int k = 0; k < K; k += tK * KK) { int8 bData[NN][KK]; - HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(B, tN, K, N, k, n, bData); + HELPER_NAME(btile_block_load_rowmajor, MM, NN)(B, tN, K, N, k, n, bData); short8 aData[KK][MM]; - HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); + HELPER_NAME(atile_block_load_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n); HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); @@ -699,10 +698,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int k = 0; k < K; k += tK * KK) { int8 bData[NN][KK]; - HELPER_NAME(btile_load_blockread_vnni, MM, NN)(B, tN, K, N, k, n, bData); + HELPER_NAME(btile_block_load_vnni, MM, NN)(B, tN, K, N, k, n, bData); short8 aData[KK][MM]; - HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); + HELPER_NAME(atile_block_load_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); // TODO: skip prefetch on the last iterations. HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n);