Skip to content

Commit

Permalink
add support for more block prefetches
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Mar 5, 2024
1 parent 873c6ab commit ef205e3
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 4 deletions.
15 changes: 15 additions & 0 deletions samples/99_matrixexperiments/matrix_helpers.cl
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,9 @@ void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1(long baseoffset, int
void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control);
void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control);

void __builtin_IB_subgroup_block_read_prefetch_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control);
void __builtin_IB_subgroup_block_read_prefetch_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control);


void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data);
void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data);
Expand Down Expand Up @@ -780,6 +783,18 @@ void intel_subgroup_block_prefetch_u16_m32k16v2(const __global void *base_addres
__builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE);
#endif // defined(PREFETCH_DEFAULT)
}
void intel_subgroup_block_prefetch_u32_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord)
{
#if defined(PREFETCH_DEFAULT)
__builtin_IB_subgroup_block_read_prefetch_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE);
#endif // defined(PREFETCH_DEFAULT)
}
void intel_subgroup_block_prefetch_u32_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord)
{
#if defined(PREFETCH_DEFAULT)
__builtin_IB_subgroup_block_read_prefetch_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE);
#endif // defined(PREFETCH_DEFAULT)
}


void intel_subgroup_block_write_u32_m1k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint data)
Expand Down
54 changes: 50 additions & 4 deletions samples/99_matrixexperiments/matrix_kernel_tiled.cl
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,52 @@ void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM
}
}

void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n)
{
if (KK % 2 == 0 & NN % 2 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int nn = 0; nn < NN; nn += 2) {
intel_subgroup_block_prefetch_u16_m32k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
}
}
} else if (NN % 2 == 0) {
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn+=2) {
intel_subgroup_block_prefetch_u16_m16k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
}
}
} else if (KK % 2 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int nn = 0; nn < NN; nn++) {
intel_subgroup_block_prefetch_u16_m32k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
}
}
} else {
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
intel_subgroup_block_prefetch_u16_m16k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
}
}
}
}

void HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n)
{
if (KK % 2 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int nn = 0; nn < NN; nn++) {
intel_subgroup_block_prefetch_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2));
}
}
} else {
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
intel_subgroup_block_prefetch_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2));
}
}
}
}

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)

Expand All @@ -560,7 +606,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN
int prefetch_k = 0;
for (int p = 0; p < PREFETCH_DISTANCE; p++) {
HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k);
HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n);
HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n);
prefetch_k += tK * KK;
}

Expand Down Expand Up @@ -590,7 +636,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN
}
}
if (kk == 0) {
HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n);
HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n);
prefetch_k += tK * KK;
}
}
Expand Down Expand Up @@ -621,8 +667,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl

int prefetch_k = 0;
for (int p = 0; p < PREFETCH_DISTANCE; p++) {
HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n);
HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k);
HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n);
prefetch_k += tK * KK;
}

Expand Down Expand Up @@ -652,7 +698,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl
}
}
if (kk == 0) {
HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n);
HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n);
prefetch_k += tK * KK;
}
}
Expand Down

0 comments on commit ef205e3

Please sign in to comment.