Skip to content

Commit

Permalink
performance improvements and bugfixes for DG2
Browse files Browse the repository at this point in the history
- Fixed an error in the stride computation.
- Added changes to improve stateless-to-stateful compilation.
- Added a wider A matrix block read to load two K tiles at a time.
  • Loading branch information
bashbaug committed Jan 23, 2024
1 parent 38d03c0 commit ab84bbe
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 40 deletions.
6 changes: 6 additions & 0 deletions samples/99_matrixexperiments/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ int main(int argc, char** argv)
bfloat16_dpas_rowmajor<4, 8>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor<8, 8>(context, program, queue, C, A, B, M, N, K, C_ref);

bfloat16_dpas_rowmajor_tiled<8, 8, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor_tiled<8, 8, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor_tiled<8, 8, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor_tiled<8, 8, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
Expand All @@ -753,6 +754,7 @@ int main(int argc, char** argv)
bfloat16_dpas_vnni<4, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni<8, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

bfloat16_dpas_vnni_tiled<8, 8, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni_tiled<8, 8, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni_tiled<8, 8, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni_tiled<8, 8, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
Expand All @@ -765,6 +767,7 @@ int main(int argc, char** argv)
bfloat16_dpas_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);

bfloat16_dpas_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
Expand All @@ -777,6 +780,7 @@ int main(int argc, char** argv)
bfloat16_dpas_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

bfloat16_dpas_vnni_tiled<8, 16, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
Expand All @@ -789,6 +793,7 @@ int main(int argc, char** argv)
bfloat16_dpas_blockread_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_blockread_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref);

bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
Expand All @@ -801,6 +806,7 @@ int main(int argc, char** argv)
bfloat16_dpas_blockread_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_blockread_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

bfloat16_dpas_blockread_vnni_tiled<8, 16, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_blockread_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
Expand Down
50 changes: 26 additions & 24 deletions samples/99_matrixexperiments/matrix_helpers.cl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ float bf16_to_fp32(ushort u)

#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short)

typedef global ushort* global_aligned_ushort_ptr __attribute__((align_value(4)));

inline int compute_m(const int num_sgs, const int tM, const int MM)
{
const int m_start = get_group_id(1) * num_sgs;
Expand Down Expand Up @@ -157,7 +159,7 @@ int load_a_rowmajor_d16_m1_k16_sg8(global ushort* A, int rowStart, int colStart
int ret;

global uint* A_ui = (global uint*)A;
int offset_ui = rowStart * stride / 2 + colStart / 2;
uint offset_ui = rowStart * stride / 2 + colStart / 2;
ret = intel_sub_group_block_read(A_ui + offset_ui);

return ret;
Expand All @@ -170,7 +172,7 @@ int2 load_a_rowmajor_d16_m2_k16_sg8(global ushort* A, int rowStart, int colStart
int2 ret;

global uint* A_ui = (global uint*)A;
int offset_ui = rowStart * stride / 2 + colStart / 2;
uint offset_ui = rowStart * stride / 2 + colStart / 2;

ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2;
ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2;
Expand All @@ -185,7 +187,7 @@ int4 load_a_rowmajor_d16_m4_k16_sg8(global ushort* A, int rowStart, int colStart
int4 ret;

global uint* A_ui = (global uint*)A;
int offset_ui = rowStart * stride / 2 + colStart / 2;
uint offset_ui = rowStart * stride / 2 + colStart / 2;

ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2;
ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2;
Expand All @@ -202,7 +204,7 @@ int8 load_a_rowmajor_d16_m8_k16_sg8(global ushort* A, int rowStart, int colStart
int8 ret;

global uint* A_ui = (global uint*)A;
int offset_ui = rowStart * stride / 2 + colStart / 2;
uint offset_ui = rowStart * stride / 2 + colStart / 2;

ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2;
ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2;
Expand All @@ -224,7 +226,7 @@ int16 load_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int colSt
uint16 ret;

global uint* A_ui = (global uint*)A;
int offset_ui = rowStart * stride / 2 + colStart / 2;
uint offset_ui = rowStart * stride / 2 + colStart / 2;

ret.s08 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2;
ret.s19 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2;
Expand All @@ -244,7 +246,7 @@ short load_a_rowmajor_d16_m1_k16_sg16(global ushort* A, int rowStart, int colSta
{
ushort ret;

int offset = rowStart * stride + colStart;
uint offset = rowStart * stride + colStart;
ret = intel_sub_group_block_read_us(A + offset);

return as_short(ret);
Expand All @@ -256,7 +258,7 @@ short2 load_a_rowmajor_d16_m2_k16_sg16(global ushort* A, int rowStart, int colSt
{
ushort2 ret;

int offset = rowStart * stride + colStart;
uint offset = rowStart * stride + colStart;
ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride;
ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride;

Expand All @@ -269,7 +271,7 @@ short4 load_a_rowmajor_d16_m4_k16_sg16(global ushort* A, int rowStart, int colSt
{
ushort4 ret;

int offset = rowStart * stride + colStart;
uint offset = rowStart * stride + colStart;
ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride;
ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride;
ret.s2 = intel_sub_group_block_read_us(A + offset); offset += stride;
Expand All @@ -284,7 +286,7 @@ short8 load_a_rowmajor_d16_m8_k16_sg16(global ushort* A, int rowStart, int colSt
{
ushort8 ret;

int offset = rowStart * stride + colStart;
uint offset = rowStart * stride + colStart;
ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride;
ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride;
ret.s2 = intel_sub_group_block_read_us(A + offset); offset += stride;
Expand All @@ -304,15 +306,15 @@ short16 load_a_rowmajor_d16_m8_k16v2_sg16(global ushort* A, int rowStart, int co
{
ushort16 ret;

int offset = rowStart * stride + colStart;
ret.s08 = intel_sub_group_block_read_us2(A + offset); offset += stride / 2;
ret.s19 = intel_sub_group_block_read_us2(A + offset); offset += stride / 2;
ret.s2a = intel_sub_group_block_read_us2(A + offset); offset += stride / 2;
ret.s3b = intel_sub_group_block_read_us2(A + offset); offset += stride / 2;
ret.s4c = intel_sub_group_block_read_us2(A + offset); offset += stride / 2;
ret.s5d = intel_sub_group_block_read_us2(A + offset); offset += stride / 2;
ret.s6e = intel_sub_group_block_read_us2(A + offset); offset += stride / 2;
ret.s7f = intel_sub_group_block_read_us2(A + offset); offset += stride / 2;
uint offset = rowStart * stride + colStart;
ret.s08 = intel_sub_group_block_read_us2(A + offset); offset += stride;
ret.s19 = intel_sub_group_block_read_us2(A + offset); offset += stride;
ret.s2a = intel_sub_group_block_read_us2(A + offset); offset += stride;
ret.s3b = intel_sub_group_block_read_us2(A + offset); offset += stride;
ret.s4c = intel_sub_group_block_read_us2(A + offset); offset += stride;
ret.s5d = intel_sub_group_block_read_us2(A + offset); offset += stride;
ret.s6e = intel_sub_group_block_read_us2(A + offset); offset += stride;
ret.s7f = intel_sub_group_block_read_us2(A + offset); offset += stride;

return as_short16(ret);
}
Expand All @@ -324,7 +326,7 @@ int8 load_b_rowmajor_d16_k16_nx(global ushort* B, int rowStart, int colStart, in
{
int8 ret;

int offset = rowStart * stride + colStart;
uint offset = rowStart * stride + colStart;

ushort row0 = intel_sub_group_block_read_us(B + offset); offset += stride;
ushort row1 = intel_sub_group_block_read_us(B + offset); offset += stride;
Expand Down Expand Up @@ -363,7 +365,7 @@ int8 load_b_vnni_d16_k16_nx(global ushort* B, int rowStart, int colStart, int st
int8 ret;

global uint* B_ui = (global uint*)B;
int offset_ui = rowStart / 2 * stride + colStart;
uint offset_ui = rowStart / 2 * stride + colStart;

ret.s0 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride;
ret.s1 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride;
Expand All @@ -382,7 +384,7 @@ void store_c_rowmajor_fp32_m1_nx(global float* C, float v, int rowStart, int col
global uint* C_ui = (global uint*)C;
uint v_ui = as_uint(v);

int offset = rowStart * stride + colStart;
uint offset = rowStart * stride + colStart;

intel_sub_group_block_write(C_ui + offset, v_ui); offset += stride;
}
Expand All @@ -392,7 +394,7 @@ void store_c_rowmajor_fp32_m2_nx(global float* C, float2 v, int rowStart, int co
global uint* C_ui = (global uint*)C;
uint2 v_ui = as_uint2(v);

int offset = rowStart * stride + colStart;
uint offset = rowStart * stride + colStart;

intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride;
intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride;
Expand All @@ -403,7 +405,7 @@ void store_c_rowmajor_fp32_m4_nx(global float* C, float4 v, int rowStart, int co
global uint* C_ui = (global uint*)C;
uint4 v_ui = as_uint4(v);

int offset = rowStart * stride + colStart;
uint offset = rowStart * stride + colStart;

intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride;
intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride;
Expand All @@ -416,7 +418,7 @@ void store_c_rowmajor_fp32_m8_nx(global float* C, float8 v, int rowStart, int co
global uint* C_ui = (global uint*)C;
uint8 v_ui = as_uint8(v);

int offset = rowStart * stride + colStart;
uint offset = rowStart * stride + colStart;

intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride;
intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride;
Expand Down
72 changes: 56 additions & 16 deletions samples/99_matrixexperiments/matrix_kernel_tiled.cl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#if HAS_SIMD8

__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)
kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K)
{
const int tM = 8;
const int tN = 8;
Expand All @@ -55,9 +55,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl

for (int k = 0; k < K; k += tK * KK) {
int8 aData[KK][MM];
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K);
if (KK % 2 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm++) {
int16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, k + kk * tK, K);
aData[kk + 0][mm] = aTemp.lo;
aData[kk + 1][mm] = aTemp.hi;
}
}
} else {
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K);
}
}
}

Expand Down Expand Up @@ -90,7 +100,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl
}

__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)
kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K)
{
const int tM = 8;
const int tN = 8;
Expand All @@ -109,9 +119,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float*

for (int k = 0; k < K; k += tK * KK) {
int8 aData[KK][MM];
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K);
if (KK % 2 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm++) {
int16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, k + kk * tK, K);
aData[kk + 0][mm] = aTemp.lo;
aData[kk + 1][mm] = aTemp.hi;
}
}
} else {
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K);
}
}
}

Expand Down Expand Up @@ -146,7 +166,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float*
#endif // HAS_SIMD8

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)
kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K)
{
const int tM = 8;
const int tN = 16;
Expand All @@ -165,9 +185,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f

for (int k = 0; k < K; k += tK * KK) {
short8 aData[KK][MM];
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K);
if (KK % 2 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm++) {
short16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg16(A, m + mm * tM, k + kk * tK, K);
aData[kk + 0][mm] = aTemp.lo;
aData[kk + 1][mm] = aTemp.hi;
}
}
} else {
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K);
}
}
}

Expand Down Expand Up @@ -200,7 +230,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f
}

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)
kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K)
{
const int tM = 8;
const int tN = 16;
Expand All @@ -219,9 +249,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float

for (int k = 0; k < K; k += tK * KK) {
short8 aData[KK][MM];
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K);
if (KK % 2 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm++) {
short16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg16(A, m + mm * tM, k + kk * tK, K);
aData[kk + 0][mm] = aTemp.lo;
aData[kk + 1][mm] = aTemp.hi;
}
}
} else {
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K);
}
}
}

Expand Down
6 changes: 6 additions & 0 deletions samples/99_matrixexperiments/matrix_kernels.cl
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,12 @@ kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort*

// Tiled matrix multiplication kernels, generated from a template:

#define MM 1
#define NN 1
#include "matrix_kernel_tiled.cl"
#undef MM
#undef NN

#define MM 2
#define NN 1
#include "matrix_kernel_tiled.cl"
Expand Down

0 comments on commit ab84bbe

Please sign in to comment.