Skip to content

Commit

Permalink
start to add support for loading two K tiles at once
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Jan 22, 2024
1 parent 83185fd commit a24a5b0
Showing 1 changed file with 46 additions and 4 deletions.
50 changes: 46 additions & 4 deletions samples/99_matrixexperiments/matrix_helpers.cl
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,30 @@ int8 load_a_rowmajor_d16_m8_k16_sg8(global ushort* A, int rowStart, int colStart
return ret;
}

// M rows x K columns x V tiles (in the K dimension)
// This is the SIMD8 version, where each work-item loads two values.
// The first tile is returned the first components of the return value, the the next tile, etc.
int16 load_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int colStart, int stride)
{
int16 ret;

global uint* A_ui = (global uint*)A;
int offset_ui = rowStart * stride / 2 + colStart / 2;

ret.s08 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2;
ret.s19 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2;
ret.s2a = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2;
ret.s3b = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2;
ret.s4c = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2;
ret.s5d = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2;
ret.s6e = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2;
ret.s7f = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2;

return ret;
}

// M rows x K columns
// This is the SIMD16 version, where each work-item loads one values.
// This is the SIMD16 version, where each work-item loads one value.
short load_a_rowmajor_d16_m1_k16_sg16(global ushort* A, int rowStart, int colStart, int stride)
{
ushort ret;
Expand All @@ -229,7 +251,7 @@ short load_a_rowmajor_d16_m1_k16_sg16(global ushort* A, int rowStart, int colSta
}

// M rows x K columns
// This is the SIMD16 version, where each work-item loads one values.
// This is the SIMD16 version, where each work-item loads one value.
short2 load_a_rowmajor_d16_m2_k16_sg16(global ushort* A, int rowStart, int colStart, int stride)
{
ushort2 ret;
Expand All @@ -242,7 +264,7 @@ short2 load_a_rowmajor_d16_m2_k16_sg16(global ushort* A, int rowStart, int colSt
}

// M rows x K columns
// This is the SIMD16 version, where each work-item loads one values.
// This is the SIMD16 version, where each work-item loads one value.
short4 load_a_rowmajor_d16_m4_k16_sg16(global ushort* A, int rowStart, int colStart, int stride)
{
ushort4 ret;
Expand All @@ -257,7 +279,7 @@ short4 load_a_rowmajor_d16_m4_k16_sg16(global ushort* A, int rowStart, int colSt
}

// M rows x K columns
// This is the SIMD16 version, where each work-item loads one values.
// This is the SIMD16 version, where each work-item loads one value.
short8 load_a_rowmajor_d16_m8_k16_sg16(global ushort* A, int rowStart, int colStart, int stride)
{
ushort8 ret;
Expand All @@ -275,6 +297,26 @@ short8 load_a_rowmajor_d16_m8_k16_sg16(global ushort* A, int rowStart, int colSt
return as_short8(ret);
}

// M rows x K columns x V tiles (in the K dimension)
// This is the SIMD16 version, where each work-item loads one value.
// The first tile is returned the first components of the return value, the the next tile, etc.
short16 load_a_rowmajor_d16_m8_k16v2_sg16(global ushort* A, int rowStart, int colStart, int stride)
{
short16 ret;

int offset = rowStart * stride + colStart;
ret.s08 = intel_sub_group_block_read_us2(A + offset); offset += stride / 2;
ret.s19 = intel_sub_group_block_read_us2(A + offset); offset += stride / 2;
ret.s2a = intel_sub_group_block_read_us2(A + offset); offset += stride / 2;
ret.s3b = intel_sub_group_block_read_us2(A + offset); offset += stride / 2;
ret.s4c = intel_sub_group_block_read_us2(A + offset); offset += stride / 2;
ret.s5d = intel_sub_group_block_read_us2(A + offset); offset += stride / 2;
ret.s6e = intel_sub_group_block_read_us2(A + offset); offset += stride / 2;
ret.s7f = intel_sub_group_block_read_us2(A + offset); offset += stride / 2;

return ret;
}

// K rows x N columns:
// Each work-item loads K values and converts to VNNI.
// Stride is in units of elements.
Expand Down

0 comments on commit a24a5b0

Please sign in to comment.