From ab84bbe442b7b721263dcd714bfa3ac56c9885c8 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Tue, 23 Jan 2024 15:14:29 -0800 Subject: [PATCH] performance improvements and bugfixes for DG2 - Fixed an error in the stride computation. - Added changes to improve stateless-to-stateful compilation. - Added a wider A matrix block read to load two K tiles at a time. --- samples/99_matrixexperiments/main.cpp | 6 ++ .../99_matrixexperiments/matrix_helpers.cl | 50 ++++++------- .../matrix_kernel_tiled.cl | 72 ++++++++++++++----- .../99_matrixexperiments/matrix_kernels.cl | 6 ++ 4 files changed, 94 insertions(+), 40 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index a3bacda..a38cf27 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -741,6 +741,7 @@ int main(int argc, char** argv) bfloat16_dpas_rowmajor<4, 8>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor<8, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 8, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 8, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 8, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); @@ -753,6 +754,7 @@ int main(int argc, char** argv) bfloat16_dpas_vnni<4, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni<8, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 8, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 8, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 8, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); @@ -765,6 +767,7 @@ int main(int argc, char** argv) bfloat16_dpas_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); @@ -777,6 +780,7 @@ int main(int argc, char** argv) bfloat16_dpas_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); @@ -789,6 +793,7 @@ int main(int argc, char** argv) bfloat16_dpas_blockread_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_blockread_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); @@ -801,6 +806,7 @@ int main(int argc, char** argv) bfloat16_dpas_blockread_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_blockread_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_blockread_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index 0719fdb..b77ede3 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -9,6 +9,8 @@ float bf16_to_fp32(ushort u) #if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) +typedef global ushort* global_aligned_ushort_ptr __attribute__((align_value(4))); + inline int compute_m(const int num_sgs, const int tM, const int MM) { const int m_start = get_group_id(1) * num_sgs; @@ -157,7 +159,7 @@ int load_a_rowmajor_d16_m1_k16_sg8(global ushort* A, int rowStart, int colStart int ret; global uint* A_ui = (global uint*)A; - int offset_ui = rowStart * stride / 2 + colStart / 2; + uint offset_ui = rowStart * stride / 2 + colStart / 2; ret = intel_sub_group_block_read(A_ui + offset_ui); return ret; @@ -170,7 +172,7 @@ int2 load_a_rowmajor_d16_m2_k16_sg8(global ushort* A, int rowStart, int colStart int2 ret; global uint* A_ui = (global uint*)A; - int offset_ui = rowStart * stride / 2 + colStart / 2; + uint offset_ui = rowStart * stride / 2 + colStart / 2; ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; @@ -185,7 +187,7 @@ int4 load_a_rowmajor_d16_m4_k16_sg8(global ushort* A, int rowStart, int colStart int4 ret; global uint* A_ui = (global uint*)A; - int offset_ui = rowStart * stride / 2 + colStart / 2; + uint offset_ui = rowStart * stride / 2 + colStart / 2; ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; @@ -202,7 +204,7 @@ int8 load_a_rowmajor_d16_m8_k16_sg8(global ushort* A, int rowStart, int colStart int8 ret; global uint* A_ui = (global uint*)A; - int offset_ui = rowStart * stride / 2 + colStart / 2; + uint offset_ui = rowStart * stride / 2 + colStart / 2; ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; @@ -224,7 +226,7 @@ int16 load_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int colSt uint16 ret; global uint* A_ui = (global uint*)A; - int offset_ui = rowStart * stride / 2 + colStart / 2; + uint offset_ui = rowStart * stride / 2 + colStart / 2; ret.s08 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; ret.s19 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; @@ -244,7 +246,7 @@ short load_a_rowmajor_d16_m1_k16_sg16(global ushort* A, int rowStart, int colSta { ushort ret; - int offset = rowStart * stride + colStart; + uint offset = rowStart * stride + colStart; ret = intel_sub_group_block_read_us(A + offset); return as_short(ret); @@ -256,7 +258,7 @@ short2 load_a_rowmajor_d16_m2_k16_sg16(global ushort* A, int rowStart, int colSt { ushort2 ret; - int offset = rowStart * stride + colStart; + uint offset = rowStart * stride + colStart; ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; @@ -269,7 +271,7 @@ short4 load_a_rowmajor_d16_m4_k16_sg16(global ushort* A, int rowStart, int colSt { ushort4 ret; - int offset = rowStart * stride + colStart; + uint offset = rowStart * stride + colStart; ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; ret.s2 = intel_sub_group_block_read_us(A + offset); offset += stride; @@ -284,7 +286,7 @@ short8 load_a_rowmajor_d16_m8_k16_sg16(global ushort* A, int rowStart, int colSt { ushort8 ret; - int offset = rowStart * stride + colStart; + uint offset = rowStart * stride + colStart; ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; ret.s2 = intel_sub_group_block_read_us(A + offset); offset += stride; @@ -304,15 +306,15 @@ short16 load_a_rowmajor_d16_m8_k16v2_sg16(global ushort* A, int rowStart, int co { ushort16 ret; - int offset = rowStart * stride + colStart; - ret.s08 = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; - ret.s19 = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; - ret.s2a = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; - ret.s3b = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; - ret.s4c = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; - ret.s5d = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; - ret.s6e = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; - ret.s7f = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; + uint offset = rowStart * stride + colStart; + ret.s08 = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s19 = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s2a = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s3b = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s4c = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s5d = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s6e = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s7f = intel_sub_group_block_read_us2(A + offset); offset += stride; return as_short16(ret); } @@ -324,7 +326,7 @@ int8 load_b_rowmajor_d16_k16_nx(global ushort* B, int rowStart, int colStart, in { int8 ret; - int offset = rowStart * stride + colStart; + uint offset = rowStart * stride + colStart; ushort row0 = intel_sub_group_block_read_us(B + offset); offset += stride; ushort row1 = intel_sub_group_block_read_us(B + offset); offset += stride; @@ -363,7 +365,7 @@ int8 load_b_vnni_d16_k16_nx(global ushort* B, int rowStart, int colStart, int st int8 ret; global uint* B_ui = (global uint*)B; - int offset_ui = rowStart / 2 * stride + colStart; + uint offset_ui = rowStart / 2 * stride + colStart; ret.s0 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; ret.s1 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; @@ -382,7 +384,7 @@ void store_c_rowmajor_fp32_m1_nx(global float* C, float v, int rowStart, int col global uint* C_ui = (global uint*)C; uint v_ui = as_uint(v); - int offset = rowStart * stride + colStart; + uint offset = rowStart * stride + colStart; intel_sub_group_block_write(C_ui + offset, v_ui); offset += stride; } @@ -392,7 +394,7 @@ void store_c_rowmajor_fp32_m2_nx(global float* C, float2 v, int rowStart, int co global uint* C_ui = (global uint*)C; uint2 v_ui = as_uint2(v); - int offset = rowStart * stride + colStart; + uint offset = rowStart * stride + colStart; intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; @@ -403,7 +405,7 @@ void store_c_rowmajor_fp32_m4_nx(global float* C, float4 v, int rowStart, int co global uint* C_ui = (global uint*)C; uint4 v_ui = as_uint4(v); - int offset = rowStart * stride + colStart; + uint offset = rowStart * stride + colStart; intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; @@ -416,7 +418,7 @@ void store_c_rowmajor_fp32_m8_nx(global float* C, float8 v, int rowStart, int co global uint* C_ui = (global uint*)C; uint8 v_ui = as_uint8(v); - int offset = rowStart * stride + colStart; + uint offset = rowStart * stride + colStart; intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 7874d7b..c932235 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -36,7 +36,7 @@ #if HAS_SIMD8 __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1))) -kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { const int tM = 8; const int tN = 8; @@ -55,9 +55,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl for (int k = 0; k < K; k += tK * KK) { int8 aData[KK][MM]; - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + int16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, k + kk * tK, K); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); + } } } @@ -90,7 +100,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl } __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1))) -kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { const int tM = 8; const int tN = 8; @@ -109,9 +119,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* for (int k = 0; k < K; k += tK * KK) { int8 aData[KK][MM]; - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + int16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, k + kk * tK, K); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); + } } } @@ -146,7 +166,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* #endif // HAS_SIMD8 __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) -kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { const int tM = 8; const int tN = 16; @@ -165,9 +185,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f for (int k = 0; k < K; k += tK * KK) { short8 aData[KK][MM]; - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg16(A, m + mm * tM, k + kk * tK, K); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); + } } } @@ -200,7 +230,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) -kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { const int tM = 8; const int tN = 16; @@ -219,9 +249,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float for (int k = 0; k < K; k += tK * KK) { short8 aData[KK][MM]; - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg16(A, m + mm * tM, k + kk * tK, K); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); + } } } diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index efc8f4f..f869a3b 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -513,6 +513,12 @@ kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort* // Tiled matrix multiplication kernels, generated from a template: +#define MM 1 +#define NN 1 +#include "matrix_kernel_tiled.cl" +#undef MM +#undef NN + #define MM 2 #define NN 1 #include "matrix_kernel_tiled.cl"