Skip to content

Commit

Permalink
Enable Joint Matrix Get Coord for SG=32
Browse files Browse the repository at this point in the history
Re-implement get coord built-ins to make them universal
for different sub group sizes. Support get coord with
SG=32
  • Loading branch information
YuriPlyakhin authored and igcbot committed Aug 31, 2023
1 parent 691e0ef commit 0649bad
Showing 1 changed file with 35 additions and 30 deletions.
65 changes: 35 additions & 30 deletions IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl
Original file line number Diff line number Diff line change
Expand Up @@ -595,12 +595,14 @@ DEFINE_STORE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 3, 16, 3x16, ROW_MAJ
DEFINE_STORE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 2, 16, 2x16, ROW_MAJOR, , 1, 16, true)
// DEFINE_STORE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 1, 16, 1x16, ROW_MAJOR, , 1, 16, true) same as for subgroup 16

// get_coord()
#define MANGLE_GETCOORD_NAME(layout, sg, elem_bitwidth, shape) \
__builtin_spirv_OpJointMatrixGetCoordINTEL_##layout##sg##_##shape##_i##elem_bitwidth
/* get_coord() support: */

#define MANGLE_GETCOORD_NAME(layout, sg, elem_bitwidth, R, C) \
__builtin_spirv_OpJointMatrixGetCoordINTEL_##layout##sg##_##R##x##C##_i##elem_bitwidth

/* Explanation of calculation for int8 and bf16 types
Let's say we are considering a JM of use::A, 8x32, of type i8, in Platform PVC.
with sub-group size 16.
<--------- 32----------------------------->
0 0 x x x x ..........................x x ^
Expand All @@ -621,12 +623,9 @@ small o item in work_item_0. The index here is 3. (Please note that index is
the argument of get_coord() call. And each WI has index running 0-15 in this
case, as they hold 16 elements (8x2))
So the calculation becomes
int div_factor = 32 / 16 * 1; // --> 2
int row = index / 2; // 1
int col = (index % 2) + (wi_num * 2); // 1
So the calculation becomes:
row: (wi_id*pack_factor)/K + index/pack_factor*skip_factor --> (0*2)/32 + 3/2*1 = 0 + 1 = 1
col: (wi_id*pack_factor)%K + index%pack_factor --> (0*2)%32 + 3%2 = 0 + 1 = 1
Now, why the index for this particular item is 3 and not 9? That is because
the slice is stored in row-major fashion. So if we have the slice like
Expand All @@ -642,42 +641,48 @@ the following for a WI:
7 7
The storage in memory will be: 0 0 1 1 2 2 ... 7 7
Please note the index of the starred item is 3, not 9.
*/

#define DEFINE_GET_COORD(layout, sg, elem_bitwidth, M, K, shape, sg_size, VF) \
INLINE int2 MANGLE_GETCOORD_NAME(layout, sg, elem_bitwidth, shape) (int index) { \
int wi_num = get_sub_group_local_id(); \
int div_factor = (K/sg_size)*VF; \
int row = index/div_factor; \
int col = (index%div_factor) + (wi_num*div_factor); \
// R - number of rows
// C - number of columns
// VF - VNNI Factor
#define DEFINE_GET_COORD(layout, sg, elem_bitwidth, contrib_bitwidth, R, C, VF) \
INLINE int2 MANGLE_GETCOORD_NAME(layout, sg, elem_bitwidth, R, C) (int index) { \
int sg_size = get_sub_group_size(); \
int wi_id = get_sub_group_local_id(); \
int pack_factor = contrib_bitwidth / elem_bitwidth; \
int sg_cols = (C*VF) / pack_factor; \
int skip_factor = sg_size / sg_cols; \
int row = (wi_id*pack_factor)/(C*VF) + index/pack_factor*skip_factor; \
int col = (wi_id * pack_factor) % (C*VF) + index % pack_factor; \
int2 result = (int2)(row, col); \
return result; \
}

// ------ PVC -------
// layout, sg, elem_bitwidth, M, K, shape, sg_size, VF
// ------ PVC -------
// layout, sg, elem_bitwidth, contrib_bitwidth, R, C, VF
//int8
DEFINE_GET_COORD(PackedA, _SG16, 8, 8, 32, 8x32, 16, 1)
DEFINE_GET_COORD(PackedB, _SG16, 8, 32, 16, 32x16, 16, 4)
DEFINE_GET_COORD(Accumulator, _SG16, 32, 8, 16, 8x16, 16, 1)
DEFINE_GET_COORD(PackedA, _SG16, 8, 16, 8, 32, 1)
DEFINE_GET_COORD(PackedB, _SG16, 8, 32, 32, 16, 4)

//bfloat16
DEFINE_GET_COORD(PackedA, _SG16, 16, 8, 16, 8x16, 16, 1)
DEFINE_GET_COORD(PackedB, _SG16, 16, 16, 16, 32x16, 16, 2)

DEFINE_GET_COORD(PackedA, _SG16, 16, 16, 8, 16, 1)
DEFINE_GET_COORD(PackedB, _SG16, 16, 32, 16, 16, 2)

// Accumulator
DEFINE_GET_COORD(Accumulator, _SG16, 32, 32, 8, 16, 1)

// --------- XMX8 ------------
//int8
DEFINE_GET_COORD(PackedA, , 8, 8, 32, 8x32, 8, 1)
DEFINE_GET_COORD(PackedB, , 8, 32, 8, 32x8, 8, 4)
DEFINE_GET_COORD(Accumulator, , 32, 8, 8, 8x8, 8, 1)
DEFINE_GET_COORD(PackedA, , 8, 32, 8, 32, 1)
DEFINE_GET_COORD(PackedB, , 8, 32, 32, 8, 4)

//bfloat16
DEFINE_GET_COORD(PackedA, , 16, 8, 16, 8x16, 8, 1)
DEFINE_GET_COORD(PackedB, , 16, 16, 8, 16x8, 8, 2)
DEFINE_GET_COORD(PackedA, , 16, 32, 8, 16, 1)
DEFINE_GET_COORD(PackedB, , 16, 32, 16, 8, 2)

// Accumulator
DEFINE_GET_COORD(Accumulator, , 32, 32, 8, 8, 1)

/* experimental large slice support: */

Expand Down

0 comments on commit 0649bad

Please sign in to comment.