Skip to content

Commit

Permalink
add 2D block read variants
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Jan 9, 2024
1 parent 1ca8f73 commit f2b00f3
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 25 deletions.
2 changes: 1 addition & 1 deletion samples/99_matrixexperiments/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

add_opencl_sample(
TEST
NUMBER 05
NUMBER 99
TARGET matrixexperiments
VERSION 120
SOURCES main.cpp
Expand Down
91 changes: 70 additions & 21 deletions samples/99_matrixexperiments/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ static void fill_matrix(std::vector<T>& M, size_t numRows, size_t numCols)
if (fixedData) {
for (size_t r = 0; r < numRows; r++) {
for (size_t c = 0; c < numCols; c++) {
//M[r * numCols + c] = 1.0f;
M[r * numCols + c] = static_cast<float>(r + c);
}
}
Expand Down Expand Up @@ -254,6 +255,49 @@ static void go_dpas_vnni(
}
}

template<int tM, int tN, int tK>
static void go_dpas_blockread_rowmajor(
cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
cl::Buffer& C, cl::Buffer& A, cl::Buffer& B,
size_t M, size_t N, size_t K,
const std::vector<float>& C_ref)
{
printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout);

std::string kernelName = "bfloat16_dpas_blockread_rowmajor";
kernelName += "_m" + std::to_string(tM);
kernelName += "_n" + std::to_string(tN);
cl::Kernel kernel{program, kernelName.c_str()};
if (kernel()) {
kernel.setArg(0, C);
kernel.setArg(1, A);
kernel.setArg(2, B);
kernel.setArg(3, static_cast<cl_int>(K));

float best = 999.0f;
for (int test = 0; test < testIterations; test++) {
auto start = test_clock::now();
queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/tM});
queue.finish();
auto end = test_clock::now();
std::chrono::duration<float> elapsed_seconds = end - start;
best = std::min(best, elapsed_seconds.count());
}
auto gops = 2.0 * M * N * K / best / 1e9;
printf("Best in %f seconds (%f gops)\n", best, gops);

if (validate) {
printf("Checking results... "); fflush(stdout);
std::vector<float> C_check(C_ref.size());
queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data());
check_results(C_check, C_ref);
printf(" done!\n");
}
} else {
printf("unsupported.\n");
}
}

int main(int argc, char** argv)
{
int platformIndex = 0;
Expand Down Expand Up @@ -376,27 +420,32 @@ int main(int argc, char** argv)

printf("Running tests...\n");

go_naive(context, program, queue, C, A, B, M, N, K, C_ref);

go_dpas_rowmajor<1, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor<2, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor<4, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor<8, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);

go_dpas_vnni<1, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<2, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<4, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<8, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

go_dpas_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);

go_dpas_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
//go_naive(context, program, queue, C, A, B, M, N, K, C_ref);
//
//go_dpas_rowmajor<1, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
//go_dpas_rowmajor<2, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
//go_dpas_rowmajor<4, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
//go_dpas_rowmajor<8, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
//
//go_dpas_vnni<1, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
//go_dpas_vnni<2, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
//go_dpas_vnni<4, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
//go_dpas_vnni<8, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
//
//go_dpas_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
//go_dpas_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
//go_dpas_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
//go_dpas_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
//
//go_dpas_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
//go_dpas_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
//go_dpas_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
//go_dpas_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

go_dpas_blockread_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);

printf("Done.\n");

Expand Down
180 changes: 177 additions & 3 deletions samples/99_matrixexperiments/matrix_kernels.cl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#define OVLD __attribute__((overloadable))

#if EMULATE_tn8 == 0
#define mat_mul_x8 intel_sub_group_bf16_bf16_matrix_mad_k16
#else
Expand Down Expand Up @@ -35,7 +37,8 @@ kernel void bfloat16_naive(global float* C, global ushort* A, global ushort* B,

#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size)

#define OVLD __attribute__((overloadable))
// These are non-block read versions.
// They work on DG2 and PVC, and on other devices when emulated.

// SIMD8 versions:
static float OVLD my_sub_group_bf16_bf16_matrix_mad_k16(int a, int8 b, float acc)
Expand Down Expand Up @@ -163,8 +166,6 @@ static float8 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float
return res;
}

#undef OVLD

// M rows x K columns
// This is the SIMD8 version, where each work-item loads two values.
static int __load_a_row_major_bf16_k16_m1_x8(global ushort* A, int rowStart, int colStart, int stride)
Expand Down Expand Up @@ -697,4 +698,177 @@ kernel void bfloat16_dpas_vnni_m8_n16(global float* C, global ushort* A, global
__store_c_row_major_fp32_m8(C, sum, m, n, N);
}

#ifdef cl_intel_subgroup_extended_block_read

// Note for 2D block reads:
// - the tile width and height is encoded into the function name.
// - base_address is the byte address. Must be 64B aligned.
// - width is the width of the entire matrix, in bytes. Must be >= 64B. Must be 4B aligned.
// - height is the height of the entire matrix, or equivalently the number of rows.
// - pitch is the number of bytes between rows of the entire matrix. Must be >= 64B. Must be a multiple of 8 bytes.
// - coord is the number of elements (x coord) and row (y coord) to read from. X coord must be multiple 4 for for 1B data and 2 for 2B data.

// Built-in functions are:

// #ifdef cl_intel_subgroup_extended_block_read
// ushort2 intel_subgroup_block_read_u8_m1k32v2(__global void *base_address, int width, int height, int pitch, int2 coord);
// ushort4 intel_subgroup_block_read_u8_m2k32v2(__global void *base_address, int width, int height, int pitch, int2 coord);
// ushort8 intel_subgroup_block_read_u8_m4k32v2(__global void *base_address, int width, int height, int pitch, int2 coord);
// ushort16 intel_subgroup_block_read_u8_m8k32v2(__global void *base_address, int width, int height, int pitch, int2 coord);
// ushort2 intel_subgroup_block_read_u16_m1k16v2(__global void *base_address, int width, int height, int pitch, int2 coord);
// ushort4 intel_subgroup_block_read_u16_m2k16v2(__global void *base_address, int width, int height, int pitch, int2 coord);
// ushort8 intel_subgroup_block_read_u16_m4k16v2(__global void *base_address, int width, int height, int pitch, int2 coord);
// ushort16 intel_subgroup_block_read_u16_m8k16v2(__global void *base_address, int width, int height, int pitch, int2 coord);
// uint8 intel_subgroup_block_read_transform_u8_k32(__global void *base_address, int width, int height, int pitch, int2 coord);
// uint8 intel_subgroup_block_read_transform_u16_k16(__global void *base_address, int width, int height, int pitch, int2 coord);
// uint8 intel_subgroup_block_read_transpose_u32_k8(__global void *base_address, int width, int height, int pitch, int2 coord);
// ulong4 intel_subgroup_block_read_transpose_u64_k4(__global void *base_address, int width, int height, int pitch, int2 coord);
// #endif //defined(cl_intel_subgroup_extended_block_read)


// For intrinsics, the pattern is:
// - prefix: __builtin_IB_subgroup_block_read_flat or __builtin_IB_subgroup_block_write_flat
// - operation (optional): _transpose or _transform
// - for no transpose or transform:
// - type / elements size: _u8 or _u16 or _u32 or _u64
// - number of tile rows: _m32 or _m16 or _m8 or _m4 or _m2 or _m1
// - tile width: _k64 or _k32 or _k16 or _k8
// - number of tiles: _v2 or _v1
// - for transpose:
// - type / element size: _u64 or _u32
// - number of tile rows: subgroup size (16)
// - tile width: _k4 (for _u64) or _k8 (for _u32)
// - number of tiles: 1
// - for transform:
// - type / element size: _u16 or _u8
// - number of tile rows: _k32 (for _u8) or _k16 (for _u16)
// - tile width: subgroup size (16)
// - number of tiles: 1

// Define additional "non-vector" block read and writes. These are supported by the hardware but are not in the headers:

ushort __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord);
ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord);
ushort4 __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord);
ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord);

void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data);
void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data);
void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data);
void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data);

ushort intel_subgroup_block_read_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord)
{
return __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord);
}
ushort2 intel_subgroup_block_read_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord)
{
return __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord);
}
ushort4 intel_subgroup_block_read_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord)
{
return __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord);
}
ushort8 intel_subgroup_block_read_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord)
{
return __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord);
}

void intel_subgroup_block_write_u32_m1k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint data)
{
__builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data);
}
void intel_subgroup_block_write_u32_m2k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint2 data)
{
__builtin_IB_subgroup_block_write_flat_u32_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data);
}
void intel_subgroup_block_write_u32_m4k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint4 data)
{
__builtin_IB_subgroup_block_write_flat_u32_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data);
}
void intel_subgroup_block_write_u32_m8k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint8 data)
{
__builtin_IB_subgroup_block_write_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data);
}

__attribute__((intel_reqd_sub_group_size(16)))
__attribute__((reqd_work_group_size(16, 1, 1)))
kernel void bfloat16_dpas_blockread_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K)
{
const int M = get_global_size(1);
const int N = get_global_size(0);
int m = get_group_id(1);
int n = get_group_id(0) * get_local_size(0);

float sum = 0;
for (int k = 0; k < K; k += 16) {
short aData = as_short(intel_subgroup_block_read_u16_m1k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m)));
int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k)));
sum = mat_mul_x16(aData, bData, sum);
}

intel_subgroup_block_write_u32_m1k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum));
}

__attribute__((intel_reqd_sub_group_size(16)))
__attribute__((reqd_work_group_size(16, 1, 1)))
kernel void bfloat16_dpas_blockread_rowmajor_m2_n16(global float* C, global ushort* A, global ushort* B, int K)
{
const int M = get_global_size(1) * 2;
const int N = get_global_size(0);
int m = get_group_id(1) * 2;
int n = get_group_id(0) * get_local_size(0);

float2 sum = 0;
for (int k = 0; k < K; k += 16) {
short2 aData = as_short2(intel_subgroup_block_read_u16_m2k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m)));
int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k)));
sum = mat_mul_x16(aData, bData, sum);
}

intel_subgroup_block_write_u32_m2k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum));
}

__attribute__((intel_reqd_sub_group_size(16)))
__attribute__((reqd_work_group_size(16, 1, 1)))
kernel void bfloat16_dpas_blockread_rowmajor_m4_n16(global float* C, global ushort* A, global ushort* B, int K)
{
const int M = get_global_size(1) * 4;
const int N = get_global_size(0);
int m = get_group_id(1) * 4;
int n = get_group_id(0) * get_local_size(0);

float4 sum = 0;
for (int k = 0; k < K; k += 16) {
short4 aData = as_short4(intel_subgroup_block_read_u16_m4k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m)));
int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k)));
sum = mat_mul_x16(aData, bData, sum);
}

intel_subgroup_block_write_u32_m4k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum));
}

__attribute__((intel_reqd_sub_group_size(16)))
__attribute__((reqd_work_group_size(16, 1, 1)))
kernel void bfloat16_dpas_blockread_rowmajor_m8_n16(global float* C, global ushort* A, global ushort* B, int K)
{
const int M = get_global_size(1) * 8;
const int N = get_global_size(0);
int m = get_group_id(1) * 8;
int n = get_group_id(0) * get_local_size(0);

float8 sum = 0;
for (int k = 0; k < K; k += 16) {
short8 aData = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m)));
int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k)));
sum = mat_mul_x16(aData, bData, sum);
}

intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum));
}

#endif // cl_intel_subgroup_extended_block_read

#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size)

#undef OVLD

0 comments on commit f2b00f3

Please sign in to comment.