Skip to content

Commit

Permalink
add support for PVC, which does not support SIMD8
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Jan 6, 2024
1 parent 52b9550 commit 074e0a5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
7 changes: 5 additions & 2 deletions samples/99_matrixexperiments/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,18 +304,21 @@ int main(int argc, char** argv)
printf("Running on device: %s\n",
device.getInfo<CL_DEVICE_NAME>().c_str() );

auto minSubGroupSize = findMinSubGroupSize(device);

bool has_simd8 = minSubGroupSize == 8;
bool emulate_tN8 = true;
bool emulate_tN16 = true;
if (!emulate && checkDeviceForExtension(device, "cl_intel_subgroup_matrix_multiply_accumulate")) {
auto minSubGroupSize = findMinSubGroupSize(device);
printf("Found support for cl_intel_subgroup_matrix_multiply_accumulate, min sub-group size is: %zu\n", minSubGroupSize);
switch(minSubGroupSize) {
case 8: emulate_tN8 = false; break;
case 8: emulate_tN8 = false; break;
case 16: emulate_tN16 = false; break;
default: break;
}
}

buildOptions += " -DHAS_SIMD8=" + std::to_string(has_simd8);
buildOptions += " -DEMULATE_tN8=" + std::to_string(emulate_tN8);
buildOptions += " -DEMULATE_tN16=" + std::to_string(emulate_tN16);

Expand Down
8 changes: 8 additions & 0 deletions samples/99_matrixexperiments/matrix_kernels.cl
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,8 @@ static void __store_c_row_major_fp32_m8(global float* C, float8 v, int rowStart,
intel_sub_group_block_write(C_ui + offset, v_ui.s7); offset += stride;
}

#if HAS_SIMD8

__attribute__((intel_reqd_sub_group_size(8)))
__attribute__((reqd_work_group_size(8, 1, 1)))
kernel void bfloat16_dpas_rowmajor_m1_n8(global float* C, global ushort* A, global ushort* B, int K)
Expand Down Expand Up @@ -473,6 +475,8 @@ kernel void bfloat16_dpas_rowmajor_m8_n8(global float* C, global ushort* A, glob
__store_c_row_major_fp32_m8(C, sum, m, n, N);
}

#endif // HAS_SIMD8

__attribute__((intel_reqd_sub_group_size(16)))
__attribute__((reqd_work_group_size(16, 1, 1)))
kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K)
Expand Down Expand Up @@ -545,6 +549,8 @@ kernel void bfloat16_dpas_rowmajor_m8_n16(global float* C, global ushort* A, glo
__store_c_row_major_fp32_m8(C, sum, m, n, N);
}

#if HAS_SIMD8

__attribute__((intel_reqd_sub_group_size(8)))
__attribute__((reqd_work_group_size(8, 1, 1)))
kernel void bfloat16_dpas_vnni_m1_n8(global float* C, global ushort* A, global ushort* B, int K)
Expand Down Expand Up @@ -617,6 +623,8 @@ kernel void bfloat16_dpas_vnni_m8_n8(global float* C, global ushort* A, global u
__store_c_row_major_fp32_m8(C, sum, m, n, N);
}

#endif // HAS_SIMD8

__attribute__((intel_reqd_sub_group_size(16)))
__attribute__((reqd_work_group_size(16, 1, 1)))
kernel void bfloat16_dpas_vnni_m1_n16(global float* C, global ushort* A, global ushort* B, int K)
Expand Down

0 comments on commit 074e0a5

Please sign in to comment.