From 756d2e9decaf63edff9ed23044d5ea27d60878ba Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 17 Jan 2024 14:35:03 -0800 Subject: [PATCH] switch the tiled dpas order We want to prioritize reuse of the A matrix to make best use of read suppression buffers. --- .../matrix_kernel_tiled.cl | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index a3b5640..87fa658 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -49,8 +49,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl } } - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); } } @@ -83,8 +83,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); } - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { sum[mm][nn] = mat_mul_sg8(aData[mm], bData[nn], sum[mm][nn]); } } @@ -126,8 +126,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); } - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); } } @@ -167,8 +167,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); } - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); } } @@ -219,8 +219,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN bData[nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k))); } - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); } } @@ -269,8 +269,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl bData[nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, k / 2))); } - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); } }