Skip to content

Commit

Permalink
add support for more K tiles for the blockread kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Jan 19, 2024
1 parent 16b7cda commit 83185fd
Showing 1 changed file with 30 additions and 34 deletions.
64 changes: 30 additions & 34 deletions samples/99_matrixexperiments/matrix_kernel_tiled.cl
Original file line number Diff line number Diff line change
Expand Up @@ -274,28 +274,26 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN

split_barrier_arrive();

for (int k = 0; k < K; k += tK) {
short8 aData[MM];
//if (MM % 2 == 0) {
// for (int mm = 0; mm < MM; mm += 2) {
// short16 aTemp = as_short16(intel_subgroup_block_read_u16_m16k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM)));
// aData[mm + 0] = aTemp.lo;
// aData[mm + 1] = aTemp.hi;
// }
//} else {
for (int k = 0; k < K; k += tK * KK) {
short8 aData[KK][MM];
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM)));
aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)));
}
//}
}

int8 bData[NN];
for (int nn = 0; nn < NN; nn++) {
bData[nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k)));
int8 bData[KK][NN];
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
bData[kk][nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)));
}
}

for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]);
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]);
}
}
}

Expand Down Expand Up @@ -331,28 +329,26 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl

split_barrier_arrive();

for (int k = 0; k < K; k += tK) {
short8 aData[MM];
//if (MM % 2 == 0) {
// for (int mm = 0; mm < MM; mm += 2) {
// short16 aTemp = as_short16(intel_subgroup_block_read_u16_m16k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM)));
// aData[mm + 0] = aTemp.lo;
// aData[mm + 1] = aTemp.hi;
// }
//} else {
for (int k = 0; k < K; k += tK * KK) {
short8 aData[KK][MM];
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM)));
aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)));
}
//}
}

int8 bData[NN];
for (int nn = 0; nn < NN; nn++) {
bData[nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, k / 2)));
int8 bData[KK][NN];
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
bData[kk][nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)));
}
}

for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]);
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]);
}
}
}

Expand Down

0 comments on commit 83185fd

Please sign in to comment.