diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index e4ce38d..e8a08ff 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -55,13 +55,20 @@ float8 activation(float8 f) typedef global ushort* global_aligned_ushort_ptr __attribute__((align_value(4))); -inline int compute_m(const int num_sgs, const int tM, const int MM) +inline int compute_m(const int num_sgs_x, const int num_sgs_y, const int tM, const int MM) { - const int m_start = get_group_id(1) * num_sgs; - const int m_index = num_sgs > 1 ? m_start + get_sub_group_id() : m_start; + const int m_start = get_group_id(1) * num_sgs_y; + const int m_index = num_sgs_y > 1 ? m_start + get_sub_group_id() / num_sgs_x : m_start; return m_index * tM * MM; } +inline int compute_n(const int num_sgs_x, const int num_sgs_y, const int tN, const int NN) +{ + const int n_start = get_group_id(0) * num_sgs_x; + const int n_index = num_sgs_x > 1 ? n_start + get_sub_group_id() % num_sgs_x : n_start; + return n_index * tN * NN; +} + // Emulated SIMD8 dpas: __attribute__((overloadable)) float emu_sub_group_bf16_bf16_matrix_mad_k16(int a, int8 b, float acc) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 3421ee4..beea5ab 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -31,9 +31,12 @@ #define HELPER_NAMEX(PREFIX, MM, NN) PREFIX ## _m ## MM ## _n ## NN #define HELPER_NAME(PREFIX, MM, NN) HELPER_NAMEX(PREFIX, MM, NN) -#if !defined(SGS_PER_WG) -// Launch four subgroups per work-group, to maximize cache reuse. -#define SGS_PER_WG 4 +#if !defined(SGS_PER_WG_X) +#define SGS_PER_WG_X 1 +#endif + +#if !defined(SGS_PER_WG_Y) +#define SGS_PER_WG_Y 4 #endif #if !defined(PREFETCH_DISTANCE) @@ -106,14 +109,14 @@ void HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int } } -__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1))) +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { const int tM = 8; const int tN = 8; const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG, tM, MM); - const int n = get_group_id(0) * tN * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); // Initial prefetch: int prefetch_k = 0; @@ -167,14 +170,14 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl } } -__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1))) +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { const int tM = 8; const int tN = 8; const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG, tM, MM); - const int n = get_group_id(0) * tN * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); // Initial prefetch: int prefetch_k = 0; @@ -276,14 +279,14 @@ void HELPER_NAME(atile_load_rowmajor, MM, NN)(global ushort* A, int tM, int K, i } } -__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { const int tM = 8; const int tN = 16; const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG, tM, MM); - const int n = get_group_id(0) * tN * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); // Initial prefetch: int prefetch_k = 0; @@ -337,14 +340,14 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f } } -__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { const int tM = 8; const int tN = 16; const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG, tM, MM); - const int n = get_group_id(0) * tN * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); // Initial prefetch: int prefetch_k = 0; @@ -592,7 +595,7 @@ void HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(global ushort* B, int tN, in } } -__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) { @@ -600,8 +603,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN const int tN = 16; const int M = get_global_size(1) * tM * MM; const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG, tM, MM); - const int n = get_group_id(0) * tN * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { @@ -652,15 +655,15 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN } } -__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 8; const int tN = 16; const int M = get_global_size(1) * tM * MM; const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG, tM, MM); - const int n = get_group_id(0) * tN * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) {