Skip to content

Commit

Permalink
add support for wide K block reads
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Jan 24, 2024
1 parent ab84bbe commit 4ae2d95
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions samples/99_matrixexperiments/matrix_kernel_tiled.cl
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN

for (int k = 0; k < K; k += tK * KK) {
short8 aData[KK][MM];
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)));
if (KK % 2 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm++) {
short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)));
aData[kk + 0][mm] = aTemp.lo;
aData[kk + 1][mm] = aTemp.hi;
}
}
} else {
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)));
}
}
}

Expand Down Expand Up @@ -371,9 +381,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl

for (int k = 0; k < K; k += tK * KK) {
short8 aData[KK][MM];
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)));
if (KK % 2 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm++) {
short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)));
aData[kk + 0][mm] = aTemp.lo;
aData[kk + 1][mm] = aTemp.hi;
}
}
} else {
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)));
}
}
}

Expand Down Expand Up @@ -406,6 +426,3 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl
}

#endif // cl_intel_subgroup_extended_block_read

#undef KK
#undef SGS_PER_WG

0 comments on commit 4ae2d95

Please sign in to comment.