Skip to content

Commit

Permalink
fix block read tiled kernels and execute them
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Jan 15, 2024
1 parent c7edcd6 commit a433769
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 8 deletions.
136 changes: 133 additions & 3 deletions samples/99_matrixexperiments/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ static void go_dpas_rowmajor(
kernel.setArg(2, B);
kernel.setArg(3, static_cast<cl_int>(K));

queue.enqueueFillBuffer(C, 0, 0, C_ref.size());

float best = 999.0f;
for (int test = 0; test < testIterations; test++) {
cl::Event event;
Expand Down Expand Up @@ -262,18 +264,20 @@ static void go_dpas_rowmajor_tiled(
kernelName += "_" + std::to_string(MM);
kernelName += "x" + std::to_string(NN);
cl::Kernel kernel{program, kernelName.c_str()};
if (tM * MM > M) {
if (kernel() == nullptr) {
printf("unsupported.\n");
} else if (tM * MM > M) {
printf("M is too small.\n");
} else if (tN * NN > N) {
printf("N is too small.\n");
} else if (kernel() == nullptr) {
printf("unsupported.\n");
} else {
kernel.setArg(0, C);
kernel.setArg(1, A);
kernel.setArg(2, B);
kernel.setArg(3, static_cast<cl_int>(K));

queue.enqueueFillBuffer(C, 0, 0, C_ref.size());

float best = 999.0f;
for (int test = 0; test < testIterations; test++) {
cl::Event event;
Expand Down Expand Up @@ -422,6 +426,8 @@ static void go_dpas_blockread_rowmajor(
kernel.setArg(2, B);
kernel.setArg(3, static_cast<cl_int>(K));

queue.enqueueFillBuffer(C, 0, 0, C_ref.size());

float best = 999.0f;
for (int test = 0; test < testIterations; test++) {
cl::Event event;
Expand All @@ -447,6 +453,60 @@ static void go_dpas_blockread_rowmajor(
}
}

template<int tM, int tN, int tK, int MM, int NN>
static void go_dpas_blockread_rowmajor_tiled(
cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
cl::Buffer& C, cl::Buffer& A, cl::Buffer& B,
size_t M, size_t N, size_t K,
const std::vector<float>& C_ref)
{
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout);

std::string kernelName = "bfloat16_dpas_blockread_rowmajor_tiled";
kernelName += "_m" + std::to_string(tM);
kernelName += "_n" + std::to_string(tN);
kernelName += "_" + std::to_string(MM);
kernelName += "x" + std::to_string(NN);
cl::Kernel kernel{program, kernelName.c_str()};
if (kernel() == nullptr) {
printf("unsupported.\n");
} else if (tM * MM > M) {
printf("M is too small.\n");
} else if (tN * NN > N) {
printf("N is too small.\n");
} else {
kernel.setArg(0, C);
kernel.setArg(1, A);
kernel.setArg(2, B);
kernel.setArg(3, static_cast<cl_int>(K));

queue.enqueueFillBuffer(C, 0, 0, C_ref.size());

float best = 999.0f;
for (int test = 0; test < testIterations; test++) {
cl::Event event;
auto start = test_clock::now();
queue.enqueueNDRangeKernel(kernel, cl::NullRange,
cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event);
queue.finish();
auto end = test_clock::now();
std::chrono::duration<float> sw_time = end - start;
auto elapsed = wallclock ? sw_time.count() : hw_time(event);
best = std::min(best, elapsed);
}
auto gops = 2.0 * M * N * K / best / 1e9;
printf("Best in %f seconds (%f gops)\n", best, gops);

if (validate) {
printf("Checking results... "); fflush(stdout);
std::vector<float> C_check(C_ref.size());
queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data());
check_results(M, N, C_check, C_ref);
printf(" done!\n");
}
}
}

template<int tM, int tN, int tK>
static void go_dpas_blockread_vnni(
cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
Expand All @@ -468,6 +528,8 @@ static void go_dpas_blockread_vnni(
kernel.setArg(2, B);
kernel.setArg(3, static_cast<cl_int>(K));

queue.enqueueFillBuffer(C, 0, 0, C_ref.size());

float best = 999.0f;
for (int test = 0; test < testIterations; test++) {
cl::Event event;
Expand All @@ -493,6 +555,60 @@ static void go_dpas_blockread_vnni(
}
}

template<int tM, int tN, int tK, int MM, int NN>
static void go_dpas_blockread_vnni_tiled(
cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
cl::Buffer& C, cl::Buffer& A, cl::Buffer& B,
size_t M, size_t N, size_t K,
const std::vector<float>& C_ref)
{
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout);

std::string kernelName = "bfloat16_dpas_blockread_vnni_tiled";
kernelName += "_m" + std::to_string(tM);
kernelName += "_n" + std::to_string(tN);
kernelName += "_" + std::to_string(MM);
kernelName += "x" + std::to_string(NN);
cl::Kernel kernel{program, kernelName.c_str()};
if (kernel() == nullptr) {
printf("unsupported.\n");
} else if (tM * MM > M) {
printf("M is too small.\n");
} else if (tN * NN > N) {
printf("N is too small.\n");
} else {
kernel.setArg(0, C);
kernel.setArg(1, A);
kernel.setArg(2, B);
kernel.setArg(3, static_cast<cl_int>(K));

queue.enqueueFillBuffer(C, 0, 0, C_ref.size());

float best = 999.0f;
for (int test = 0; test < testIterations; test++) {
cl::Event event;
auto start = test_clock::now();
queue.enqueueNDRangeKernel(kernel, cl::NullRange,
cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event);
queue.finish();
auto end = test_clock::now();
std::chrono::duration<float> sw_time = end - start;
auto elapsed = wallclock ? sw_time.count() : hw_time(event);
best = std::min(best, elapsed);
}
auto gops = 2.0 * M * N * K / best / 1e9;
printf("Best in %f seconds (%f gops)\n", best, gops);

if (validate) {
printf("Checking results... "); fflush(stdout);
std::vector<float> C_check(C_ref.size());
queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data());
check_results(M, N, C_check, C_ref);
printf(" done!\n");
}
}
}

int main(int argc, char** argv)
{
int platformIndex = 0;
Expand Down Expand Up @@ -673,11 +789,25 @@ int main(int argc, char** argv)
go_dpas_blockread_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);

go_dpas_blockread_rowmajor_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref);

go_dpas_blockread_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_blockread_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_blockread_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_blockread_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

go_dpas_blockread_vnni_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_blockread_vnni_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_blockread_vnni_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_blockread_vnni_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_blockread_vnni_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_blockread_vnni_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

printf("Done.\n");

return 0;
Expand Down
12 changes: 7 additions & 5 deletions samples/99_matrixexperiments/matrix_kernel_tiled.cl
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,11 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float
#ifdef cl_intel_subgroup_extended_block_read

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)
kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)
{
const int tM = 8;
const int tN = 16;
const int M = get_global_size(1) * tM;
const int N = get_global_size(0) * NN;
const int m = get_group_id(1) * tM * MM;
const int n = get_group_id(0) * tN * NN;
Expand All @@ -207,7 +208,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f

int8 bData[NN];
for (int nn = 0; nn < NN; nn++) {
bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N);
bData[nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k)));
}

for (int mm = 0; mm < MM; mm++) {
Expand All @@ -225,10 +226,11 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f
}

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)
kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)
{
const int tM = 8;
const int tN = 16;
const int M = get_global_size(1) * tM;
const int N = get_global_size(0) * NN;
const int m = get_group_id(1) * tM * MM;
const int n = get_group_id(0) * tN * NN;
Expand All @@ -248,7 +250,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float

int8 bData[NN];
for (int nn = 0; nn < NN; nn++) {
bData[nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k)));
bData[nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, k / 2)));
}

for (int mm = 0; mm < MM; mm++) {
Expand All @@ -265,4 +267,4 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float
}
}

#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size)
#endif // cl_intel_subgroup_extended_block_read

0 comments on commit a433769

Please sign in to comment.