diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index b77ede3..e4ad091 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -487,7 +487,8 @@ ushort4 __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(long baseoffset, int ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); ushort16 __builtin_IB_subgroup_block_read_flat_u16_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint16 __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); @@ -519,6 +520,10 @@ uint8 intel_subgroup_block_read_u32_m8k16(const __global void* base_address, int { return __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } +uint16 intel_subgroup_block_read_u32_m16k16(const __global void* base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} void intel_subgroup_block_write_u32_m1k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) { diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 47a5c95..98ddfa9 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -398,9 +398,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl } int8 bData[KK][NN]; - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + int16 bTemp = as_int16(intel_subgroup_block_read_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + bData[kk + 0][nn] = bTemp.lo; + bData[kk + 1][nn] = bTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + } } }