From 2d23f76891e51a98dd64e62078c4613f44fa2693 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Sat, 2 Mar 2024 17:33:30 -0800 Subject: [PATCH] add support for 2D block prefetches --- .../99_matrixexperiments/matrix_helpers.cl | 76 +++++++++++++++++-- .../matrix_kernel_tiled.cl | 53 ++++++++++--- 2 files changed, 113 insertions(+), 16 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index ee71681..97b2ce5 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -290,7 +290,7 @@ void prefetch_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int co #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; __builtin_assume((ulong)(A + offset) % 4 == 0); - prefetch(A + offset, 1); + prefetch(A + offset, 2); #endif // defined(PREFETCH_DEFAULT) } @@ -379,7 +379,7 @@ void prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(global ushort* A, int rowStart, int #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; __builtin_assume((ulong)(A + offset) % 4 == 0); - prefetch(A + offset, 1); + prefetch(A + offset, 2); #endif // defined(PREFETCH_DEFAULT) } @@ -449,9 +449,9 @@ void prefetch_b_rowmajor_d16_k16_n8v4_sg8(global ushort* B, int rowStart, int co #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; __builtin_assume((ulong)(B + offset) % 4 == 0); - prefetch(B + offset, 1); offset += 8 * stride; + prefetch(B + offset, 2); offset += 8 * stride; __builtin_assume((ulong)(B + offset) % 4 == 0); - prefetch(B + offset, 1); offset += 8 * stride; + prefetch(B + offset, 2); offset += 8 * stride; #endif // defined(PREFETCH_DEFAULT) } @@ -461,7 +461,7 @@ void prefetch_b_rowmajor_d16_k16_n16v2_sg16(global ushort* B, int rowStart, int #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; __builtin_assume((ulong)(B + offset) % 4 == 0); - prefetch(B + offset, 1); + prefetch(B + offset, 2); #endif // defined(PREFETCH_DEFAULT) } @@ -587,10 +587,21 @@ void store_c_rowmajor_fp32_m8_nx(global float* C, float8 v, int rowStart, int co // - tile width: subgroup size (16) // - number of tiles: 1 +enum LSC_LDCC { + LSC_LDCC_DEFAULT = 0, + LSC_LDCC_L1UC_L3UC = 1, // Override to L1 uncached and L3 uncached + LSC_LDCC_L1UC_L3C = 2, // Override to L1 uncached and L3 cached + LSC_LDCC_L1C_L3UC = 3, // Override to L1 cached and L3 uncached + LSC_LDCC_L1C_L3C = 4, // Override to L1 cached and L3 cached + LSC_LDCC_L1S_L3UC = 5, // Override to L1 streaming load and L3 uncached + LSC_LDCC_L1S_L3C = 6, // Override to L1 streaming load and L3 cached + LSC_LDCC_L1IAR_L3C = 7, // Override to L1 invalidate-after-read, and L3 cached +}; + typedef ushort __attribute__((ext_vector_type(32))) ushort32; typedef ushort __attribute__((ext_vector_type(64))) ushort64; -// Define block reads and writes. These are supported by the hardware but are not in the headers: +// Define block reads, prefetches, and writes. These are supported by the hardware but are not in the headers: ushort __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); @@ -605,6 +616,19 @@ ushort64 __builtin_IB_subgroup_block_read_flat_u16_m32k16v2(long baseoffset, int uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); uint16 __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +void __builtin_IB_subgroup_block_read_prefetch_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); + +void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); + + void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); @@ -669,6 +693,46 @@ uint16 intel_subgroup_block_read_u32_m16k16(const __global void* base_address, i return __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } +#define BLOCK_PREFETCH_CACHE_TYPE LSC_LDCC_L1C_L3C + +void intel_subgroup_block_prefetch_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + __builtin_IB_subgroup_block_read_prefetch_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +} +void intel_subgroup_block_prefetch_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + __builtin_IB_subgroup_block_read_prefetch_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +} +void intel_subgroup_block_prefetch_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + __builtin_IB_subgroup_block_read_prefetch_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +} +void intel_subgroup_block_prefetch_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +} +void intel_subgroup_block_prefetch_u16_m8k16v2(__global void *base_address, int width, int height, int pitch, int2 coord) +{ + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +} +void intel_subgroup_block_prefetch_u16_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +} +void intel_subgroup_block_prefetch_u16_m32k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +} +void intel_subgroup_block_prefetch_u16_m16k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +} +void intel_subgroup_block_prefetch_u16_m32k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +} + + void intel_subgroup_block_write_u32_m1k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) { __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 4e7db7b..f1bb189 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -481,6 +481,41 @@ void HELPER_NAME(btile_load_blockread_vnni, MM, NN)(global ushort* B, int tN, in } } +void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k) +{ + if (KK % 2 == 0 & MM % 4 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=4) { + intel_subgroup_block_prefetch_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else if (KK % 2 == 0 & MM % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + intel_subgroup_block_prefetch_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + intel_subgroup_block_prefetch_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else if (MM % 4 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm+=4) { + intel_subgroup_block_prefetch_u16_m32k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + intel_subgroup_block_prefetch_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } +} + __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) @@ -492,11 +527,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN const int m = compute_m(SGS_PER_WG, tM, MM); const int n = get_group_id(0) * tN * NN; - // Initial prefetch: int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); prefetch_k += tK * KK; } @@ -510,11 +544,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN split_barrier_arrive(); for (int k = 0; k < K; k += tK * KK) { - // Next prefetch: // TODO: skip prefetch on the last iterations. - HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); - prefetch_k += tK * KK; short8 aData[KK][MM]; HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); @@ -522,6 +553,9 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN int8 bData[KK][NN]; HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(B, tN, K, N, k, n, bData); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { @@ -554,11 +588,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl const int m = compute_m(SGS_PER_WG, tM, MM); const int n = get_group_id(0) * tN * NN; - // Initial prefetch: int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); prefetch_k += tK * KK; } @@ -572,11 +605,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl split_barrier_arrive(); for (int k = 0; k < K; k += tK * KK) { - // Next prefetch: // TODO: skip prefetch on the last iterations. - HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); - prefetch_k += tK * KK; short8 aData[KK][MM]; HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); @@ -584,6 +614,9 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl int8 bData[KK][NN]; HELPER_NAME(btile_load_blockread_vnni, MM, NN)(B, tN, K, N, k, n, bData); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) {