Skip to content

Commit

Permalink
add support for larger work-groups in both dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Mar 6, 2024
1 parent 2ea4d56 commit f60930a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 24 deletions.
13 changes: 10 additions & 3 deletions samples/99_matrixexperiments/matrix_helpers.cl
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,20 @@ float8 activation(float8 f)

typedef global ushort* global_aligned_ushort_ptr __attribute__((align_value(4)));

inline int compute_m(const int num_sgs, const int tM, const int MM)
inline int compute_m(const int num_sgs_x, const int num_sgs_y, const int tM, const int MM)
{
const int m_start = get_group_id(1) * num_sgs;
const int m_index = num_sgs > 1 ? m_start + get_sub_group_id() : m_start;
const int m_start = get_group_id(1) * num_sgs_y;
const int m_index = num_sgs_y > 1 ? m_start + get_sub_group_id() / num_sgs_x : m_start;
return m_index * tM * MM;
}

inline int compute_n(const int num_sgs_x, const int num_sgs_y, const int tN, const int NN)
{
const int n_start = get_group_id(0) * num_sgs_x;
const int n_index = num_sgs_x > 1 ? n_start + get_sub_group_id() % num_sgs_x : n_start;
return n_index * tN * NN;
}

// Emulated SIMD8 dpas:
__attribute__((overloadable))
float emu_sub_group_bf16_bf16_matrix_mad_k16(int a, int8 b, float acc)
Expand Down
45 changes: 24 additions & 21 deletions samples/99_matrixexperiments/matrix_kernel_tiled.cl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,12 @@
#define HELPER_NAMEX(PREFIX, MM, NN) PREFIX ## _m ## MM ## _n ## NN
#define HELPER_NAME(PREFIX, MM, NN) HELPER_NAMEX(PREFIX, MM, NN)

#if !defined(SGS_PER_WG)
// Launch four subgroups per work-group, to maximize cache reuse.
#define SGS_PER_WG 4
#if !defined(SGS_PER_WG_X)
#define SGS_PER_WG_X 1
#endif

#if !defined(SGS_PER_WG_Y)
#define SGS_PER_WG_Y 4
#endif

#if !defined(PREFETCH_DISTANCE)
Expand Down Expand Up @@ -106,14 +109,14 @@ void HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int
}
}

__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1)))
__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8 * SGS_PER_WG_X, SGS_PER_WG_Y, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K)
{
const int tM = 8;
const int tN = 8;
const int N = get_global_size(0) * NN;
const int m = compute_m(SGS_PER_WG, tM, MM);
const int n = get_group_id(0) * tN * NN;
const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM);
const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN);

// Initial prefetch:
int prefetch_k = 0;
Expand Down Expand Up @@ -167,14 +170,14 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl
}
}

__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1)))
__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8 * SGS_PER_WG_X, SGS_PER_WG_Y, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K)
{
const int tM = 8;
const int tN = 8;
const int N = get_global_size(0) * NN;
const int m = compute_m(SGS_PER_WG, tM, MM);
const int n = get_group_id(0) * tN * NN;
const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM);
const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN);

// Initial prefetch:
int prefetch_k = 0;
Expand Down Expand Up @@ -276,14 +279,14 @@ void HELPER_NAME(atile_load_rowmajor, MM, NN)(global ushort* A, int tM, int K, i
}
}

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1)))
__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K)
{
const int tM = 8;
const int tN = 16;
const int N = get_global_size(0) * NN;
const int m = compute_m(SGS_PER_WG, tM, MM);
const int n = get_group_id(0) * tN * NN;
const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM);
const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN);

// Initial prefetch:
int prefetch_k = 0;
Expand Down Expand Up @@ -337,14 +340,14 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f
}
}

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1)))
__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K)
{
const int tM = 8;
const int tN = 16;
const int N = get_global_size(0) * NN;
const int m = compute_m(SGS_PER_WG, tM, MM);
const int n = get_group_id(0) * tN * NN;
const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM);
const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN);

// Initial prefetch:
int prefetch_k = 0;
Expand Down Expand Up @@ -592,16 +595,16 @@ void HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(global ushort* B, int tN, in
}
}

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1)))
__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)

{
const int tM = 8;
const int tN = 16;
const int M = get_global_size(1) * tM * MM;
const int N = get_global_size(0) * NN;
const int m = compute_m(SGS_PER_WG, tM, MM);
const int n = get_group_id(0) * tN * NN;
const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM);
const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN);

int prefetch_k = 0;
for (int p = 0; p < PREFETCH_DISTANCE; p++) {
Expand Down Expand Up @@ -652,15 +655,15 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN
}
}

__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1)))
__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1)))
kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K)
{
const int tM = 8;
const int tN = 16;
const int M = get_global_size(1) * tM * MM;
const int N = get_global_size(0) * NN;
const int m = compute_m(SGS_PER_WG, tM, MM);
const int n = get_group_id(0) * tN * NN;
const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM);
const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN);

int prefetch_k = 0;
for (int p = 0; p < PREFETCH_DISTANCE; p++) {
Expand Down

0 comments on commit f60930a

Please sign in to comment.