From d76df7e82f8a2cee9878340891b4d0886a4ec9f9 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 17 Jan 2024 10:58:36 -0800 Subject: [PATCH] add support for a larger A matrix block read --- .../99_matrixexperiments/matrix_helpers.cl | 21 +++++++++------- .../matrix_kernel_tiled.cl | 24 +++++++++++++++---- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index 04a63c5..5e24f68 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -430,10 +430,11 @@ void store_c_rowmajor_fp32_m8_nx(global float* C, float8 v, int rowStart, int co // Define additional "non-vector" block read and writes. These are supported by the hardware but are not in the headers: -ushort __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort4 __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort4 __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort16 __builtin_IB_subgroup_block_read_flat_u16_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); @@ -442,22 +443,26 @@ void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int wid void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data); -ushort intel_subgroup_block_read_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +ushort intel_subgroup_block_read_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -ushort2 intel_subgroup_block_read_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +ushort2 intel_subgroup_block_read_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -ushort4 intel_subgroup_block_read_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +ushort4 intel_subgroup_block_read_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -ushort8 intel_subgroup_block_read_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +ushort8 intel_subgroup_block_read_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } +ushort16 intel_subgroup_block_read_u16_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} uint8 intel_subgroup_block_read_u32_m8k16(const __global void* base_address, int width, int height, int pitch, int2 coord) { diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index a01e0f2..a3b5640 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -202,8 +202,16 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int k = 0; k < K; k += tK) { short8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + if (MM % 2 == 0) { + for (int mm = 0; mm < MM; mm += 2) { + short16 aTemp = as_short16(intel_subgroup_block_read_u16_m16k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + aData[mm + 0] = aTemp.lo; + aData[mm + 1] = aTemp.hi; + } + } else { + for (int mm = 0; mm < MM; mm++) { + aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + } } int8 bData[NN]; @@ -244,8 +252,16 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int k = 0; k < K; k += tK) { short8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + if (MM % 2 == 0) { + for (int mm = 0; mm < MM; mm += 2) { + short16 aTemp = as_short16(intel_subgroup_block_read_u16_m16k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + aData[mm + 0] = aTemp.lo; + aData[mm + 1] = aTemp.hi; + } + } else { + for (int mm = 0; mm < MM; mm++) { + aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + } } int8 bData[NN];