diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 044739e..a3bacda 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -40,25 +40,25 @@ std::string makeTestName( std::string makeTestName( const std::string &func, - int tM, int tN, int tK, + int tM, int tN, size_t M, size_t N, size_t K) { std::ostringstream ret; ret << func; - ret << ""; + ret << ""; ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; return ret.str(); } std::string makeTestName( const std::string &func, - int tM, int tN, int tK, + int tM, int tN, int MM, int NN, size_t M, size_t N, size_t K) { std::ostringstream ret; ret << func; - ret << ""; + ret << ""; ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; return ret.str(); } @@ -157,7 +157,7 @@ static float hw_time(cl::Event& event) return ns / 1e9f; } -static void go_naive( +static void bfloat16_naive( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, @@ -201,14 +201,14 @@ static void go_naive( } } -template -static void go_dpas_rowmajor( +template +static void bfloat16_dpas_rowmajor( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout); + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); std::string kernelName = "bfloat16_dpas_rowmajor"; kernelName += "_m" + std::to_string(tM); @@ -249,14 +249,14 @@ static void go_dpas_rowmajor( } } -template -static void go_dpas_rowmajor_tiled( +template +static void bfloat16_dpas_rowmajor_tiled( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout); + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); std::string kernelName = "bfloat16_dpas_rowmajor_tiled"; kernelName += "_m" + std::to_string(tM); @@ -303,14 +303,14 @@ static void go_dpas_rowmajor_tiled( } } -template -static void go_dpas_vnni( +template +static void bfloat16_dpas_vnni( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout); + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); std::string kernelName = "bfloat16_dpas_vnni"; kernelName += "_m" + std::to_string(tM); @@ -351,14 +351,14 @@ static void go_dpas_vnni( } } -template -static void go_dpas_vnni_tiled( +template +static void bfloat16_dpas_vnni_tiled( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout); + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); std::string kernelName = "bfloat16_dpas_vnni_tiled"; kernelName += "_m" + std::to_string(tM); @@ -405,14 +405,14 @@ static void go_dpas_vnni_tiled( } } -template -static void go_dpas_blockread_rowmajor( +template +static void bfloat16_dpas_blockread_rowmajor( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout); + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); std::string kernelName = "bfloat16_dpas_blockread_rowmajor"; kernelName += "_m" + std::to_string(tM); @@ -453,14 +453,14 @@ static void go_dpas_blockread_rowmajor( } } -template -static void go_dpas_blockread_rowmajor_tiled( +template +static void bfloat16_dpas_blockread_rowmajor_tiled( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout); + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); std::string kernelName = "bfloat16_dpas_blockread_rowmajor_tiled"; kernelName += "_m" + std::to_string(tM); @@ -507,14 +507,14 @@ static void go_dpas_blockread_rowmajor_tiled( } } -template -static void go_dpas_blockread_vnni( +template +static void bfloat16_dpas_blockread_vnni( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout); + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); std::string kernelName = "bfloat16_dpas_blockread_vnni"; kernelName += "_m" + std::to_string(tM); @@ -555,14 +555,14 @@ static void go_dpas_blockread_vnni( } } -template -static void go_dpas_blockread_vnni_tiled( +template +static void bfloat16_dpas_blockread_vnni_tiled( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout); + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); std::string kernelName = "bfloat16_dpas_blockread_vnni_tiled"; kernelName += "_m" + std::to_string(tM); @@ -734,79 +734,79 @@ int main(int argc, char** argv) printf("Running tests...\n"); - go_naive(context, program, queue, C, A, B, M, N, K, C_ref); - - go_dpas_rowmajor<1, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor<2, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor<4, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor<8, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - - go_dpas_rowmajor_tiled<8, 8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - - go_dpas_vnni<1, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni<2, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni<4, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni<8, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - - go_dpas_vnni_tiled<8, 8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - - go_dpas_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - - go_dpas_rowmajor_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - - go_dpas_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - - go_dpas_vnni_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - - go_dpas_blockread_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_blockread_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_blockread_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_blockread_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - - go_dpas_blockread_rowmajor_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_blockread_rowmajor_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_blockread_rowmajor_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_blockread_rowmajor_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_blockread_rowmajor_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_blockread_rowmajor_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - - go_dpas_blockread_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - - go_dpas_blockread_vnni_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_naive(context, program, queue, C, A, B, M, N, K, C_ref); + + bfloat16_dpas_rowmajor<1, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<2, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<4, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<8, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + + bfloat16_dpas_rowmajor_tiled<8, 8, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + + bfloat16_dpas_vnni<1, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<2, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<4, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<8, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + + bfloat16_dpas_vnni_tiled<8, 8, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + + bfloat16_dpas_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + + bfloat16_dpas_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + + bfloat16_dpas_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + + bfloat16_dpas_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + + bfloat16_dpas_blockread_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + + bfloat16_dpas_blockread_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + + bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); printf("Done.\n");