Skip to content

Commit

Permalink
document int4 functions and functions with other return types
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Jun 7, 2024
1 parent f37a868 commit 04a08d4
Showing 1 changed file with 64 additions and 8 deletions.
72 changes: 64 additions & 8 deletions extensions/cl_intel_subgroup_matrix_multiply_accumulate.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,27 @@ int2 intel_sub_group_u8_u8_matrix_mad_k32(uint2 a, uint8 b, int2 acc);
int4 intel_sub_group_u8_u8_matrix_mad_k32(uint4 a, uint8 b, int4 acc);
int8 intel_sub_group_u8_u8_matrix_mad_k32(uint8 a, uint8 b, int8 acc);
// 4-bit matrices:
int intel_sub_group_i4_i4_matrix_mad_k64(int a, int8 b, int acc);
int2 intel_sub_group_i4_i4_matrix_mad_k64(int2 a, int8 b, int2 acc);
int4 intel_sub_group_i4_i4_matrix_mad_k64(int4 a, int8 b, int4 acc);
int8 intel_sub_group_i4_i4_matrix_mad_k64(int8 a, int8 b, int8 acc);
int intel_sub_group_i4_u4_matrix_mad_k64(int a, uint8 b, int acc);
int2 intel_sub_group_i4_u4_matrix_mad_k64(int2 a, uint8 b, int2 acc);
int4 intel_sub_group_i4_u4_matrix_mad_k64(int4 a, uint8 b, int4 acc);
int8 intel_sub_group_i4_u4_matrix_mad_k64(int8 a, uint8 b, int8 acc);
int intel_sub_group_u4_i4_matrix_mad_k64(uint a, int8 b, int acc);
int2 intel_sub_group_u4_i4_matrix_mad_k64(uint2 a, int8 b, int2 acc);
int4 intel_sub_group_u4_i4_matrix_mad_k64(uint4 a, int8 b, int4 acc);
int8 intel_sub_group_u4_i4_matrix_mad_k64(uint8 a, int8 b, int8 acc);
int intel_sub_group_u4_u4_matrix_mad_k64(uint a, uint8 b, int acc);
int2 intel_sub_group_u4_u4_matrix_mad_k64(uint2 a, uint8 b, int2 acc);
int4 intel_sub_group_u4_u4_matrix_mad_k64(uint4 a, uint8 b, int4 acc);
int8 intel_sub_group_u4_u4_matrix_mad_k64(uint8 a, uint8 b, int8 acc);
// bfloat16 matrices:
float intel_sub_group_bf16_bf16_matrix_mad_k16(int a, int8 b, float acc);
float2 intel_sub_group_bf16_bf16_matrix_mad_k16(int2 a, int8 b, float2 acc);
Expand Down Expand Up @@ -134,17 +155,50 @@ int2 intel_sub_group_u8_u8_matrix_mad_k32(ushort2 a, uint8 b, int2 acc);
int4 intel_sub_group_u8_u8_matrix_mad_k32(ushort4 a, uint8 b, int4 acc);
int8 intel_sub_group_u8_u8_matrix_mad_k32(ushort8 a, uint8 b, int8 acc);
// bfloat16 matrices:
// 4-bit matrices:
int intel_sub_group_i4_i4_matrix_mad_k64(short a, int8 b, int acc);
int2 intel_sub_group_i4_i4_matrix_mad_k64(short2 a, int8 b, int2 acc);
int4 intel_sub_group_i4_i4_matrix_mad_k64(short4 a, int8 b, int4 acc);
int8 intel_sub_group_i4_i4_matrix_mad_k64(short8 a, int8 b, int8 acc);
int intel_sub_group_i4_u4_matrix_mad_k64(short a, uint8 b, int acc);
int2 intel_sub_group_i4_u4_matrix_mad_k64(short2 a, uint8 b, int2 acc);
int4 intel_sub_group_i4_u4_matrix_mad_k64(short4 a, uint8 b, int4 acc);
int8 intel_sub_group_i4_u4_matrix_mad_k64(short8 a, uint8 b, int8 acc);
int intel_sub_group_u4_i4_matrix_mad_k64(ushort a, int8 b, int acc);
int2 intel_sub_group_u4_i4_matrix_mad_k64(ushort2 a, int8 b, int2 acc);
int4 intel_sub_group_u4_i4_matrix_mad_k64(ushort4 a, int8 b, int4 acc);
int8 intel_sub_group_u4_i4_matrix_mad_k64(ushort8 a, int8 b, int8 acc);
int intel_sub_group_u4_u4_matrix_mad_k64(ushort a, uint8 b, int acc);
int2 intel_sub_group_u4_u4_matrix_mad_k64(ushort2 a, uint8 b, int2 acc);
int4 intel_sub_group_u4_u4_matrix_mad_k64(ushort4 a, uint8 b, int4 acc);
int8 intel_sub_group_u4_u4_matrix_mad_k64(ushort8 a, uint8 b, int8 acc);
// bfloat16 matrices with float accumulator:
float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc);
float2 intel_sub_group_bf16_bf16_matrix_mad_k16(short2 a, int8 b, float2 acc);
float4 intel_sub_group_bf16_bf16_matrix_mad_k16(short4 a, int8 b, float4 acc);
float8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc);
// fp16 matrices:
// fp16 matrices with float accumulator:
float intel_sub_group_f16_f16_matrix_mad_k16(short a, int8 b, float acc);
float2 intel_sub_group_f16_f16_matrix_mad_k16(short2 a, int8 b, float2 acc);
float4 intel_sub_group_f16_f16_matrix_mad_k16(short4 a, int8 b, float4 acc);
float8 intel_sub_group_f16_f16_matrix_mad_k16(short8 a, int8 b, float8 acc);
// bfloat16 with bfloat16 accumulator:
short intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, short acc);
short2 intel_sub_group_bf16_bf16_matrix_mad_k16(short2 a, int8 b, short2 acc);
short4 intel_sub_group_bf16_bf16_matrix_mad_k16(short4 a, int8 b, short4 acc);
short8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, short8 acc);
// fp16 matrices with fp16 accumulator:
half intel_sub_group_f16_f16_matrix_mad_k16(short a, int8 b, half acc);
half2 intel_sub_group_f16_f16_matrix_mad_k16(short2 a, int8 b, half2 acc);
half4 intel_sub_group_f16_f16_matrix_mad_k16(short4 a, int8 b, half4 acc);
half8 intel_sub_group_f16_f16_matrix_mad_k16(short8 a, int8 b, half8 acc);
----

== Modifications to the OpenCL C Specification
Expand Down Expand Up @@ -213,10 +267,13 @@ For this list of functions:
* `M` may be equal to 1, 2, 4, or 8.
* `N` must be equal to 8 for some devices or 16 for other devices.
In other words, the only supported subgroup sizes are 8 and 16.
* Supported integer matrix types for `a` and `b` are any combination of signed or unsigned 8-bit integers.
For these integer matrix types, the accumulation value `acc` and result value are signed 32-bit integers, and `K` must be equal to 32.
* Supported integer matrix types for `a` and `b` are any combination of signed or unsigned 8-bit integers, or any combination of signed or unsigned 4-bit integers.
For 8-bit matrices, `K` must be equal to 32. For 4-bit matrices, `K` must be equal to 64.
For these integer matrix types, the accumulation value `acc` and result value are signed 32-bit integers.
* The supported floating-point matrix types for `a` and `b` are fp16 (half) or bfloat16.
For these floating-point matrix type, the accumulation value `acc` and result value are 32-bit floating-point values, and `K` must be equal to 16.
For these floating-point matrices, `K` must be equal to 16.
The accumulation value `acc` and result value are 32-bit floating-point values.
For devices with `N` equal to 16, the accumulation value `acc` and result value may also be fp16 for fp16 matrices, or bfloat16 for bfloat16 matrices.

== Coding Sample

Expand Down Expand Up @@ -288,12 +345,10 @@ int2 intel_sub_group_i8_i8_matrix_mad_k32(int2 a, int8 b, int2 acc)

== Issues

None.

. Should this extension use signed or unsigned types to represent fp16 and bf16 data?
+
--
`RESOLVED`: This extension will use signed types to represent fp16 and bf16 data even though this is inconsistent with other extensions such as cl_intel_bfloat16 conversions.
`RESOLVED`: This extension will use signed types to represent fp16 and bf16 data even though this is inconsistent with other extensions, such as the `cl_intel_bfloat16_conversions` extension.
This inconsistency may be addressed in a future extension or in a future version of this extension.
Applications are encouraged to use `as_type` to reinterpret unsigned data as signed data as needed to use the functions added by this extension.
--
Expand All @@ -306,6 +361,7 @@ Applications are encouraged to use `as_type` to reinterpret unsigned data as sig
|========================================
|Rev|Date|Author|Changes
|1.0.0|2022-05-18|Ben Ashbaugh|*Initial public revision*
|1.0.0|2024-06-06|Ben Ashbaugh|Document additional functions.
|========================================

//************************************************************************
Expand Down

0 comments on commit 04a08d4

Please sign in to comment.