Skip to content

Commit

Permalink
add SIMD16 versions and emulation
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Jan 6, 2024
1 parent d637ee6 commit 52b9550
Show file tree
Hide file tree
Showing 3 changed files with 443 additions and 45 deletions.
5 changes: 5 additions & 0 deletions include/CL/opencl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,11 @@ CL_HPP_DECLARE_PARAM_TRAITS_(cl_device_info, CL_DEVICE_NUM_THREADS_PER_EU_INTEL,
CL_HPP_DECLARE_PARAM_TRAITS_(cl_device_info, CL_DEVICE_FEATURE_CAPABILITIES_INTEL, cl_device_feature_capabilities_intel)
#endif // cl_intel_device_attribute_query

#if defined(cl_intel_required_subgroup_size)
CL_HPP_DECLARE_PARAM_TRAITS_(cl_device_info, CL_DEVICE_SUB_GROUP_SIZES_INTEL, cl::vector<size_type>)
CL_HPP_DECLARE_PARAM_TRAITS_(cl_kernel_work_group_info, CL_KERNEL_SPILL_MEM_SIZE_INTEL, cl_ulong)
#endif // cl_intel_required_subgroup_size

// Convenience functions

template <typename Func, typename T>
Expand Down
88 changes: 66 additions & 22 deletions samples/99_matrixexperiments/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <algorithm>
#include <chrono>
#include <sstream>
#include <string>
#include <random>
#include <vector>
Expand All @@ -21,6 +22,7 @@ using test_clock = std::chrono::high_resolution_clock;

bool fixedData = false;
bool validate = false;
bool emulate = false;
int testIterations = 16;
float threshold = 0.01f;

Expand All @@ -46,6 +48,16 @@ std::string makeTestName(
return ret.str();
}

static size_t findMinSubGroupSize(cl::Device& device)
{
auto s = device.getInfo<CL_DEVICE_SUB_GROUP_SIZES_INTEL>();
auto it = std::min_element(std::begin(s), std::end(s));
if (it != std::end(s)) {
return *it;
}
return 0;
}

template <typename T>
static void fill_matrix(std::vector<T>& M, size_t numRows, size_t numCols)
{
Expand Down Expand Up @@ -163,7 +175,9 @@ static void go_dpas_rowmajor(
{
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout);

std::string kernelName = "bfloat16_dpas_rowmajor_m" + std::to_string(tM);
std::string kernelName = "bfloat16_dpas_rowmajor";
kernelName += "_m" + std::to_string(tM);
kernelName += "_n" + std::to_string(tN);
cl::Kernel kernel{program, kernelName.c_str()};
if (kernel()) {
kernel.setArg(0, C);
Expand Down Expand Up @@ -204,7 +218,9 @@ static void go_dpas_vnni(
{
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout);

std::string kernelName = "bfloat16_dpas_vnni_m" + std::to_string(tM);
std::string kernelName = "bfloat16_dpas_vnni";
kernelName += "_m" + std::to_string(tM);
kernelName += "_n" + std::to_string(tN);
cl::Kernel kernel{program, kernelName.c_str()};
if (kernel()) {
kernel.setArg(0, C);
Expand Down Expand Up @@ -257,6 +273,7 @@ int main(int argc, char** argv)
op.add<popl::Value<int>>("i", "iterations", "Test Iterations", testIterations, &testIterations);
op.add<popl::Switch>("", "validate", "Validate Results", &validate);
op.add<popl::Switch>("", "fixed", "Use Fixed Data", &fixedData);
op.add<popl::Switch>("", "emulate", "Unconditionally Emulate dpas", &emulate);
op.add<popl::Value<float>>("", "threshold", "Local Error Threshold", threshold, &threshold);
bool printUsage = false;
try {
Expand Down Expand Up @@ -287,10 +304,27 @@ int main(int argc, char** argv)
printf("Running on device: %s\n",
device.getInfo<CL_DEVICE_NAME>().c_str() );

bool emulate_tN8 = true;
bool emulate_tN16 = true;
if (!emulate && checkDeviceForExtension(device, "cl_intel_subgroup_matrix_multiply_accumulate")) {
auto minSubGroupSize = findMinSubGroupSize(device);
printf("Found support for cl_intel_subgroup_matrix_multiply_accumulate, min sub-group size is: %zu\n", minSubGroupSize);
switch(minSubGroupSize) {
case 8: emulate_tN8 = false; break;
case 16: emulate_tN16 = false; break;
default: break;
}
}

buildOptions += " -DEMULATE_tN8=" + std::to_string(emulate_tN8);
buildOptions += " -DEMULATE_tN16=" + std::to_string(emulate_tN16);

printf("Config:\n");
printf("\tTest Iterations: %d\n", testIterations);
printf("\tValidating data?: %s\n", validate ? "true" : "false");
printf("\tFixed data?: %s\n", fixedData ? "true" : "false");
printf("\tEmulate dpas for tN=8?: %s\n", emulate_tN8 ? "true" : "false");
printf("\tEmulate dpas for tN=16?: %s\n", emulate_tN16 ? "true" : "false");

cl::Context context{device};
cl::CommandQueue queue{context, device};
Expand All @@ -314,42 +348,52 @@ int main(int argc, char** argv)
const auto N = matrixSize;
const auto K = matrixSize;

std::vector<bfloat16> A(M * K);
std::vector<bfloat16> B(K * N);
std::vector<bfloat16> B_vnni(K * N);
std::vector<bfloat16> A_vec(M * K);
std::vector<bfloat16> B_vec(K * N);
std::vector<bfloat16> Bvnni_vec(K * N);

std::vector<float> C_ref(M * N);

printf("Initializing source matrices...\n");
fill_matrix(A, M, K);
fill_matrix(B, K, N);
fill_matrix(A_vec, M, K);
fill_matrix(B_vec, K, N);

vnni_matrix(B_vnni, B, K, N, 2);
vnni_matrix(Bvnni_vec, B_vec, K, N, 2);

if (validate) {
printf("Computing reference...\n");
compute_reference(C_ref, A, B, M, N, K);
compute_reference(C_ref, A_vec, B_vec, M, N, K);
}

printf("Creating source buffers...\n");
cl::Buffer Abuf{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, A.size() * sizeof(A[0]), A.data()};
cl::Buffer Bbuf{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B.size() * sizeof(B[0]), B.data()};
cl::Buffer Bbuf_vnni{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B_vnni.size() * sizeof(B_vnni[0]), B_vnni.data()};
cl::Buffer Cbuf{context, CL_MEM_WRITE_ONLY, C_ref.size() * sizeof(C_ref[0])};
cl::Buffer A{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, A_vec.size() * sizeof(A_vec[0]), A_vec.data()};
cl::Buffer B{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B_vec.size() * sizeof(B_vec[0]), B_vec.data()};
cl::Buffer Bvnni{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, Bvnni_vec.size() * sizeof(Bvnni_vec[0]), Bvnni_vec.data()};
cl::Buffer C{context, CL_MEM_WRITE_ONLY, C_ref.size() * sizeof(C_ref[0])};

printf("Running tests...\n");

go_naive(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref);
go_naive(context, program, queue, C, A, B, M, N, K, C_ref);

go_dpas_rowmajor<1, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor<2, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor<4, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor<8, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);

go_dpas_vnni<1, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<2, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<4, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<8, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

go_dpas_rowmajor<1, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref);
go_dpas_rowmajor<2, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref);
go_dpas_rowmajor<4, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref);
go_dpas_rowmajor<8, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref);
go_dpas_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);

go_dpas_vnni<1, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf_vnni, M, N, K, C_ref);
go_dpas_vnni<2, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf_vnni, M, N, K, C_ref);
go_dpas_vnni<4, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf_vnni, M, N, K, C_ref);
go_dpas_vnni<8, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf_vnni, M, N, K, C_ref);
go_dpas_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

printf("Done.\n");

Expand Down
Loading

0 comments on commit 52b9550

Please sign in to comment.