diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 9bec204..cd04b29 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -180,3 +180,89 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float } } } + +#ifdef cl_intel_subgroup_extended_block_read + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +{ + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + + for (int k = 0; k < K; k += tK) { + short8 aData[MM]; + for (int mm = 0; mm < MM; mm++) { + aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + } + + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); + } + } + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn])); + } + } +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +{ + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + + for (int k = 0; k < K; k += tK) { + short8 aData[MM]; + for (int mm = 0; mm < MM; mm++) { + aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + } + + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k))); + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); + } + } + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn])); + } + } +} + +#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size)