Skip to content

Commit

Permalink
add tiled block read kernels for PVC
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Jan 15, 2024
1 parent 4e89026 commit c7edcd6
Showing 1 changed file with 86 additions and 0 deletions.
86 changes: 86 additions & 0 deletions samples/99_matrixexperiments/matrix_kernel_tiled.cl
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,89 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float
}
}
}

#ifdef cl_intel_subgroup_extended_block_read

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)
{
const int tM = 8;
const int tN = 16;
const int N = get_global_size(0) * NN;
const int m = get_group_id(1) * tM * MM;
const int n = get_group_id(0) * tN * NN;

float8 sum[MM][NN];
for (int mm = 0; mm < MM; mm++) {
for (int nn = 0; nn < NN; nn++) {
sum[mm][nn] = 0;
}
}

for (int k = 0; k < K; k += tK) {
short8 aData[MM];
for (int mm = 0; mm < MM; mm++) {
aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM)));
}

int8 bData[NN];
for (int nn = 0; nn < NN; nn++) {
bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N);
}

for (int mm = 0; mm < MM; mm++) {
for (int nn = 0; nn < NN; nn++) {
sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]);
}
}
}

for (int mm = 0; mm < MM; mm++) {
for (int nn = 0; nn < NN; nn++) {
intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn]));
}
}
}

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)
{
const int tM = 8;
const int tN = 16;
const int N = get_global_size(0) * NN;
const int m = get_group_id(1) * tM * MM;
const int n = get_group_id(0) * tN * NN;

float8 sum[MM][NN];
for (int mm = 0; mm < MM; mm++) {
for (int nn = 0; nn < NN; nn++) {
sum[mm][nn] = 0;
}
}

for (int k = 0; k < K; k += tK) {
short8 aData[MM];
for (int mm = 0; mm < MM; mm++) {
aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM)));
}

int8 bData[NN];
for (int nn = 0; nn < NN; nn++) {
bData[nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k)));
}

for (int mm = 0; mm < MM; mm++) {
for (int nn = 0; nn < NN; nn++) {
sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]);
}
}
}

for (int mm = 0; mm < MM; mm++) {
for (int nn = 0; nn < NN; nn++) {
intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn]));
}
}
}

#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size)

0 comments on commit c7edcd6

Please sign in to comment.