Skip to content

Commit

Permalink
add support for 2D block prefetches
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Mar 3, 2024
1 parent cbbdcb6 commit 2d23f76
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 16 deletions.
76 changes: 70 additions & 6 deletions samples/99_matrixexperiments/matrix_helpers.cl
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ void prefetch_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int co
#if defined(PREFETCH_DEFAULT)
uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride;
__builtin_assume((ulong)(A + offset) % 4 == 0);
prefetch(A + offset, 1);
prefetch(A + offset, 2);
#endif // defined(PREFETCH_DEFAULT)
}

Expand Down Expand Up @@ -379,7 +379,7 @@ void prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(global ushort* A, int rowStart, int
#if defined(PREFETCH_DEFAULT)
uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride;
__builtin_assume((ulong)(A + offset) % 4 == 0);
prefetch(A + offset, 1);
prefetch(A + offset, 2);
#endif // defined(PREFETCH_DEFAULT)
}

Expand Down Expand Up @@ -449,9 +449,9 @@ void prefetch_b_rowmajor_d16_k16_n8v4_sg8(global ushort* B, int rowStart, int co
#if defined(PREFETCH_DEFAULT)
uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride;
__builtin_assume((ulong)(B + offset) % 4 == 0);
prefetch(B + offset, 1); offset += 8 * stride;
prefetch(B + offset, 2); offset += 8 * stride;
__builtin_assume((ulong)(B + offset) % 4 == 0);
prefetch(B + offset, 1); offset += 8 * stride;
prefetch(B + offset, 2); offset += 8 * stride;
#endif // defined(PREFETCH_DEFAULT)
}

Expand All @@ -461,7 +461,7 @@ void prefetch_b_rowmajor_d16_k16_n16v2_sg16(global ushort* B, int rowStart, int
#if defined(PREFETCH_DEFAULT)
uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride;
__builtin_assume((ulong)(B + offset) % 4 == 0);
prefetch(B + offset, 1);
prefetch(B + offset, 2);
#endif // defined(PREFETCH_DEFAULT)
}

Expand Down Expand Up @@ -587,10 +587,21 @@ void store_c_rowmajor_fp32_m8_nx(global float* C, float8 v, int rowStart, int co
// - tile width: subgroup size (16)
// - number of tiles: 1

enum LSC_LDCC {
LSC_LDCC_DEFAULT = 0,
LSC_LDCC_L1UC_L3UC = 1, // Override to L1 uncached and L3 uncached
LSC_LDCC_L1UC_L3C = 2, // Override to L1 uncached and L3 cached
LSC_LDCC_L1C_L3UC = 3, // Override to L1 cached and L3 uncached
LSC_LDCC_L1C_L3C = 4, // Override to L1 cached and L3 cached
LSC_LDCC_L1S_L3UC = 5, // Override to L1 streaming load and L3 uncached
LSC_LDCC_L1S_L3C = 6, // Override to L1 streaming load and L3 cached
LSC_LDCC_L1IAR_L3C = 7, // Override to L1 invalidate-after-read, and L3 cached
};

typedef ushort __attribute__((ext_vector_type(32))) ushort32;
typedef ushort __attribute__((ext_vector_type(64))) ushort64;

// Define block reads and writes. These are supported by the hardware but are not in the headers:
// Define block reads, prefetches, and writes. These are supported by the hardware but are not in the headers:

ushort __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord);
ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord);
Expand All @@ -605,6 +616,19 @@ ushort64 __builtin_IB_subgroup_block_read_flat_u16_m32k16v2(long baseoffset, int
uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord);
uint16 __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord);


void __builtin_IB_subgroup_block_read_prefetch_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control);
void __builtin_IB_subgroup_block_read_prefetch_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control);
void __builtin_IB_subgroup_block_read_prefetch_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control);
void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control);
void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control);
void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control);
void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control);

void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control);
void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control);


void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data);
void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data);
void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data);
Expand Down Expand Up @@ -669,6 +693,46 @@ uint16 intel_subgroup_block_read_u32_m16k16(const __global void* base_address, i
return __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord);
}

#define BLOCK_PREFETCH_CACHE_TYPE LSC_LDCC_L1C_L3C

void intel_subgroup_block_prefetch_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord)
{
__builtin_IB_subgroup_block_read_prefetch_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE);
}
void intel_subgroup_block_prefetch_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord)
{
__builtin_IB_subgroup_block_read_prefetch_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE);
}
void intel_subgroup_block_prefetch_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord)
{
__builtin_IB_subgroup_block_read_prefetch_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE);
}
void intel_subgroup_block_prefetch_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord)
{
__builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE);
}
void intel_subgroup_block_prefetch_u16_m8k16v2(__global void *base_address, int width, int height, int pitch, int2 coord)
{
__builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE);
}
void intel_subgroup_block_prefetch_u16_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord)
{
__builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE);
}
void intel_subgroup_block_prefetch_u16_m32k16(const __global void *base_address, int width, int height, int pitch, int2 coord)
{
__builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE);
}
void intel_subgroup_block_prefetch_u16_m16k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord)
{
__builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE);
}
void intel_subgroup_block_prefetch_u16_m32k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord)
{
__builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE);
}


void intel_subgroup_block_write_u32_m1k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint data)
{
__builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data);
Expand Down
53 changes: 43 additions & 10 deletions samples/99_matrixexperiments/matrix_kernel_tiled.cl
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,41 @@ void HELPER_NAME(btile_load_blockread_vnni, MM, NN)(global ushort* B, int tN, in
}
}

void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k)
{
if (KK % 2 == 0 & MM % 4 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm+=4) {
intel_subgroup_block_prefetch_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
}
}
} else if (KK % 2 == 0 & MM % 2 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm+=2) {
intel_subgroup_block_prefetch_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
}
}
} else if (KK % 2 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm++) {
intel_subgroup_block_prefetch_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
}
}
} else if (MM % 4 == 0) {
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm+=4) {
intel_subgroup_block_prefetch_u16_m32k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
}
}
} else {
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
intel_subgroup_block_prefetch_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
}
}
}
}

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)

Expand All @@ -492,11 +527,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN
const int m = compute_m(SGS_PER_WG, tM, MM);
const int n = get_group_id(0) * tN * NN;

// Initial prefetch:
int prefetch_k = 0;
for (int p = 0; p < PREFETCH_DISTANCE; p++) {
HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k);
HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n);
HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k);
prefetch_k += tK * KK;
}

Expand All @@ -510,18 +544,18 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN
split_barrier_arrive();

for (int k = 0; k < K; k += tK * KK) {
// Next prefetch:
// TODO: skip prefetch on the last iterations.
HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k);
HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n);
prefetch_k += tK * KK;

short8 aData[KK][MM];
HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData);

int8 bData[KK][NN];
HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(B, tN, K, N, k, n, bData);

HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k);
prefetch_k += tK * KK;

for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
Expand Down Expand Up @@ -554,11 +588,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl
const int m = compute_m(SGS_PER_WG, tM, MM);
const int n = get_group_id(0) * tN * NN;

// Initial prefetch:
int prefetch_k = 0;
for (int p = 0; p < PREFETCH_DISTANCE; p++) {
HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k);
HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n);
HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k);
prefetch_k += tK * KK;
}

Expand All @@ -572,18 +605,18 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl
split_barrier_arrive();

for (int k = 0; k < K; k += tK * KK) {
// Next prefetch:
// TODO: skip prefetch on the last iterations.
HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k);
HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n);
prefetch_k += tK * KK;

short8 aData[KK][MM];
HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData);

int8 bData[KK][NN];
HELPER_NAME(btile_load_blockread_vnni, MM, NN)(B, tN, K, N, k, n, bData);

HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k);
prefetch_k += tK * KK;

for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
Expand Down

0 comments on commit 2d23f76

Please sign in to comment.