Skip to content

Commit

Permalink
try a smaller cooperative prefetch for the B matrix for the rowmajor …
Browse files Browse the repository at this point in the history
…case
  • Loading branch information
bashbaug committed Mar 14, 2024
1 parent afca843 commit 7176d53
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
12 changes: 6 additions & 6 deletions samples/99_matrixexperiments/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -881,12 +881,12 @@ int main(int argc, char** argv)
}

if (mask & 0x400) {
bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref);
//bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref);
//bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref);
//bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
//bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
//bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
//bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref);
bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref);
}

Expand Down
26 changes: 16 additions & 10 deletions samples/99_matrixexperiments/matrix_kernel_tiled.cl
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,9 @@ void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global ushort* B, int tN, in
if (KK % 2 == 0 & NN % 2 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int nn = 0; nn < NN; nn+=2) {
//if (get_sub_group_local_id() == 0) {
// printf("btile block load: %d, %d, %2d: n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), n, k, nn, kk, n + nn * tN, k + kk * tK);
//}
int8 tmp[2][2];
intel_subgroup_block_read_transform_u16_k32n16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), tmp);
for (int tnn = 0; tnn < 2; tnn++) {
Expand Down Expand Up @@ -555,11 +558,14 @@ void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM

void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n)
{
if (KK % 2 == 0 & NN == 4 & SGS_PER_WG_Y >= 2) {
const int nn = (get_sub_group_id() / SGS_PER_WG_X) % 2 * 2;
for (int kk = 0; kk < KK; kk+=2) {
intel_subgroup_block_prefetch_u16_m32k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
}
if (KK == 2 & NN == 4 & SGS_PER_WG_Y >= 4) {
const int sg_index_y = get_sub_group_id() / SGS_PER_WG_X; // index in [0, SGS_PER_WG_Y)
const int nn = sg_index_y % 2 * 2; // nn(sg_index_y) == 0, 2, 0, 2, 0, 2, 0, 2, ...
const int kk = sg_index_y / 2 % 2; // kk(sg_index_y) == 0, 0, 1, 1, 0, 0, 1, 1, ...
//if (get_sub_group_local_id() == 0) {
// printf("btile block prefetch: %d, %d, %2d: sg_y = %d, n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_y, n, k, nn, kk, n + nn * tN, k + kk * tK);
//}
intel_subgroup_block_prefetch_u16_m16k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
} else if (KK % 2 == 0 & NN % 2 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int nn = 0; nn < NN; nn += 2) {
Expand Down Expand Up @@ -589,11 +595,11 @@ void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN

void HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n)
{
if (KK % 2 == 0 & NN == 4 & SGS_PER_WG_Y >= 4) {
const int nn = (get_sub_group_id() / SGS_PER_WG_X) % 4;
for (int kk = 0; kk < KK; kk+=2) {
intel_subgroup_block_prefetch_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2));
}
if (KK == 2 & NN == 4 & SGS_PER_WG_Y >= 4) {
const int sg_index_y = get_sub_group_id() / SGS_PER_WG_X; // index in [0, SGS_PER_WG_Y)
const int nn = sg_index_y % 4; // nn(sg_index_y) == 0, 1, 2, 3, 0, 1, 2, 3
const int kk = 0; // kk(sg_index_y) == 0, 0, 0, 0, 0, 0, 0, 0
intel_subgroup_block_prefetch_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2));
} else if (KK % 2 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int nn = 0; nn < NN; nn++) {
Expand Down

0 comments on commit 7176d53

Please sign in to comment.