Skip to content

Commit

Permalink
add vnni block read variants
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Jan 9, 2024
1 parent 0bb5529 commit 7b89cfe
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 2 deletions.
55 changes: 53 additions & 2 deletions samples/99_matrixexperiments/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

using test_clock = std::chrono::high_resolution_clock;

bool identityData = false;
bool fixedData = false;
bool validate = false;
bool emulate = false;
Expand Down Expand Up @@ -61,10 +62,11 @@ static size_t findMinSubGroupSize(cl::Device& device)
template <typename T>
static void fill_matrix(std::vector<T>& M, size_t numRows, size_t numCols)
{
if (fixedData) {
if (identityData) {
std::generate(std::begin(M), std::end(M), [&]{ return 1.0f; });
} else if (fixedData) {
for (size_t r = 0; r < numRows; r++) {
for (size_t c = 0; c < numCols; c++) {
//M[r * numCols + c] = 1.0f;
M[r * numCols + c] = static_cast<float>(r + c);
}
}
Expand Down Expand Up @@ -298,6 +300,49 @@ static void go_dpas_blockread_rowmajor(
}
}

template<int tM, int tN, int tK>
static void go_dpas_blockread_vnni(
cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
cl::Buffer& C, cl::Buffer& A, cl::Buffer& B,
size_t M, size_t N, size_t K,
const std::vector<float>& C_ref)
{
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout);

std::string kernelName = "bfloat16_dpas_blockread_vnni";
kernelName += "_m" + std::to_string(tM);
kernelName += "_n" + std::to_string(tN);
cl::Kernel kernel{program, kernelName.c_str()};
if (kernel()) {
kernel.setArg(0, C);
kernel.setArg(1, A);
kernel.setArg(2, B);
kernel.setArg(3, static_cast<cl_int>(K));

float best = 999.0f;
for (int test = 0; test < testIterations; test++) {
auto start = test_clock::now();
queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/tM});
queue.finish();
auto end = test_clock::now();
std::chrono::duration<float> elapsed_seconds = end - start;
best = std::min(best, elapsed_seconds.count());
}
auto gops = 2.0 * M * N * K / best / 1e9;
printf("Best in %f seconds (%f gops)\n", best, gops);

if (validate) {
printf("Checking results... "); fflush(stdout);
std::vector<float> C_check(C_ref.size());
queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data());
check_results(C_check, C_ref);
printf(" done!\n");
}
} else {
printf("unsupported.\n");
}
}

int main(int argc, char** argv)
{
int platformIndex = 0;
Expand All @@ -316,6 +361,7 @@ int main(int argc, char** argv)
op.add<popl::Value<size_t>>("m", "matrixsize", "Matrix Size", matrixSize, &matrixSize);
op.add<popl::Value<int>>("i", "iterations", "Test Iterations", testIterations, &testIterations);
op.add<popl::Switch>("", "validate", "Validate Results", &validate);
op.add<popl::Switch>("", "identity", "Use Identity Data", &identityData);
op.add<popl::Switch>("", "fixed", "Use Fixed Data", &fixedData);
op.add<popl::Switch>("", "emulate", "Unconditionally Emulate dpas", &emulate);
op.add<popl::Value<float>>("", "threshold", "Local Error Threshold", threshold, &threshold);
Expand Down Expand Up @@ -447,6 +493,11 @@ int main(int argc, char** argv)
go_dpas_blockread_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);

go_dpas_blockread_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_blockread_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_blockread_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_blockread_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

printf("Done.\n");

return 0;
Expand Down
83 changes: 83 additions & 0 deletions samples/99_matrixexperiments/matrix_kernels.cl
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,8 @@ ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(long baseoffset, int w
ushort4 __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord);
ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord);

uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord);

void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data);
void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data);
void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data);
Expand All @@ -774,6 +776,11 @@ ushort8 intel_subgroup_block_read_u16_m8k16(const __global void *base_address, i
return __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord);
}

uint8 intel_subgroup_block_read_u32_m8k16(const __global void* base_address, int width, int height, int pitch, int2 coord)
{
return __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord);
}

void intel_subgroup_block_write_u32_m1k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint data)
{
__builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data);
Expand Down Expand Up @@ -867,6 +874,82 @@ kernel void bfloat16_dpas_blockread_rowmajor_m8_n16(global float* C, global usho
intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum));
}

__attribute__((intel_reqd_sub_group_size(16)))
__attribute__((reqd_work_group_size(16, 1, 1)))
kernel void bfloat16_dpas_blockread_vnni_m1_n16(global float* C, global ushort* A, global ushort* B, int K)
{
const int M = get_global_size(1);
const int N = get_global_size(0);
int m = get_group_id(1);
int n = get_group_id(0) * get_local_size(0);

float sum = 0;
for (int k = 0; k < K; k += 16) {
short aData = as_short(intel_subgroup_block_read_u16_m1k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m)));
int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2)));
sum = mat_mul_x16(aData, bData, sum);
}

intel_subgroup_block_write_u32_m1k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum));
}

__attribute__((intel_reqd_sub_group_size(16)))
__attribute__((reqd_work_group_size(16, 1, 1)))
kernel void bfloat16_dpas_blockread_vnni_m2_n16(global float* C, global ushort* A, global ushort* B, int K)
{
const int M = get_global_size(1) * 2;
const int N = get_global_size(0);
int m = get_group_id(1) * 2;
int n = get_group_id(0) * get_local_size(0);

float2 sum = 0;
for (int k = 0; k < K; k += 16) {
short2 aData = as_short2(intel_subgroup_block_read_u16_m2k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m)));
int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2)));
sum = mat_mul_x16(aData, bData, sum);
}

intel_subgroup_block_write_u32_m2k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum));
}

__attribute__((intel_reqd_sub_group_size(16)))
__attribute__((reqd_work_group_size(16, 1, 1)))
kernel void bfloat16_dpas_blockread_vnni_m4_n16(global float* C, global ushort* A, global ushort* B, int K)
{
const int M = get_global_size(1) * 4;
const int N = get_global_size(0);
int m = get_group_id(1) * 4;
int n = get_group_id(0) * get_local_size(0);

float4 sum = 0;
for (int k = 0; k < K; k += 16) {
short4 aData = as_short4(intel_subgroup_block_read_u16_m4k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m)));
int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2)));
sum = mat_mul_x16(aData, bData, sum);
}

intel_subgroup_block_write_u32_m4k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum));
}

__attribute__((intel_reqd_sub_group_size(16)))
__attribute__((reqd_work_group_size(16, 1, 1)))
kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort* A, global ushort* B, int K)
{
const int M = get_global_size(1) * 8;
const int N = get_global_size(0);
int m = get_group_id(1) * 8;
int n = get_group_id(0) * get_local_size(0);

float8 sum = 0;
for (int k = 0; k < K; k += 16) {
short8 aData = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m)));
int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2)));
sum = mat_mul_x16(aData, bData, sum);
}

intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum));
}

#endif // cl_intel_subgroup_extended_block_read

#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size)
Expand Down

0 comments on commit 7b89cfe

Please sign in to comment.