diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index d3808be..26d177f 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -79,9 +79,9 @@ int check_results(const std::vector& C, { float err = 0.f; for (int i = 0; i < C.size(); ++i) { - float localErr = std::fabs(C[i] - C_ref[i]) / - std::max(std::fabs(C[i]), - std::fabs(C_ref[i])); + auto localErr = std::fabs(C[i] - C_ref[i]) / + std::max(std::fabs(C[i]), + std::fabs(C_ref[i])); err = std::max(localErr, err); if (localErr >= threshold) { std::cerr << "Error at index " << i << " (local error " << localErr @@ -94,12 +94,11 @@ int check_results(const std::vector& C, return err < 0.001f; } -template static void go_naive( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, - const std::vector& C_ref) + const std::vector& C_ref) { printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); @@ -123,19 +122,18 @@ static void go_naive( if (validate) { printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); check_results(C_check, C_ref); printf(" done!\n"); } } -template static void go_dpas_rowmajor_m1( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, - const std::vector& C_ref) + const std::vector& C_ref) { printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); @@ -160,8 +158,8 @@ static void go_dpas_rowmajor_m1( if (validate) { printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); check_results(C_check, C_ref); printf(" done!\n"); } @@ -170,12 +168,11 @@ static void go_dpas_rowmajor_m1( } } -template static void go_dpas_rowmajor_m2( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, - const std::vector& C_ref) + const std::vector& C_ref) { printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); @@ -200,8 +197,8 @@ static void go_dpas_rowmajor_m2( if (validate) { printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); check_results(C_check, C_ref); printf(" done!\n"); } @@ -210,12 +207,11 @@ static void go_dpas_rowmajor_m2( } } -template static void go_dpas_rowmajor_m4( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, - const std::vector& C_ref) + const std::vector& C_ref) { printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); @@ -240,8 +236,8 @@ static void go_dpas_rowmajor_m4( if (validate) { printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); check_results(C_check, C_ref); printf(" done!\n"); } @@ -250,12 +246,11 @@ static void go_dpas_rowmajor_m4( } } -template static void go_dpas_rowmajor_m8( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, - const std::vector& C_ref) + const std::vector& C_ref) { printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); @@ -280,8 +275,8 @@ static void go_dpas_rowmajor_m8( if (validate) { printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); check_results(C_check, C_ref); printf(" done!\n"); } @@ -290,12 +285,11 @@ static void go_dpas_rowmajor_m8( } } -template static void go_dpas_vnni_m1( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, - const std::vector& C_ref) + const std::vector& C_ref) { printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); @@ -320,8 +314,8 @@ static void go_dpas_vnni_m1( if (validate) { printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); check_results(C_check, C_ref); printf(" done!\n"); } @@ -330,12 +324,11 @@ static void go_dpas_vnni_m1( } } -template static void go_dpas_vnni_m2( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, - const std::vector& C_ref) + const std::vector& C_ref) { printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); @@ -360,8 +353,8 @@ static void go_dpas_vnni_m2( if (validate) { printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); check_results(C_check, C_ref); printf(" done!\n"); } @@ -370,12 +363,11 @@ static void go_dpas_vnni_m2( } } -template static void go_dpas_vnni_m4( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, - const std::vector& C_ref) + const std::vector& C_ref) { printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); @@ -400,8 +392,8 @@ static void go_dpas_vnni_m4( if (validate) { printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); check_results(C_check, C_ref); printf(" done!\n"); } @@ -410,12 +402,11 @@ static void go_dpas_vnni_m4( } } -template static void go_dpas_vnni_m8( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, - const std::vector& C_ref) + const std::vector& C_ref) { printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); @@ -440,8 +431,8 @@ static void go_dpas_vnni_m8( if (validate) { printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); check_results(C_check, C_ref); printf(" done!\n"); } @@ -450,9 +441,7 @@ static void go_dpas_vnni_m8( } } -int main( - int argc, - char** argv ) +int main(int argc, char** argv) { int platformIndex = 0; int deviceIndex = 0; @@ -551,6 +540,7 @@ int main( printf("Running tests...\n"); go_naive(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); + go_dpas_rowmajor_m1(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); go_dpas_rowmajor_m2(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); go_dpas_rowmajor_m4(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index ed4900f..76ce675 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -15,7 +15,7 @@ kernel void bfloat16_naive(global float* C, global ushort* A, global ushort* B, float sum = 0; for (int k = 0; k < K; k++) { - sum += bfloat16_to_float(A[m * K + k]) * bfloat16_to_float(B[k * N + n]); + sum = fma(bfloat16_to_float(A[m * K + k]), bfloat16_to_float(B[k * N + n]), sum); } C[m * N + n] = sum; @@ -340,4 +340,4 @@ kernel void bfloat16_dpas_vnni_m8(global float* C, global ushort* A, global usho __store_c_row_major_fp32_m8(C, sum, m, n, N); } -#endif // defined(cl_intel_subgroup_matrix_multiply_accumulate) \ No newline at end of file +#endif // defined(cl_intel_subgroup_matrix_multiply_accumulate)