diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 0d78b42..32deef1 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -304,18 +304,21 @@ int main(int argc, char** argv) printf("Running on device: %s\n", device.getInfo().c_str() ); + auto minSubGroupSize = findMinSubGroupSize(device); + + bool has_simd8 = minSubGroupSize == 8; bool emulate_tN8 = true; bool emulate_tN16 = true; if (!emulate && checkDeviceForExtension(device, "cl_intel_subgroup_matrix_multiply_accumulate")) { - auto minSubGroupSize = findMinSubGroupSize(device); printf("Found support for cl_intel_subgroup_matrix_multiply_accumulate, min sub-group size is: %zu\n", minSubGroupSize); switch(minSubGroupSize) { - case 8: emulate_tN8 = false; break; + case 8: emulate_tN8 = false; break; case 16: emulate_tN16 = false; break; default: break; } } + buildOptions += " -DHAS_SIMD8=" + std::to_string(has_simd8); buildOptions += " -DEMULATE_tN8=" + std::to_string(emulate_tN8); buildOptions += " -DEMULATE_tN16=" + std::to_string(emulate_tN16); diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index 7121021..009225a 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -401,6 +401,8 @@ static void __store_c_row_major_fp32_m8(global float* C, float8 v, int rowStart, intel_sub_group_block_write(C_ui + offset, v_ui.s7); offset += stride; } +#if HAS_SIMD8 + __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_rowmajor_m1_n8(global float* C, global ushort* A, global ushort* B, int K) @@ -473,6 +475,8 @@ kernel void bfloat16_dpas_rowmajor_m8_n8(global float* C, global ushort* A, glob __store_c_row_major_fp32_m8(C, sum, m, n, N); } +#endif // HAS_SIMD8 + __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K) @@ -545,6 +549,8 @@ kernel void bfloat16_dpas_rowmajor_m8_n16(global float* C, global ushort* A, glo __store_c_row_major_fp32_m8(C, sum, m, n, N); } +#if HAS_SIMD8 + __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_vnni_m1_n8(global float* C, global ushort* A, global ushort* B, int K) @@ -617,6 +623,8 @@ kernel void bfloat16_dpas_vnni_m8_n8(global float* C, global ushort* A, global u __store_c_row_major_fp32_m8(C, sum, m, n, N); } +#endif // HAS_SIMD8 + __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_vnni_m1_n16(global float* C, global ushort* A, global ushort* B, int K)