diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 05c2ace..7874d7b 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -274,28 +274,26 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN split_barrier_arrive(); - for (int k = 0; k < K; k += tK) { - short8 aData[MM]; - //if (MM % 2 == 0) { - // for (int mm = 0; mm < MM; mm += 2) { - // short16 aTemp = as_short16(intel_subgroup_block_read_u16_m16k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); - // aData[mm + 0] = aTemp.lo; - // aData[mm + 1] = aTemp.hi; - // } - //} else { + for (int k = 0; k < K; k += tK * KK) { + short8 aData[KK][MM]; + for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm++) { - aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); } - //} + } - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k))); + int8 bData[KK][NN]; + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK))); + } } - for (int nn = 0; nn < NN; nn++) { - for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + } } } @@ -331,28 +329,26 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl split_barrier_arrive(); - for (int k = 0; k < K; k += tK) { - short8 aData[MM]; - //if (MM % 2 == 0) { - // for (int mm = 0; mm < MM; mm += 2) { - // short16 aTemp = as_short16(intel_subgroup_block_read_u16_m16k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); - // aData[mm + 0] = aTemp.lo; - // aData[mm + 1] = aTemp.hi; - // } - //} else { + for (int k = 0; k < K; k += tK * KK) { + short8 aData[KK][MM]; + for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm++) { - aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); } - //} + } - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, k / 2))); + int8 bData[KK][NN]; + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + } } - for (int nn = 0; nn < NN; nn++) { - for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + } } }