Skip to content

Commit

Permalink
rename tester host functions to match kernel names more closely
Browse files Browse the repository at this point in the history
Also, remove tK from all host function output, since it is only
used internally within the kernels.
  • Loading branch information
bashbaug committed Jan 19, 2024
1 parent 031e076 commit 16b7cda
Showing 1 changed file with 102 additions and 102 deletions.
204 changes: 102 additions & 102 deletions samples/99_matrixexperiments/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,25 @@ std::string makeTestName(

std::string makeTestName(
const std::string &func,
int tM, int tN, int tK,
int tM, int tN,
size_t M, size_t N, size_t K)
{
std::ostringstream ret;
ret << func;
ret << "<tM:" << tM << ", tN:" << tN << ", tK:" << tK << ">";
ret << "<tM:" << tM << ", tN:" << tN << ">";
ret << " (M=" << M << ", N=" << N << ", K=" << K << ")";
return ret.str();
}

std::string makeTestName(
const std::string &func,
int tM, int tN, int tK,
int tM, int tN,
int MM, int NN,
size_t M, size_t N, size_t K)
{
std::ostringstream ret;
ret << func;
ret << "<tM:" << tM << "x" << MM << ", tN:" << tN << "x" << NN << ", tK:" << tK << ">";
ret << "<tM:" << tM << "x" << MM << ", tN:" << tN << "x" << NN << ">";
ret << " (M=" << M << ", N=" << N << ", K=" << K << ")";
return ret.str();
}
Expand Down Expand Up @@ -157,7 +157,7 @@ static float hw_time(cl::Event& event)
return ns / 1e9f;
}

static void go_naive(
static void bfloat16_naive(
cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
cl::Buffer& C, cl::Buffer& A, cl::Buffer& B,
size_t M, size_t N, size_t K,
Expand Down Expand Up @@ -201,14 +201,14 @@ static void go_naive(
}
}

template<int tM, int tN, int tK>
static void go_dpas_rowmajor(
template<int tM, int tN>
static void bfloat16_dpas_rowmajor(
cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
cl::Buffer& C, cl::Buffer& A, cl::Buffer& B,
size_t M, size_t N, size_t K,
const std::vector<float>& C_ref)
{
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout);
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout);

std::string kernelName = "bfloat16_dpas_rowmajor";
kernelName += "_m" + std::to_string(tM);
Expand Down Expand Up @@ -249,14 +249,14 @@ static void go_dpas_rowmajor(
}
}

template<int tM, int tN, int tK, int MM, int NN>
static void go_dpas_rowmajor_tiled(
template<int tM, int tN, int MM, int NN>
static void bfloat16_dpas_rowmajor_tiled(
cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
cl::Buffer& C, cl::Buffer& A, cl::Buffer& B,
size_t M, size_t N, size_t K,
const std::vector<float>& C_ref)
{
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout);
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout);

std::string kernelName = "bfloat16_dpas_rowmajor_tiled";
kernelName += "_m" + std::to_string(tM);
Expand Down Expand Up @@ -303,14 +303,14 @@ static void go_dpas_rowmajor_tiled(
}
}

template<int tM, int tN, int tK>
static void go_dpas_vnni(
template<int tM, int tN>
static void bfloat16_dpas_vnni(
cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
cl::Buffer& C, cl::Buffer& A, cl::Buffer& B,
size_t M, size_t N, size_t K,
const std::vector<float>& C_ref)
{
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout);
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout);

std::string kernelName = "bfloat16_dpas_vnni";
kernelName += "_m" + std::to_string(tM);
Expand Down Expand Up @@ -351,14 +351,14 @@ static void go_dpas_vnni(
}
}

template<int tM, int tN, int tK, int MM, int NN>
static void go_dpas_vnni_tiled(
template<int tM, int tN, int MM, int NN>
static void bfloat16_dpas_vnni_tiled(
cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
cl::Buffer& C, cl::Buffer& A, cl::Buffer& B,
size_t M, size_t N, size_t K,
const std::vector<float>& C_ref)
{
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout);
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout);

std::string kernelName = "bfloat16_dpas_vnni_tiled";
kernelName += "_m" + std::to_string(tM);
Expand Down Expand Up @@ -405,14 +405,14 @@ static void go_dpas_vnni_tiled(
}
}

template<int tM, int tN, int tK>
static void go_dpas_blockread_rowmajor(
template<int tM, int tN>
static void bfloat16_dpas_blockread_rowmajor(
cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
cl::Buffer& C, cl::Buffer& A, cl::Buffer& B,
size_t M, size_t N, size_t K,
const std::vector<float>& C_ref)
{
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout);
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout);

std::string kernelName = "bfloat16_dpas_blockread_rowmajor";
kernelName += "_m" + std::to_string(tM);
Expand Down Expand Up @@ -453,14 +453,14 @@ static void go_dpas_blockread_rowmajor(
}
}

template<int tM, int tN, int tK, int MM, int NN>
static void go_dpas_blockread_rowmajor_tiled(
template<int tM, int tN, int MM, int NN>
static void bfloat16_dpas_blockread_rowmajor_tiled(
cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
cl::Buffer& C, cl::Buffer& A, cl::Buffer& B,
size_t M, size_t N, size_t K,
const std::vector<float>& C_ref)
{
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout);
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout);

std::string kernelName = "bfloat16_dpas_blockread_rowmajor_tiled";
kernelName += "_m" + std::to_string(tM);
Expand Down Expand Up @@ -507,14 +507,14 @@ static void go_dpas_blockread_rowmajor_tiled(
}
}

template<int tM, int tN, int tK>
static void go_dpas_blockread_vnni(
template<int tM, int tN>
static void bfloat16_dpas_blockread_vnni(
cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
cl::Buffer& C, cl::Buffer& A, cl::Buffer& B,
size_t M, size_t N, size_t K,
const std::vector<float>& C_ref)
{
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout);
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout);

std::string kernelName = "bfloat16_dpas_blockread_vnni";
kernelName += "_m" + std::to_string(tM);
Expand Down Expand Up @@ -555,14 +555,14 @@ static void go_dpas_blockread_vnni(
}
}

template<int tM, int tN, int tK, int MM, int NN>
static void go_dpas_blockread_vnni_tiled(
template<int tM, int tN, int MM, int NN>
static void bfloat16_dpas_blockread_vnni_tiled(
cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
cl::Buffer& C, cl::Buffer& A, cl::Buffer& B,
size_t M, size_t N, size_t K,
const std::vector<float>& C_ref)
{
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout);
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout);

std::string kernelName = "bfloat16_dpas_blockread_vnni_tiled";
kernelName += "_m" + std::to_string(tM);
Expand Down Expand Up @@ -734,79 +734,79 @@ int main(int argc, char** argv)

printf("Running tests...\n");

go_naive(context, program, queue, C, A, B, M, N, K, C_ref);

go_dpas_rowmajor<1, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor<2, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor<4, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor<8, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);

go_dpas_rowmajor_tiled<8, 8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor_tiled<8, 8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor_tiled<8, 8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor_tiled<8, 8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor_tiled<8, 8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor_tiled<8, 8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref);

go_dpas_vnni<1, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<2, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<4, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<8, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

go_dpas_vnni_tiled<8, 8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni_tiled<8, 8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni_tiled<8, 8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni_tiled<8, 8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni_tiled<8, 8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni_tiled<8, 8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

go_dpas_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);

go_dpas_rowmajor_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref);

go_dpas_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

go_dpas_vnni_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

go_dpas_blockread_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);

go_dpas_blockread_rowmajor_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref);

go_dpas_blockread_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_blockread_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_blockread_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_blockread_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

go_dpas_blockread_vnni_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_blockread_vnni_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_blockread_vnni_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_blockread_vnni_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_blockread_vnni_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_blockread_vnni_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_naive(context, program, queue, C, A, B, M, N, K, C_ref);

bfloat16_dpas_rowmajor<1, 8>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor<2, 8>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor<4, 8>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor<8, 8>(context, program, queue, C, A, B, M, N, K, C_ref);

bfloat16_dpas_rowmajor_tiled<8, 8, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor_tiled<8, 8, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor_tiled<8, 8, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor_tiled<8, 8, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor_tiled<8, 8, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor_tiled<8, 8, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref);

bfloat16_dpas_vnni<1, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni<2, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni<4, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni<8, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

bfloat16_dpas_vnni_tiled<8, 8, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni_tiled<8, 8, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni_tiled<8, 8, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni_tiled<8, 8, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni_tiled<8, 8, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni_tiled<8, 8, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

bfloat16_dpas_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);

bfloat16_dpas_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref);

bfloat16_dpas_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

bfloat16_dpas_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

bfloat16_dpas_blockread_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_blockread_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_blockread_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_blockread_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);

bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref);

bfloat16_dpas_blockread_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_blockread_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_blockread_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_blockread_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_blockread_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_blockread_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_blockread_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

printf("Done.\n");

Expand Down

0 comments on commit 16b7cda

Please sign in to comment.