From 7b89cfe1c60af25be196655ea02f697a6098e3c5 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 8 Jan 2024 21:39:24 -0800 Subject: [PATCH] add vnni block read variants --- samples/99_matrixexperiments/main.cpp | 55 +++++++++++- .../99_matrixexperiments/matrix_kernels.cl | 83 +++++++++++++++++++ 2 files changed, 136 insertions(+), 2 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 4fd60b8..aeb132c 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -20,6 +20,7 @@ using test_clock = std::chrono::high_resolution_clock; +bool identityData = false; bool fixedData = false; bool validate = false; bool emulate = false; @@ -61,10 +62,11 @@ static size_t findMinSubGroupSize(cl::Device& device) template static void fill_matrix(std::vector& M, size_t numRows, size_t numCols) { - if (fixedData) { + if (identityData) { + std::generate(std::begin(M), std::end(M), [&]{ return 1.0f; }); + } else if (fixedData) { for (size_t r = 0; r < numRows; r++) { for (size_t c = 0; c < numCols; c++) { - //M[r * numCols + c] = 1.0f; M[r * numCols + c] = static_cast(r + c); } } @@ -298,6 +300,49 @@ static void go_dpas_blockread_rowmajor( } } +template +static void go_dpas_blockread_vnni( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "bfloat16_dpas_blockread_vnni"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/tM}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); + } +} + int main(int argc, char** argv) { int platformIndex = 0; @@ -316,6 +361,7 @@ int main(int argc, char** argv) op.add>("m", "matrixsize", "Matrix Size", matrixSize, &matrixSize); op.add>("i", "iterations", "Test Iterations", testIterations, &testIterations); op.add("", "validate", "Validate Results", &validate); + op.add("", "identity", "Use Identity Data", &identityData); op.add("", "fixed", "Use Fixed Data", &fixedData); op.add("", "emulate", "Unconditionally Emulate dpas", &emulate); op.add>("", "threshold", "Local Error Threshold", threshold, &threshold); @@ -447,6 +493,11 @@ int main(int argc, char** argv) go_dpas_blockread_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); go_dpas_blockread_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_blockread_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_blockread_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_blockread_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_blockread_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + printf("Done.\n"); return 0; diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index 56086a8..53c618f 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -752,6 +752,8 @@ ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(long baseoffset, int w ushort4 __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); @@ -774,6 +776,11 @@ ushort8 intel_subgroup_block_read_u16_m8k16(const __global void *base_address, i return __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } +uint8 intel_subgroup_block_read_u32_m8k16(const __global void* base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} + void intel_subgroup_block_write_u32_m1k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) { __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); @@ -867,6 +874,82 @@ kernel void bfloat16_dpas_blockread_rowmajor_m8_n16(global float* C, global usho intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); } +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_vnni_m1_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int M = get_global_size(1); + const int N = get_global_size(0); + int m = get_group_id(1); + int n = get_group_id(0) * get_local_size(0); + + float sum = 0; + for (int k = 0; k < K; k += 16) { + short aData = as_short(intel_subgroup_block_read_u16_m1k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + sum = mat_mul_x16(aData, bData, sum); + } + + intel_subgroup_block_write_u32_m1k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_vnni_m2_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int M = get_global_size(1) * 2; + const int N = get_global_size(0); + int m = get_group_id(1) * 2; + int n = get_group_id(0) * get_local_size(0); + + float2 sum = 0; + for (int k = 0; k < K; k += 16) { + short2 aData = as_short2(intel_subgroup_block_read_u16_m2k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + sum = mat_mul_x16(aData, bData, sum); + } + + intel_subgroup_block_write_u32_m2k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_vnni_m4_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int M = get_global_size(1) * 4; + const int N = get_global_size(0); + int m = get_group_id(1) * 4; + int n = get_group_id(0) * get_local_size(0); + + float4 sum = 0; + for (int k = 0; k < K; k += 16) { + short4 aData = as_short4(intel_subgroup_block_read_u16_m4k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + sum = mat_mul_x16(aData, bData, sum); + } + + intel_subgroup_block_write_u32_m4k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int M = get_global_size(1) * 8; + const int N = get_global_size(0); + int m = get_group_id(1) * 8; + int n = get_group_id(0) * get_local_size(0); + + float8 sum = 0; + for (int k = 0; k < K; k += 16) { + short8 aData = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + sum = mat_mul_x16(aData, bData, sum); + } + + intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); +} + #endif // cl_intel_subgroup_extended_block_read #endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size)