diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index ce36217..6398e12 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -216,8 +216,30 @@ int8 load_a_rowmajor_d16_m8_k16_sg8(global ushort* A, int rowStart, int colStart return ret; } +// M rows x K columns x V tiles (in the K dimension) +// This is the SIMD8 version, where each work-item loads two values. +// The first tile is returned the first components of the return value, the the next tile, etc. +int16 load_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int colStart, int stride) +{ + int16 ret; + + global uint* A_ui = (global uint*)A; + int offset_ui = rowStart * stride / 2 + colStart / 2; + + ret.s08 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s19 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s2a = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s3b = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s4c = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s5d = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s6e = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s7f = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + + return ret; +} + // M rows x K columns -// This is the SIMD16 version, where each work-item loads one values. +// This is the SIMD16 version, where each work-item loads one value. short load_a_rowmajor_d16_m1_k16_sg16(global ushort* A, int rowStart, int colStart, int stride) { ushort ret; @@ -229,7 +251,7 @@ short load_a_rowmajor_d16_m1_k16_sg16(global ushort* A, int rowStart, int colSta } // M rows x K columns -// This is the SIMD16 version, where each work-item loads one values. +// This is the SIMD16 version, where each work-item loads one value. short2 load_a_rowmajor_d16_m2_k16_sg16(global ushort* A, int rowStart, int colStart, int stride) { ushort2 ret; @@ -242,7 +264,7 @@ short2 load_a_rowmajor_d16_m2_k16_sg16(global ushort* A, int rowStart, int colSt } // M rows x K columns -// This is the SIMD16 version, where each work-item loads one values. +// This is the SIMD16 version, where each work-item loads one value. short4 load_a_rowmajor_d16_m4_k16_sg16(global ushort* A, int rowStart, int colStart, int stride) { ushort4 ret; @@ -257,7 +279,7 @@ short4 load_a_rowmajor_d16_m4_k16_sg16(global ushort* A, int rowStart, int colSt } // M rows x K columns -// This is the SIMD16 version, where each work-item loads one values. +// This is the SIMD16 version, where each work-item loads one value. short8 load_a_rowmajor_d16_m8_k16_sg16(global ushort* A, int rowStart, int colStart, int stride) { ushort8 ret; @@ -275,6 +297,26 @@ short8 load_a_rowmajor_d16_m8_k16_sg16(global ushort* A, int rowStart, int colSt return as_short8(ret); } +// M rows x K columns x V tiles (in the K dimension) +// This is the SIMD16 version, where each work-item loads one value. +// The first tile is returned the first components of the return value, the the next tile, etc. +short16 load_a_rowmajor_d16_m8_k16v2_sg16(global ushort* A, int rowStart, int colStart, int stride) +{ + short16 ret; + + int offset = rowStart * stride + colStart; + ret.s08 = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; + ret.s19 = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; + ret.s2a = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; + ret.s3b = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; + ret.s4c = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; + ret.s5d = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; + ret.s6e = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; + ret.s7f = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; + + return ret; +} + // K rows x N columns: // Each work-item loads K values and converts to VNNI. // Stride is in units of elements.