From 4ae2d9505d7f3fbf0a31a83b3277848c1536f5f7 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 24 Jan 2024 12:12:13 -0800 Subject: [PATCH] add support for wide K block reads --- .../matrix_kernel_tiled.cl | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index c932235..47a5c95 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -316,9 +316,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int k = 0; k < K; k += tK * KK) { short8 aData[KK][MM]; - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + } } } @@ -371,9 +381,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int k = 0; k < K; k += tK * KK) { short8 aData[KK][MM]; - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + } } } @@ -406,6 +426,3 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl } #endif // cl_intel_subgroup_extended_block_read - -#undef KK -#undef SGS_PER_WG