From 04a08d4f7b55a6e15ee02cf59aeae89df551c589 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 6 Jun 2024 17:22:34 -0700 Subject: [PATCH] document int4 functions and functions with other return types --- ...bgroup_matrix_multiply_accumulate.asciidoc | 72 ++++++++++++++++--- 1 file changed, 64 insertions(+), 8 deletions(-) diff --git a/extensions/cl_intel_subgroup_matrix_multiply_accumulate.asciidoc b/extensions/cl_intel_subgroup_matrix_multiply_accumulate.asciidoc index d6f492ba..64d4b348 100644 --- a/extensions/cl_intel_subgroup_matrix_multiply_accumulate.asciidoc +++ b/extensions/cl_intel_subgroup_matrix_multiply_accumulate.asciidoc @@ -95,6 +95,27 @@ int2 intel_sub_group_u8_u8_matrix_mad_k32(uint2 a, uint8 b, int2 acc); int4 intel_sub_group_u8_u8_matrix_mad_k32(uint4 a, uint8 b, int4 acc); int8 intel_sub_group_u8_u8_matrix_mad_k32(uint8 a, uint8 b, int8 acc); +// 4-bit matrices: +int intel_sub_group_i4_i4_matrix_mad_k64(int a, int8 b, int acc); +int2 intel_sub_group_i4_i4_matrix_mad_k64(int2 a, int8 b, int2 acc); +int4 intel_sub_group_i4_i4_matrix_mad_k64(int4 a, int8 b, int4 acc); +int8 intel_sub_group_i4_i4_matrix_mad_k64(int8 a, int8 b, int8 acc); + +int intel_sub_group_i4_u4_matrix_mad_k64(int a, uint8 b, int acc); +int2 intel_sub_group_i4_u4_matrix_mad_k64(int2 a, uint8 b, int2 acc); +int4 intel_sub_group_i4_u4_matrix_mad_k64(int4 a, uint8 b, int4 acc); +int8 intel_sub_group_i4_u4_matrix_mad_k64(int8 a, uint8 b, int8 acc); + +int intel_sub_group_u4_i4_matrix_mad_k64(uint a, int8 b, int acc); +int2 intel_sub_group_u4_i4_matrix_mad_k64(uint2 a, int8 b, int2 acc); +int4 intel_sub_group_u4_i4_matrix_mad_k64(uint4 a, int8 b, int4 acc); +int8 intel_sub_group_u4_i4_matrix_mad_k64(uint8 a, int8 b, int8 acc); + +int intel_sub_group_u4_u4_matrix_mad_k64(uint a, uint8 b, int acc); +int2 intel_sub_group_u4_u4_matrix_mad_k64(uint2 a, uint8 b, int2 acc); +int4 intel_sub_group_u4_u4_matrix_mad_k64(uint4 a, uint8 b, int4 acc); +int8 intel_sub_group_u4_u4_matrix_mad_k64(uint8 a, uint8 b, int8 acc); + // bfloat16 matrices: float intel_sub_group_bf16_bf16_matrix_mad_k16(int a, int8 b, float acc); float2 intel_sub_group_bf16_bf16_matrix_mad_k16(int2 a, int8 b, float2 acc); @@ -134,17 +155,50 @@ int2 intel_sub_group_u8_u8_matrix_mad_k32(ushort2 a, uint8 b, int2 acc); int4 intel_sub_group_u8_u8_matrix_mad_k32(ushort4 a, uint8 b, int4 acc); int8 intel_sub_group_u8_u8_matrix_mad_k32(ushort8 a, uint8 b, int8 acc); -// bfloat16 matrices: +// 4-bit matrices: +int intel_sub_group_i4_i4_matrix_mad_k64(short a, int8 b, int acc); +int2 intel_sub_group_i4_i4_matrix_mad_k64(short2 a, int8 b, int2 acc); +int4 intel_sub_group_i4_i4_matrix_mad_k64(short4 a, int8 b, int4 acc); +int8 intel_sub_group_i4_i4_matrix_mad_k64(short8 a, int8 b, int8 acc); + +int intel_sub_group_i4_u4_matrix_mad_k64(short a, uint8 b, int acc); +int2 intel_sub_group_i4_u4_matrix_mad_k64(short2 a, uint8 b, int2 acc); +int4 intel_sub_group_i4_u4_matrix_mad_k64(short4 a, uint8 b, int4 acc); +int8 intel_sub_group_i4_u4_matrix_mad_k64(short8 a, uint8 b, int8 acc); + +int intel_sub_group_u4_i4_matrix_mad_k64(ushort a, int8 b, int acc); +int2 intel_sub_group_u4_i4_matrix_mad_k64(ushort2 a, int8 b, int2 acc); +int4 intel_sub_group_u4_i4_matrix_mad_k64(ushort4 a, int8 b, int4 acc); +int8 intel_sub_group_u4_i4_matrix_mad_k64(ushort8 a, int8 b, int8 acc); + +int intel_sub_group_u4_u4_matrix_mad_k64(ushort a, uint8 b, int acc); +int2 intel_sub_group_u4_u4_matrix_mad_k64(ushort2 a, uint8 b, int2 acc); +int4 intel_sub_group_u4_u4_matrix_mad_k64(ushort4 a, uint8 b, int4 acc); +int8 intel_sub_group_u4_u4_matrix_mad_k64(ushort8 a, uint8 b, int8 acc); + +// bfloat16 matrices with float accumulator: float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc); float2 intel_sub_group_bf16_bf16_matrix_mad_k16(short2 a, int8 b, float2 acc); float4 intel_sub_group_bf16_bf16_matrix_mad_k16(short4 a, int8 b, float4 acc); float8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc); -// fp16 matrices: +// fp16 matrices with float accumulator: float intel_sub_group_f16_f16_matrix_mad_k16(short a, int8 b, float acc); float2 intel_sub_group_f16_f16_matrix_mad_k16(short2 a, int8 b, float2 acc); float4 intel_sub_group_f16_f16_matrix_mad_k16(short4 a, int8 b, float4 acc); float8 intel_sub_group_f16_f16_matrix_mad_k16(short8 a, int8 b, float8 acc); + +// bfloat16 with bfloat16 accumulator: +short intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, short acc); +short2 intel_sub_group_bf16_bf16_matrix_mad_k16(short2 a, int8 b, short2 acc); +short4 intel_sub_group_bf16_bf16_matrix_mad_k16(short4 a, int8 b, short4 acc); +short8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, short8 acc); + +// fp16 matrices with fp16 accumulator: +half intel_sub_group_f16_f16_matrix_mad_k16(short a, int8 b, half acc); +half2 intel_sub_group_f16_f16_matrix_mad_k16(short2 a, int8 b, half2 acc); +half4 intel_sub_group_f16_f16_matrix_mad_k16(short4 a, int8 b, half4 acc); +half8 intel_sub_group_f16_f16_matrix_mad_k16(short8 a, int8 b, half8 acc); ---- == Modifications to the OpenCL C Specification @@ -213,10 +267,13 @@ For this list of functions: * `M` may be equal to 1, 2, 4, or 8. * `N` must be equal to 8 for some devices or 16 for other devices. In other words, the only supported subgroup sizes are 8 and 16. -* Supported integer matrix types for `a` and `b` are any combination of signed or unsigned 8-bit integers. -For these integer matrix types, the accumulation value `acc` and result value are signed 32-bit integers, and `K` must be equal to 32. +* Supported integer matrix types for `a` and `b` are any combination of signed or unsigned 8-bit integers, or any combination of signed or unsigned 4-bit integers. +For 8-bit matrices, `K` must be equal to 32. For 4-bit matrices, `K` must be equal to 64. +For these integer matrix types, the accumulation value `acc` and result value are signed 32-bit integers. * The supported floating-point matrix types for `a` and `b` are fp16 (half) or bfloat16. -For these floating-point matrix type, the accumulation value `acc` and result value are 32-bit floating-point values, and `K` must be equal to 16. +For these floating-point matrices, `K` must be equal to 16. +The accumulation value `acc` and result value are 32-bit floating-point values. +For devices with `N` equal to 16, the accumulation value `acc` and result value may also be fp16 for fp16 matrices, or bfloat16 for bfloat16 matrices. == Coding Sample @@ -288,12 +345,10 @@ int2 intel_sub_group_i8_i8_matrix_mad_k32(int2 a, int8 b, int2 acc) == Issues -None. - . Should this extension use signed or unsigned types to represent fp16 and bf16 data? + -- -`RESOLVED`: This extension will use signed types to represent fp16 and bf16 data even though this is inconsistent with other extensions such as cl_intel_bfloat16 conversions. +`RESOLVED`: This extension will use signed types to represent fp16 and bf16 data even though this is inconsistent with other extensions, such as the `cl_intel_bfloat16_conversions` extension. This inconsistency may be addressed in a future extension or in a future version of this extension. Applications are encouraged to use `as_type` to reinterpret unsigned data as signed data as needed to use the functions added by this extension. -- @@ -306,6 +361,7 @@ Applications are encouraged to use `as_type` to reinterpret unsigned data as sig |======================================== |Rev|Date|Author|Changes |1.0.0|2022-05-18|Ben Ashbaugh|*Initial public revision* +|1.0.0|2024-06-06|Ben Ashbaugh|Document additional functions. |======================================== //************************************************************************