diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 87fa658..58aba82 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -202,17 +202,17 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int k = 0; k < K; k += tK) { short8 aData[MM]; - if (MM % 2 == 0) { - for (int mm = 0; mm < MM; mm += 2) { - short16 aTemp = as_short16(intel_subgroup_block_read_u16_m16k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); - aData[mm + 0] = aTemp.lo; - aData[mm + 1] = aTemp.hi; - } - } else { + //if (MM % 2 == 0) { + // for (int mm = 0; mm < MM; mm += 2) { + // short16 aTemp = as_short16(intel_subgroup_block_read_u16_m16k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + // aData[mm + 0] = aTemp.lo; + // aData[mm + 1] = aTemp.hi; + // } + //} else { for (int mm = 0; mm < MM; mm++) { aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); } - } + //} int8 bData[NN]; for (int nn = 0; nn < NN; nn++) { @@ -252,17 +252,17 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int k = 0; k < K; k += tK) { short8 aData[MM]; - if (MM % 2 == 0) { - for (int mm = 0; mm < MM; mm += 2) { - short16 aTemp = as_short16(intel_subgroup_block_read_u16_m16k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); - aData[mm + 0] = aTemp.lo; - aData[mm + 1] = aTemp.hi; - } - } else { + //if (MM % 2 == 0) { + // for (int mm = 0; mm < MM; mm += 2) { + // short16 aTemp = as_short16(intel_subgroup_block_read_u16_m16k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + // aData[mm + 0] = aTemp.lo; + // aData[mm + 1] = aTemp.hi; + // } + //} else { for (int mm = 0; mm < MM; mm++) { aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); } - } + //} int8 bData[NN]; for (int nn = 0; nn < NN; nn++) {