diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 6042073..d3808be 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -18,6 +18,7 @@ using test_clock = std::chrono::high_resolution_clock; +bool fixedData = false; bool validate = false; int testIterations = 16; float threshold = 0.01f; @@ -25,18 +26,18 @@ float threshold = 0.01f; template static void fill_matrix(std::vector& M, size_t numRows, size_t numCols) { -#if 1 - std::random_device dev; - std::mt19937 rng(dev()); - std::uniform_real_distribution dist(-1.0, 1.0); - std::generate(std::begin(M), std::end(M), [&]{ return dist(rng); }); -#else - for (size_t r = 0; r < numRows; r++) { - for (size_t c = 0; c < numCols; c++) { - M[r * numCols + c] = c; //1.0f; // + (float)r / numRows + (float)c / numCols; + if (fixedData) { + for (size_t r = 0; r < numRows; r++) { + for (size_t c = 0; c < numCols; c++) { + M[r * numCols + c] = r + c; + } } + } else { + std::random_device dev; + std::mt19937 rng(dev()); + std::uniform_real_distribution dist(-1.0, 1.0); + std::generate(std::begin(M), std::end(M), [&]{ return dist(rng); }); } -#endif } template @@ -100,7 +101,7 @@ static void go_naive( size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); cl::Kernel kernel{program, "bfloat16_naive"}; kernel.setArg(0, C); @@ -117,7 +118,8 @@ static void go_naive( std::chrono::duration elapsed_seconds = end - start; best = std::min(best, elapsed_seconds.count()); } - printf("Finished in %f seconds\n", best); + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Finished in %f seconds (%f gops)\n", best, gops); if (validate) { printf("Checking results... "); fflush(stdout); @@ -129,37 +131,322 @@ static void go_naive( } template -static void go_dpas_basic( +static void go_dpas_rowmajor_m1( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + + cl::Kernel kernel{program, "bfloat16_dpas_rowmajor_m1"}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Finished in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); + } +} - cl::Kernel kernel{program, "bfloat16_dpas_basic"}; - kernel.setArg(0, C); - kernel.setArg(1, A); - kernel.setArg(2, B); - kernel.setArg(3, static_cast(K)); +template +static void go_dpas_rowmajor_m2( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + + cl::Kernel kernel{program, "bfloat16_dpas_rowmajor_m2"}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/2}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Finished in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); + } +} - float best = 999.0f; - for (int test = 0; test < testIterations; test++) { - auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M}); - queue.finish(); - auto end = test_clock::now(); - std::chrono::duration elapsed_seconds = end - start; - best = std::min(best, elapsed_seconds.count()); +template +static void go_dpas_rowmajor_m4( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + + cl::Kernel kernel{program, "bfloat16_dpas_rowmajor_m4"}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/4}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Finished in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); } - printf("Finished in %f seconds\n", best); +} - if (validate) { - printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); - check_results(C_check, C_ref); - printf(" done!\n"); +template +static void go_dpas_rowmajor_m8( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + + cl::Kernel kernel{program, "bfloat16_dpas_rowmajor_m8"}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/8}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Finished in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); + } +} + +template +static void go_dpas_vnni_m1( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + + cl::Kernel kernel{program, "bfloat16_dpas_vnni_m1"}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Finished in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); + } +} + +template +static void go_dpas_vnni_m2( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + + cl::Kernel kernel{program, "bfloat16_dpas_vnni_m2"}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/2}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Finished in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); + } +} + +template +static void go_dpas_vnni_m4( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + + cl::Kernel kernel{program, "bfloat16_dpas_vnni_m4"}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/4}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Finished in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); + } +} + +template +static void go_dpas_vnni_m8( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + + cl::Kernel kernel{program, "bfloat16_dpas_vnni_m8"}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/8}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Finished in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); } } @@ -183,6 +470,7 @@ int main( op.add>("m", "matrixsize", "Matrix Size", matrixSize, &matrixSize); op.add>("i", "iterations", "Test Iterations", testIterations, &testIterations); op.add("", "validate", "Validate Results", &validate); + op.add("", "fixed", "Use Fixed Data", &fixedData); op.add>("", "threshold", "Local Error Threshold", threshold, &threshold); bool printUsage = false; try { @@ -213,6 +501,11 @@ int main( printf("Running on device: %s\n", device.getInfo().c_str() ); + printf("Config:\n"); + printf("\tTest Iterations: %d\n", testIterations); + printf("\tValidating data?: %s\n", validate ? "true" : "false"); + printf("\tFixed data?: %s\n", fixedData ? "true" : "false"); + cl::Context context{device}; cl::CommandQueue queue{context, device}; @@ -242,6 +535,8 @@ int main( fill_matrix(A, matrixSize, matrixSize); fill_matrix(B, matrixSize, matrixSize); + vnni_matrix(B_vnni, B, matrixSize, matrixSize, 2); + if (validate) { printf("Computing reference...\n"); compute_reference(C_ref, A, B, matrixSize, matrixSize, matrixSize); @@ -250,13 +545,21 @@ int main( printf("Creating source buffers...\n"); cl::Buffer Abuf{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, A.size() * sizeof(A[0]), A.data()}; cl::Buffer Bbuf{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B.size() * sizeof(B[0]), B.data()}; + cl::Buffer Bbuf_vnni{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B_vnni.size() * sizeof(B_vnni[0]), B_vnni.data()}; cl::Buffer Cbuf{context, CL_MEM_WRITE_ONLY, C.size() * sizeof(C[0])}; printf("Running tests...\n"); go_naive(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); - - go_dpas_basic(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); + go_dpas_rowmajor_m1(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); + go_dpas_rowmajor_m2(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); + go_dpas_rowmajor_m4(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); + go_dpas_rowmajor_m8(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); + + go_dpas_vnni_m1(context, program, queue, Cbuf, Abuf, Bbuf_vnni, matrixSize, matrixSize, matrixSize, C_ref); + go_dpas_vnni_m2(context, program, queue, Cbuf, Abuf, Bbuf_vnni, matrixSize, matrixSize, matrixSize, C_ref); + go_dpas_vnni_m4(context, program, queue, Cbuf, Abuf, Bbuf_vnni, matrixSize, matrixSize, matrixSize, C_ref); + go_dpas_vnni_m8(context, program, queue, Cbuf, Abuf, Bbuf_vnni, matrixSize, matrixSize, matrixSize, C_ref); printf("Done.\n"); diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index 1646f33..ed4900f 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -24,13 +24,63 @@ kernel void bfloat16_naive(global float* C, global ushort* A, global ushort* B, #if defined(cl_intel_subgroup_matrix_multiply_accumulate) // M rows x K columns -static int __load_a_row_major_bf16_m1(global ushort* A, int rowStart, int colStart, int stride) +static int __load_a_row_major_bf16_k16_m1_x8(global ushort* A, int rowStart, int colStart, int stride) { int ret; - int offset = rowStart * stride + colStart + get_sub_group_local_id() * 2; + global uint* A_ui = (global uint*)A; + int offset_ui = rowStart * stride / 2 + colStart / 2; + ret = intel_sub_group_block_read(A_ui + offset_ui); - ret = as_int(vload2(0, A + offset)); + return ret; +} + +// M rows x K columns +static int2 __load_a_row_major_bf16_k16_m2_x8(global ushort* A, int rowStart, int colStart, int stride) +{ + int2 ret; + + global uint* A_ui = (global uint*)A; + int offset_ui = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + + return ret; +} + +// M rows x K columns +static int4 __load_a_row_major_bf16_k16_m4_x8(global ushort* A, int rowStart, int colStart, int stride) +{ + int4 ret; + + global uint* A_ui = (global uint*)A; + int offset_ui = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s2 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s3 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + + return ret; +} + +// M rows x K columns +static int8 __load_a_row_major_bf16_k16_m8_x8(global ushort* A, int rowStart, int colStart, int stride) +{ + int8 ret; + + global uint* A_ui = (global uint*)A; + int offset_ui = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s2 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s3 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s4 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s5 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s6 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s7 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; return ret; } @@ -42,25 +92,24 @@ static int8 __load_b_row_major_bf16_k16(global ushort* B, int rowStart, int colS { int8 ret; - int offset = rowStart * stride + colStart + get_sub_group_local_id(); - - // Note: this could probably use block loads? - ushort row0 = B[offset]; offset += stride; - ushort row1 = B[offset]; offset += stride; - ushort row2 = B[offset]; offset += stride; - ushort row3 = B[offset]; offset += stride; - ushort row4 = B[offset]; offset += stride; - ushort row5 = B[offset]; offset += stride; - ushort row6 = B[offset]; offset += stride; - ushort row7 = B[offset]; offset += stride; - ushort row8 = B[offset]; offset += stride; - ushort row9 = B[offset]; offset += stride; - ushort row10 = B[offset]; offset += stride; - ushort row11 = B[offset]; offset += stride; - ushort row12 = B[offset]; offset += stride; - ushort row13 = B[offset]; offset += stride; - ushort row14 = B[offset]; offset += stride; - ushort row15 = B[offset]; offset += stride; + int offset = rowStart * stride + colStart; + + ushort row0 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row1 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row2 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row3 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row4 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row5 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row6 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row7 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row8 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row9 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row10 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row11 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row12 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row13 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row14 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row15 = intel_sub_group_block_read_us(B + offset); offset += stride; ret.s0 = as_int((ushort2)(row0, row1 )); ret.s1 = as_int((ushort2)(row2, row3 )); @@ -74,9 +123,82 @@ static int8 __load_b_row_major_bf16_k16(global ushort* B, int rowStart, int colS return ret; } +// K rows x N columns: +// Each work-item loads K values that has already been converted to VNNI. +// Stride is in units of elements. +static int8 __load_b_vnni_bf16_k16(global ushort* B, int rowStart, int colStart, int stride) +{ + int8 ret; + + global uint* B_ui = (global uint*)B; + int offset_ui = rowStart / 2 * stride + colStart; + + ret.s0 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s1 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s2 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s3 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s4 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s5 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s6 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s7 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + + return ret; +} + +static void __store_c_row_major_fp32_m1(global float* C, float v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint v_ui = as_uint(v); + + int offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui); offset += stride; +} + +static void __store_c_row_major_fp32_m2(global float* C, float2 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint2 v_ui = as_uint2(v); + + int offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; +} + +static void __store_c_row_major_fp32_m4(global float* C, float4 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint4 v_ui = as_uint4(v); + + int offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; +} + +static void __store_c_row_major_fp32_m8(global float* C, float8 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint8 v_ui = as_uint8(v); + + int offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s4); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s5); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s6); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s7); offset += stride; +} + __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) -kernel void bfloat16_dpas_basic(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_rowmajor_m1(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); int m = get_group_id(1); @@ -84,18 +206,138 @@ kernel void bfloat16_dpas_basic(global float* C, global ushort* A, global ushort float sum = 0; for (int k = 0; k < K; k += 16) { - int aData = __load_a_row_major_bf16_m1(A, m, k, K); + int aData = __load_a_row_major_bf16_k16_m1_x8(A, m, k, K); int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); } - C[m * N + n + get_sub_group_local_id()] = sum; + __store_c_row_major_fp32_m1(C, sum, m, n, N); } -#else +__attribute__((intel_reqd_sub_group_size(8))) +__attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m2(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 2; + int n = get_group_id(0) * get_local_size(0); + + float2 sum = 0; + for (int k = 0; k < K; k += 16) { + int2 aData = __load_a_row_major_bf16_k16_m2_x8(A, m, k, K); + int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + } + + __store_c_row_major_fp32_m2(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) +__attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m4(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 4; + int n = get_group_id(0) * get_local_size(0); + + float4 sum = 0; + for (int k = 0; k < K; k += 16) { + int4 aData = __load_a_row_major_bf16_k16_m4_x8(A, m, k, K); + int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + } + + __store_c_row_major_fp32_m4(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) +__attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m8(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 8; + int n = get_group_id(0) * get_local_size(0); + + float8 sum = 0; + for (int k = 0; k < K; k += 16) { + int8 aData = __load_a_row_major_bf16_k16_m8_x8(A, m, k, K); + int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + } + + __store_c_row_major_fp32_m8(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) +__attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_vnni_m1(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1); + int n = get_group_id(0) * get_local_size(0); -#pragma message("cl_intel_subgroup_matrix_multiply_accumulate is unsupported!") + float sum = 0; + for (int k = 0; k < K; k += 16) { + int aData = __load_a_row_major_bf16_k16_m1_x8(A, m, k, K); + int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + } + + __store_c_row_major_fp32_m1(C, sum, m, n, N); +} -kernel void bfloat16_dpas_basic(global float* C, global ushort* A, global ushort* B, int K) {} +__attribute__((intel_reqd_sub_group_size(8))) +__attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_vnni_m2(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 2; + int n = get_group_id(0) * get_local_size(0); + + float2 sum = 0; + for (int k = 0; k < K; k += 16) { + int2 aData = __load_a_row_major_bf16_k16_m2_x8(A, m, k, K); + int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + } + + __store_c_row_major_fp32_m2(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) +__attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_vnni_m4(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 4; + int n = get_group_id(0) * get_local_size(0); + + float4 sum = 0; + for (int k = 0; k < K; k += 16) { + int4 aData = __load_a_row_major_bf16_k16_m4_x8(A, m, k, K); + int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + } + + __store_c_row_major_fp32_m4(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) +__attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_vnni_m8(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 8; + int n = get_group_id(0) * get_local_size(0); + + float8 sum = 0; + for (int k = 0; k < K; k += 16) { + int8 aData = __load_a_row_major_bf16_k16_m8_x8(A, m, k, K); + int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + } + + __store_c_row_major_fp32_m8(C, sum, m, n, N); +} -#endif \ No newline at end of file +#endif // defined(cl_intel_subgroup_matrix_multiply_accumulate) \ No newline at end of file