Skip to content

Commit

Permalink
use more helper functions for DG2 tiled kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Mar 2, 2024
1 parent b034c1d commit 697754d
Showing 1 changed file with 49 additions and 98 deletions.
147 changes: 49 additions & 98 deletions samples/99_matrixexperiments/matrix_kernel_tiled.cl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@
#define PREFETCH_DISTANCE 1
#endif

void HELPER_NAME(btile_load_rowmajor, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[KK][NN])
{
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N);
}
}
}

void HELPER_NAME(btile_load_vnni, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[KK][NN])
{
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N);
}
}
}

#if HAS_SIMD8

void HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int K, int m, int prefetch_k)
Expand Down Expand Up @@ -69,6 +87,25 @@ void HELPER_NAME(btile_prefetch_vnni_sg8, MM, NN)(global ushort* B, int tN, int
}
}

void HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int K, int m, int k, int8 aData[KK][MM])
{
if (KK % 2 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm++) {
int16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, k + kk * tK, K);
aData[kk + 0][mm] = aTemp.lo;
aData[kk + 1][mm] = aTemp.hi;
}
}
} else {
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K);
}
}
}
}

__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K)
{
Expand All @@ -81,16 +118,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl
// Initial prefetch:
int prefetch_k = 0;
for (int p = 0; p < PREFETCH_DISTANCE; p++) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm++) {
prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K);
}
}
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn+=4) {
prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, prefetch_k + kk * tK, n + nn * tN, N);
}
}
HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k);
HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n);
prefetch_k += tK * KK;
}

Expand All @@ -106,41 +135,15 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl
for (int k = 0; k < K; k += tK * KK) {
// Next prefetch:
// TODO: skip prefetch on the last iterations.
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm++) {
prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K);
}
}
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn+=4) {
prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, prefetch_k + kk * tK, n + nn * tN, N);
}
}
HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k);
HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n);
prefetch_k += tK * KK;

int8 aData[KK][MM];
if (KK % 2 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm++) {
int16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, k + kk * tK, K);
aData[kk + 0][mm] = aTemp.lo;
aData[kk + 1][mm] = aTemp.hi;
}
}
} else {
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K);
}
}
}
HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(A, tM, K, m, k, aData);

int8 bData[KK][NN];
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N);
}
}
HELPER_NAME(btile_load_rowmajor, MM, NN)(B, tN, N, k, n, bData);

for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
Expand Down Expand Up @@ -176,16 +179,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float*
// Initial prefetch:
int prefetch_k = 0;
for (int p = 0; p < PREFETCH_DISTANCE; p++) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm++) {
prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K);
}
}
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn+=2) {
prefetch_b_vnni_d16_k16_n8v2_sg8(B, prefetch_k + kk * tK, n + nn * tN, N);
}
}
HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k);
HELPER_NAME(btile_prefetch_vnni_sg8, MM, NN)(B, tN, N, prefetch_k, n);
prefetch_k += tK * KK;
}

Expand All @@ -201,41 +196,15 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float*
for (int k = 0; k < K; k += tK * KK) {
// Next prefetch:
// TODO: skip prefetch on the last iterations.
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm++) {
prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K);
}
}
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn+=2) {
prefetch_b_vnni_d16_k16_n8v2_sg8(B, prefetch_k + kk * tK, n + nn * tN, N);
}
}
HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k);
HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n);
prefetch_k += tK * KK;

int8 aData[KK][MM];
if (KK % 2 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm++) {
int16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, k + kk * tK, K);
aData[kk + 0][mm] = aTemp.lo;
aData[kk + 1][mm] = aTemp.hi;
}
}
} else {
for (int kk = 0; kk < KK; kk++) {
for (int mm = 0; mm < MM; mm++) {
aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K);
}
}
}
HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(A, tM, K, m, k, aData);

int8 bData[KK][NN];
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N);
}
}
HELPER_NAME(btile_load_vnni, MM, NN)(B, tN, N, k, n, bData);

for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
Expand Down Expand Up @@ -307,24 +276,6 @@ void HELPER_NAME(atile_load_rowmajor, MM, NN)(global ushort* A, int tM, int K, i
}
}

void HELPER_NAME(btile_load_rowmajor, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[KK][NN])
{
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N);
}
}
}

void HELPER_NAME(btile_load_vnni, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[KK][NN])
{
for (int kk = 0; kk < KK; kk++) {
for (int nn = 0; nn < NN; nn++) {
bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N);
}
}
}

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K)
{
Expand Down

0 comments on commit 697754d

Please sign in to comment.