From 52b9550042b1754421ac84e50bb70155dae4e937 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Sat, 6 Jan 2024 11:14:38 -0800 Subject: [PATCH] add SIMD16 versions and emulation --- include/CL/opencl.hpp | 5 + samples/99_matrixexperiments/main.cpp | 88 +++- .../99_matrixexperiments/matrix_kernels.cl | 395 +++++++++++++++++- 3 files changed, 443 insertions(+), 45 deletions(-) diff --git a/include/CL/opencl.hpp b/include/CL/opencl.hpp index 1c43ae0..c14c81c 100644 --- a/include/CL/opencl.hpp +++ b/include/CL/opencl.hpp @@ -1654,6 +1654,11 @@ CL_HPP_DECLARE_PARAM_TRAITS_(cl_device_info, CL_DEVICE_NUM_THREADS_PER_EU_INTEL, CL_HPP_DECLARE_PARAM_TRAITS_(cl_device_info, CL_DEVICE_FEATURE_CAPABILITIES_INTEL, cl_device_feature_capabilities_intel) #endif // cl_intel_device_attribute_query +#if defined(cl_intel_required_subgroup_size) +CL_HPP_DECLARE_PARAM_TRAITS_(cl_device_info, CL_DEVICE_SUB_GROUP_SIZES_INTEL, cl::vector) +CL_HPP_DECLARE_PARAM_TRAITS_(cl_kernel_work_group_info, CL_KERNEL_SPILL_MEM_SIZE_INTEL, cl_ulong) +#endif // cl_intel_required_subgroup_size + // Convenience functions template diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 662457d..0d78b42 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -21,6 +22,7 @@ using test_clock = std::chrono::high_resolution_clock; bool fixedData = false; bool validate = false; +bool emulate = false; int testIterations = 16; float threshold = 0.01f; @@ -46,6 +48,16 @@ std::string makeTestName( return ret.str(); } +static size_t findMinSubGroupSize(cl::Device& device) +{ + auto s = device.getInfo(); + auto it = std::min_element(std::begin(s), std::end(s)); + if (it != std::end(s)) { + return *it; + } + return 0; +} + template static void fill_matrix(std::vector& M, size_t numRows, size_t numCols) { @@ -163,7 +175,9 @@ static void go_dpas_rowmajor( { printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout); - std::string kernelName = "bfloat16_dpas_rowmajor_m" + std::to_string(tM); + std::string kernelName = "bfloat16_dpas_rowmajor"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); cl::Kernel kernel{program, kernelName.c_str()}; if (kernel()) { kernel.setArg(0, C); @@ -204,7 +218,9 @@ static void go_dpas_vnni( { printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout); - std::string kernelName = "bfloat16_dpas_vnni_m" + std::to_string(tM); + std::string kernelName = "bfloat16_dpas_vnni"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); cl::Kernel kernel{program, kernelName.c_str()}; if (kernel()) { kernel.setArg(0, C); @@ -257,6 +273,7 @@ int main(int argc, char** argv) op.add>("i", "iterations", "Test Iterations", testIterations, &testIterations); op.add("", "validate", "Validate Results", &validate); op.add("", "fixed", "Use Fixed Data", &fixedData); + op.add("", "emulate", "Unconditionally Emulate dpas", &emulate); op.add>("", "threshold", "Local Error Threshold", threshold, &threshold); bool printUsage = false; try { @@ -287,10 +304,27 @@ int main(int argc, char** argv) printf("Running on device: %s\n", device.getInfo().c_str() ); + bool emulate_tN8 = true; + bool emulate_tN16 = true; + if (!emulate && checkDeviceForExtension(device, "cl_intel_subgroup_matrix_multiply_accumulate")) { + auto minSubGroupSize = findMinSubGroupSize(device); + printf("Found support for cl_intel_subgroup_matrix_multiply_accumulate, min sub-group size is: %zu\n", minSubGroupSize); + switch(minSubGroupSize) { + case 8: emulate_tN8 = false; break; + case 16: emulate_tN16 = false; break; + default: break; + } + } + + buildOptions += " -DEMULATE_tN8=" + std::to_string(emulate_tN8); + buildOptions += " -DEMULATE_tN16=" + std::to_string(emulate_tN16); + printf("Config:\n"); printf("\tTest Iterations: %d\n", testIterations); printf("\tValidating data?: %s\n", validate ? "true" : "false"); printf("\tFixed data?: %s\n", fixedData ? "true" : "false"); + printf("\tEmulate dpas for tN=8?: %s\n", emulate_tN8 ? "true" : "false"); + printf("\tEmulate dpas for tN=16?: %s\n", emulate_tN16 ? "true" : "false"); cl::Context context{device}; cl::CommandQueue queue{context, device}; @@ -314,42 +348,52 @@ int main(int argc, char** argv) const auto N = matrixSize; const auto K = matrixSize; - std::vector A(M * K); - std::vector B(K * N); - std::vector B_vnni(K * N); + std::vector A_vec(M * K); + std::vector B_vec(K * N); + std::vector Bvnni_vec(K * N); std::vector C_ref(M * N); printf("Initializing source matrices...\n"); - fill_matrix(A, M, K); - fill_matrix(B, K, N); + fill_matrix(A_vec, M, K); + fill_matrix(B_vec, K, N); - vnni_matrix(B_vnni, B, K, N, 2); + vnni_matrix(Bvnni_vec, B_vec, K, N, 2); if (validate) { printf("Computing reference...\n"); - compute_reference(C_ref, A, B, M, N, K); + compute_reference(C_ref, A_vec, B_vec, M, N, K); } printf("Creating source buffers...\n"); - cl::Buffer Abuf{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, A.size() * sizeof(A[0]), A.data()}; - cl::Buffer Bbuf{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B.size() * sizeof(B[0]), B.data()}; - cl::Buffer Bbuf_vnni{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B_vnni.size() * sizeof(B_vnni[0]), B_vnni.data()}; - cl::Buffer Cbuf{context, CL_MEM_WRITE_ONLY, C_ref.size() * sizeof(C_ref[0])}; + cl::Buffer A{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, A_vec.size() * sizeof(A_vec[0]), A_vec.data()}; + cl::Buffer B{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B_vec.size() * sizeof(B_vec[0]), B_vec.data()}; + cl::Buffer Bvnni{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, Bvnni_vec.size() * sizeof(Bvnni_vec[0]), Bvnni_vec.data()}; + cl::Buffer C{context, CL_MEM_WRITE_ONLY, C_ref.size() * sizeof(C_ref[0])}; printf("Running tests...\n"); - go_naive(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref); + go_naive(context, program, queue, C, A, B, M, N, K, C_ref); + + go_dpas_rowmajor<1, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor<2, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor<4, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor<8, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + + go_dpas_vnni<1, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni<2, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni<4, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni<8, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_rowmajor<1, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref); - go_dpas_rowmajor<2, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref); - go_dpas_rowmajor<4, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref); - go_dpas_rowmajor<8, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref); + go_dpas_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_vnni<1, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf_vnni, M, N, K, C_ref); - go_dpas_vnni<2, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf_vnni, M, N, K, C_ref); - go_dpas_vnni<4, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf_vnni, M, N, K, C_ref); - go_dpas_vnni<8, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf_vnni, M, N, K, C_ref); + go_dpas_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); printf("Done.\n"); diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index 76ce675..7121021 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -1,4 +1,16 @@ -float bfloat16_to_float(ushort u) +#if EMULATE_tn8 == 0 +#define mat_mul_x8 intel_sub_group_bf16_bf16_matrix_mad_k16 +#else +#define mat_mul_x8 my_sub_group_bf16_bf16_matrix_mad_k16 +#endif + +#if EMULATE_tN16 == 0 +#define mat_mul_x16 intel_sub_group_bf16_bf16_matrix_mad_k16 +#else +#define mat_mul_x16 my_sub_group_bf16_bf16_matrix_mad_k16 +#endif + +float bf16_to_fp32(ushort u) { #if defined(cl_intel_bfloat16_conversions) return intel_convert_as_bfloat16_float(u); @@ -15,15 +27,146 @@ kernel void bfloat16_naive(global float* C, global ushort* A, global ushort* B, float sum = 0; for (int k = 0; k < K; k++) { - sum = fma(bfloat16_to_float(A[m * K + k]), bfloat16_to_float(B[k * N + n]), sum); + sum = fma(bf16_to_fp32(A[m * K + k]), bf16_to_fp32(B[k * N + n]), sum); } C[m * N + n] = sum; } -#if defined(cl_intel_subgroup_matrix_multiply_accumulate) +#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) + +#define OVLD __attribute__((overloadable)) + +// SIMD8 versions: +static float OVLD my_sub_group_bf16_bf16_matrix_mad_k16(int a, int8 b, float acc) +{ + float res = acc; + + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 0)).x), bf16_to_fp32(as_ushort2(b.s0).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 0)).y), bf16_to_fp32(as_ushort2(b.s0).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 1)).x), bf16_to_fp32(as_ushort2(b.s1).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 1)).y), bf16_to_fp32(as_ushort2(b.s1).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 2)).x), bf16_to_fp32(as_ushort2(b.s2).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 2)).y), bf16_to_fp32(as_ushort2(b.s2).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 3)).x), bf16_to_fp32(as_ushort2(b.s3).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 3)).y), bf16_to_fp32(as_ushort2(b.s3).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 4)).x), bf16_to_fp32(as_ushort2(b.s4).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 4)).y), bf16_to_fp32(as_ushort2(b.s4).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 5)).x), bf16_to_fp32(as_ushort2(b.s5).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 5)).y), bf16_to_fp32(as_ushort2(b.s5).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 6)).x), bf16_to_fp32(as_ushort2(b.s6).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 6)).y), bf16_to_fp32(as_ushort2(b.s6).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 7)).x), bf16_to_fp32(as_ushort2(b.s7).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 7)).y), bf16_to_fp32(as_ushort2(b.s7).y), res); + + return res; +} + +static float2 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(int2 a, int8 b, float2 acc) +{ + float2 res; + + res.s0 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + + return res; +} + +static float4 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(int4 a, int8 b, float4 acc) +{ + float4 res; + + res.s0 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + res.s2 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); + res.s3 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); + + return res; +} + +static float8 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(int8 a, int8 b, float8 acc) +{ + float8 res; + + res.s0 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + res.s2 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); + res.s3 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); + res.s4 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s4, b, acc.s4); + res.s5 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s5, b, acc.s5); + res.s6 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s6, b, acc.s6); + res.s7 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s7, b, acc.s7); + + return res; +} + +// SIMD16 versions: +static float OVLD my_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc) +{ + float res = acc; + + res = fma(bf16_to_fp32(sub_group_broadcast(a, 0)), bf16_to_fp32(as_ushort2(b.s0).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 1)), bf16_to_fp32(as_ushort2(b.s0).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 2)), bf16_to_fp32(as_ushort2(b.s1).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 3)), bf16_to_fp32(as_ushort2(b.s1).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 4)), bf16_to_fp32(as_ushort2(b.s2).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 5)), bf16_to_fp32(as_ushort2(b.s2).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 6)), bf16_to_fp32(as_ushort2(b.s3).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 7)), bf16_to_fp32(as_ushort2(b.s3).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 8)), bf16_to_fp32(as_ushort2(b.s4).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 9)), bf16_to_fp32(as_ushort2(b.s4).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 10)), bf16_to_fp32(as_ushort2(b.s5).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 11)), bf16_to_fp32(as_ushort2(b.s5).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 12)), bf16_to_fp32(as_ushort2(b.s6).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 13)), bf16_to_fp32(as_ushort2(b.s6).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 14)), bf16_to_fp32(as_ushort2(b.s7).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 15)), bf16_to_fp32(as_ushort2(b.s7).y), res); + + return res; +} + +static float2 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(short2 a, int8 b, float2 acc) +{ + float2 res; + + res.s0 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + + return res; +} + +static float4 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(short4 a, int8 b, float4 acc) +{ + float4 res; + + res.s0 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + res.s2 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); + res.s3 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); + + return res; +} + +static float8 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc) +{ + float8 res; + + res.s0 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + res.s2 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); + res.s3 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); + res.s4 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s4, b, acc.s4); + res.s5 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s5, b, acc.s5); + res.s6 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s6, b, acc.s6); + res.s7 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s7, b, acc.s7); + + return res; +} + +#undef OVLD // M rows x K columns +// This is the SIMD8 version, where each work-item loads two values. static int __load_a_row_major_bf16_k16_m1_x8(global ushort* A, int rowStart, int colStart, int stride) { int ret; @@ -32,10 +175,11 @@ static int __load_a_row_major_bf16_k16_m1_x8(global ushort* A, int rowStart, int int offset_ui = rowStart * stride / 2 + colStart / 2; ret = intel_sub_group_block_read(A_ui + offset_ui); - return ret; + return ret; } // M rows x K columns +// This is the SIMD8 version, where each work-item loads two values. static int2 __load_a_row_major_bf16_k16_m2_x8(global ushort* A, int rowStart, int colStart, int stride) { int2 ret; @@ -46,10 +190,11 @@ static int2 __load_a_row_major_bf16_k16_m2_x8(global ushort* A, int rowStart, in ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; - return ret; + return ret; } // M rows x K columns +// This is the SIMD8 version, where each work-item loads two values. static int4 __load_a_row_major_bf16_k16_m4_x8(global ushort* A, int rowStart, int colStart, int stride) { int4 ret; @@ -66,6 +211,7 @@ static int4 __load_a_row_major_bf16_k16_m4_x8(global ushort* A, int rowStart, in } // M rows x K columns +// This is the SIMD8 version, where each work-item loads two values. static int8 __load_a_row_major_bf16_k16_m8_x8(global ushort* A, int rowStart, int colStart, int stride) { int8 ret; @@ -82,7 +228,66 @@ static int8 __load_a_row_major_bf16_k16_m8_x8(global ushort* A, int rowStart, in ret.s6 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; ret.s7 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; - return ret; + return ret; +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads one values. +static short __load_a_row_major_bf16_k16_m1_x16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort ret; + + int offset = rowStart * stride + colStart; + ret = intel_sub_group_block_read_us(A + offset); + + return as_short(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads one values. +static short2 __load_a_row_major_bf16_k16_m2_x16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort2 ret; + + int offset = rowStart * stride + colStart; + ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; + + return as_short2(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads one values. +static short4 __load_a_row_major_bf16_k16_m4_x16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort4 ret; + + int offset = rowStart * stride + colStart; + ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s2 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s3 = intel_sub_group_block_read_us(A + offset); offset += stride; + + return as_short4(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads one values. +static short8 __load_a_row_major_bf16_k16_m8_x16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort8 ret; + + int offset = rowStart * stride + colStart; + ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s2 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s3 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s4 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s5 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s6 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s7 = intel_sub_group_block_read_us(A + offset); offset += stride; + + return as_short8(ret); } // K rows x N columns: @@ -198,7 +403,7 @@ static void __store_c_row_major_fp32_m8(global float* C, float8 v, int rowStart, __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) -kernel void bfloat16_dpas_rowmajor_m1(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_rowmajor_m1_n8(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); int m = get_group_id(1); @@ -208,7 +413,7 @@ kernel void bfloat16_dpas_rowmajor_m1(global float* C, global ushort* A, global for (int k = 0; k < K; k += 16) { int aData = __load_a_row_major_bf16_k16_m1_x8(A, m, k, K); int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); - sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + sum = mat_mul_x8(aData, bData, sum); } __store_c_row_major_fp32_m1(C, sum, m, n, N); @@ -216,7 +421,7 @@ kernel void bfloat16_dpas_rowmajor_m1(global float* C, global ushort* A, global __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) -kernel void bfloat16_dpas_rowmajor_m2(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_rowmajor_m2_n8(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); int m = get_group_id(1) * 2; @@ -226,7 +431,7 @@ kernel void bfloat16_dpas_rowmajor_m2(global float* C, global ushort* A, global for (int k = 0; k < K; k += 16) { int2 aData = __load_a_row_major_bf16_k16_m2_x8(A, m, k, K); int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); - sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + sum = mat_mul_x8(aData, bData, sum); } __store_c_row_major_fp32_m2(C, sum, m, n, N); @@ -234,7 +439,7 @@ kernel void bfloat16_dpas_rowmajor_m2(global float* C, global ushort* A, global __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) -kernel void bfloat16_dpas_rowmajor_m4(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_rowmajor_m4_n8(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); int m = get_group_id(1) * 4; @@ -244,7 +449,7 @@ kernel void bfloat16_dpas_rowmajor_m4(global float* C, global ushort* A, global for (int k = 0; k < K; k += 16) { int4 aData = __load_a_row_major_bf16_k16_m4_x8(A, m, k, K); int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); - sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + sum = mat_mul_x8(aData, bData, sum); } __store_c_row_major_fp32_m4(C, sum, m, n, N); @@ -252,7 +457,7 @@ kernel void bfloat16_dpas_rowmajor_m4(global float* C, global ushort* A, global __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) -kernel void bfloat16_dpas_rowmajor_m8(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_rowmajor_m8_n8(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); int m = get_group_id(1) * 8; @@ -262,7 +467,79 @@ kernel void bfloat16_dpas_rowmajor_m8(global float* C, global ushort* A, global for (int k = 0; k < K; k += 16) { int8 aData = __load_a_row_major_bf16_k16_m8_x8(A, m, k, K); int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); - sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + sum = mat_mul_x8(aData, bData, sum); + } + + __store_c_row_major_fp32_m8(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1); + int n = get_group_id(0) * get_local_size(0); + + float sum = 0; + for (int k = 0; k < K; k += 16) { + short aData = __load_a_row_major_bf16_k16_m1_x16(A, m, k, K); + int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); + } + + __store_c_row_major_fp32_m1(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m2_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 2; + int n = get_group_id(0) * get_local_size(0); + + float2 sum = 0; + for (int k = 0; k < K; k += 16) { + short2 aData = __load_a_row_major_bf16_k16_m2_x16(A, m, k, K); + int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); + } + + __store_c_row_major_fp32_m2(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m4_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 4; + int n = get_group_id(0) * get_local_size(0); + + float4 sum = 0; + for (int k = 0; k < K; k += 16) { + short4 aData = __load_a_row_major_bf16_k16_m4_x16(A, m, k, K); + int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); + } + + __store_c_row_major_fp32_m4(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m8_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 8; + int n = get_group_id(0) * get_local_size(0); + + float8 sum = 0; + for (int k = 0; k < K; k += 16) { + short8 aData = __load_a_row_major_bf16_k16_m8_x16(A, m, k, K); + int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); } __store_c_row_major_fp32_m8(C, sum, m, n, N); @@ -270,7 +547,7 @@ kernel void bfloat16_dpas_rowmajor_m8(global float* C, global ushort* A, global __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) -kernel void bfloat16_dpas_vnni_m1(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_vnni_m1_n8(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); int m = get_group_id(1); @@ -280,7 +557,7 @@ kernel void bfloat16_dpas_vnni_m1(global float* C, global ushort* A, global usho for (int k = 0; k < K; k += 16) { int aData = __load_a_row_major_bf16_k16_m1_x8(A, m, k, K); int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); - sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + sum = mat_mul_x8(aData, bData, sum); } __store_c_row_major_fp32_m1(C, sum, m, n, N); @@ -288,7 +565,7 @@ kernel void bfloat16_dpas_vnni_m1(global float* C, global ushort* A, global usho __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) -kernel void bfloat16_dpas_vnni_m2(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_vnni_m2_n8(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); int m = get_group_id(1) * 2; @@ -298,7 +575,7 @@ kernel void bfloat16_dpas_vnni_m2(global float* C, global ushort* A, global usho for (int k = 0; k < K; k += 16) { int2 aData = __load_a_row_major_bf16_k16_m2_x8(A, m, k, K); int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); - sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + sum = mat_mul_x8(aData, bData, sum); } __store_c_row_major_fp32_m2(C, sum, m, n, N); @@ -306,7 +583,7 @@ kernel void bfloat16_dpas_vnni_m2(global float* C, global ushort* A, global usho __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) -kernel void bfloat16_dpas_vnni_m4(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_vnni_m4_n8(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); int m = get_group_id(1) * 4; @@ -316,7 +593,7 @@ kernel void bfloat16_dpas_vnni_m4(global float* C, global ushort* A, global usho for (int k = 0; k < K; k += 16) { int4 aData = __load_a_row_major_bf16_k16_m4_x8(A, m, k, K); int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); - sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + sum = mat_mul_x8(aData, bData, sum); } __store_c_row_major_fp32_m4(C, sum, m, n, N); @@ -324,7 +601,7 @@ kernel void bfloat16_dpas_vnni_m4(global float* C, global ushort* A, global usho __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) -kernel void bfloat16_dpas_vnni_m8(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_vnni_m8_n8(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); int m = get_group_id(1) * 8; @@ -334,10 +611,82 @@ kernel void bfloat16_dpas_vnni_m8(global float* C, global ushort* A, global usho for (int k = 0; k < K; k += 16) { int8 aData = __load_a_row_major_bf16_k16_m8_x8(A, m, k, K); int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); - sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + sum = mat_mul_x8(aData, bData, sum); + } + + __store_c_row_major_fp32_m8(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_vnni_m1_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1); + int n = get_group_id(0) * get_local_size(0); + + float sum = 0; + for (int k = 0; k < K; k += 16) { + short aData = __load_a_row_major_bf16_k16_m1_x16(A, m, k, K); + int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); + } + + __store_c_row_major_fp32_m1(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_vnni_m2_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 2; + int n = get_group_id(0) * get_local_size(0); + + float2 sum = 0; + for (int k = 0; k < K; k += 16) { + short2 aData = __load_a_row_major_bf16_k16_m2_x16(A, m, k, K); + int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); + } + + __store_c_row_major_fp32_m2(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_vnni_m4_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 4; + int n = get_group_id(0) * get_local_size(0); + + float4 sum = 0; + for (int k = 0; k < K; k += 16) { + short4 aData = __load_a_row_major_bf16_k16_m4_x16(A, m, k, K); + int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); + } + + __store_c_row_major_fp32_m4(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_vnni_m8_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 8; + int n = get_group_id(0) * get_local_size(0); + + float8 sum = 0; + for (int k = 0; k < K; k += 16) { + short8 aData = __load_a_row_major_bf16_k16_m8_x16(A, m, k, K); + int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); } __store_c_row_major_fp32_m8(C, sum, m, n, N); } -#endif // defined(cl_intel_subgroup_matrix_multiply_accumulate) +#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size)