Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cl_intel_subgroup_matrix_multiply_accumulate update #1181

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 64 additions & 8 deletions extensions/cl_intel_subgroup_matrix_multiply_accumulate.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,27 @@ int2 intel_sub_group_u8_u8_matrix_mad_k32(uint2 a, uint8 b, int2 acc);
int4 intel_sub_group_u8_u8_matrix_mad_k32(uint4 a, uint8 b, int4 acc);
int8 intel_sub_group_u8_u8_matrix_mad_k32(uint8 a, uint8 b, int8 acc);

// 4-bit matrices:
int intel_sub_group_i4_i4_matrix_mad_k64(int a, int8 b, int acc);
int2 intel_sub_group_i4_i4_matrix_mad_k64(int2 a, int8 b, int2 acc);
int4 intel_sub_group_i4_i4_matrix_mad_k64(int4 a, int8 b, int4 acc);
int8 intel_sub_group_i4_i4_matrix_mad_k64(int8 a, int8 b, int8 acc);

int intel_sub_group_i4_u4_matrix_mad_k64(int a, uint8 b, int acc);
int2 intel_sub_group_i4_u4_matrix_mad_k64(int2 a, uint8 b, int2 acc);
int4 intel_sub_group_i4_u4_matrix_mad_k64(int4 a, uint8 b, int4 acc);
int8 intel_sub_group_i4_u4_matrix_mad_k64(int8 a, uint8 b, int8 acc);

int intel_sub_group_u4_i4_matrix_mad_k64(uint a, int8 b, int acc);
int2 intel_sub_group_u4_i4_matrix_mad_k64(uint2 a, int8 b, int2 acc);
int4 intel_sub_group_u4_i4_matrix_mad_k64(uint4 a, int8 b, int4 acc);
int8 intel_sub_group_u4_i4_matrix_mad_k64(uint8 a, int8 b, int8 acc);

int intel_sub_group_u4_u4_matrix_mad_k64(uint a, uint8 b, int acc);
int2 intel_sub_group_u4_u4_matrix_mad_k64(uint2 a, uint8 b, int2 acc);
int4 intel_sub_group_u4_u4_matrix_mad_k64(uint4 a, uint8 b, int4 acc);
int8 intel_sub_group_u4_u4_matrix_mad_k64(uint8 a, uint8 b, int8 acc);

// bfloat16 matrices:
float intel_sub_group_bf16_bf16_matrix_mad_k16(int a, int8 b, float acc);
float2 intel_sub_group_bf16_bf16_matrix_mad_k16(int2 a, int8 b, float2 acc);
Expand Down Expand Up @@ -134,17 +155,50 @@ int2 intel_sub_group_u8_u8_matrix_mad_k32(ushort2 a, uint8 b, int2 acc);
int4 intel_sub_group_u8_u8_matrix_mad_k32(ushort4 a, uint8 b, int4 acc);
int8 intel_sub_group_u8_u8_matrix_mad_k32(ushort8 a, uint8 b, int8 acc);

// bfloat16 matrices:
// 4-bit matrices:
int intel_sub_group_i4_i4_matrix_mad_k64(short a, int8 b, int acc);
int2 intel_sub_group_i4_i4_matrix_mad_k64(short2 a, int8 b, int2 acc);
int4 intel_sub_group_i4_i4_matrix_mad_k64(short4 a, int8 b, int4 acc);
int8 intel_sub_group_i4_i4_matrix_mad_k64(short8 a, int8 b, int8 acc);

int intel_sub_group_i4_u4_matrix_mad_k64(short a, uint8 b, int acc);
int2 intel_sub_group_i4_u4_matrix_mad_k64(short2 a, uint8 b, int2 acc);
int4 intel_sub_group_i4_u4_matrix_mad_k64(short4 a, uint8 b, int4 acc);
int8 intel_sub_group_i4_u4_matrix_mad_k64(short8 a, uint8 b, int8 acc);

int intel_sub_group_u4_i4_matrix_mad_k64(ushort a, int8 b, int acc);
int2 intel_sub_group_u4_i4_matrix_mad_k64(ushort2 a, int8 b, int2 acc);
int4 intel_sub_group_u4_i4_matrix_mad_k64(ushort4 a, int8 b, int4 acc);
int8 intel_sub_group_u4_i4_matrix_mad_k64(ushort8 a, int8 b, int8 acc);

int intel_sub_group_u4_u4_matrix_mad_k64(ushort a, uint8 b, int acc);
int2 intel_sub_group_u4_u4_matrix_mad_k64(ushort2 a, uint8 b, int2 acc);
int4 intel_sub_group_u4_u4_matrix_mad_k64(ushort4 a, uint8 b, int4 acc);
int8 intel_sub_group_u4_u4_matrix_mad_k64(ushort8 a, uint8 b, int8 acc);

// bfloat16 matrices with float accumulator:
float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc);
float2 intel_sub_group_bf16_bf16_matrix_mad_k16(short2 a, int8 b, float2 acc);
float4 intel_sub_group_bf16_bf16_matrix_mad_k16(short4 a, int8 b, float4 acc);
float8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc);

// fp16 matrices:
// fp16 matrices with float accumulator:
float intel_sub_group_f16_f16_matrix_mad_k16(short a, int8 b, float acc);
float2 intel_sub_group_f16_f16_matrix_mad_k16(short2 a, int8 b, float2 acc);
float4 intel_sub_group_f16_f16_matrix_mad_k16(short4 a, int8 b, float4 acc);
float8 intel_sub_group_f16_f16_matrix_mad_k16(short8 a, int8 b, float8 acc);

// bfloat16 with bfloat16 accumulator:
short intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, short acc);
short2 intel_sub_group_bf16_bf16_matrix_mad_k16(short2 a, int8 b, short2 acc);
short4 intel_sub_group_bf16_bf16_matrix_mad_k16(short4 a, int8 b, short4 acc);
short8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, short8 acc);

// fp16 matrices with fp16 accumulator:
half intel_sub_group_f16_f16_matrix_mad_k16(short a, int8 b, half acc);
half2 intel_sub_group_f16_f16_matrix_mad_k16(short2 a, int8 b, half2 acc);
half4 intel_sub_group_f16_f16_matrix_mad_k16(short4 a, int8 b, half4 acc);
half8 intel_sub_group_f16_f16_matrix_mad_k16(short8 a, int8 b, half8 acc);
----

== Modifications to the OpenCL C Specification
Expand Down Expand Up @@ -213,10 +267,13 @@ For this list of functions:
* `M` may be equal to 1, 2, 4, or 8.
* `N` must be equal to 8 for some devices or 16 for other devices.
In other words, the only supported subgroup sizes are 8 and 16.
* Supported integer matrix types for `a` and `b` are any combination of signed or unsigned 8-bit integers.
For these integer matrix types, the accumulation value `acc` and result value are signed 32-bit integers, and `K` must be equal to 32.
* Supported integer matrix types for `a` and `b` are any combination of signed or unsigned 8-bit integers, or any combination of signed or unsigned 4-bit integers.
For 8-bit matrices, `K` must be equal to 32. For 4-bit matrices, `K` must be equal to 64.
For these integer matrix types, the accumulation value `acc` and result value are signed 32-bit integers.
* The supported floating-point matrix types for `a` and `b` are fp16 (half) or bfloat16.
For these floating-point matrix type, the accumulation value `acc` and result value are 32-bit floating-point values, and `K` must be equal to 16.
For these floating-point matrices, `K` must be equal to 16.
The accumulation value `acc` and result value are 32-bit floating-point values.
For devices with `N` equal to 16, the accumulation value `acc` and result value may also be fp16 for fp16 matrices, or bfloat16 for bfloat16 matrices.

== Coding Sample

Expand Down Expand Up @@ -288,12 +345,10 @@ int2 intel_sub_group_i8_i8_matrix_mad_k32(int2 a, int8 b, int2 acc)

== Issues

None.

. Should this extension use signed or unsigned types to represent fp16 and bf16 data?
+
--
`RESOLVED`: This extension will use signed types to represent fp16 and bf16 data even though this is inconsistent with other extensions such as cl_intel_bfloat16 conversions.
`RESOLVED`: This extension will use signed types to represent fp16 and bf16 data even though this is inconsistent with other extensions, such as the `cl_intel_bfloat16_conversions` extension.
This inconsistency may be addressed in a future extension or in a future version of this extension.
Applications are encouraged to use `as_type` to reinterpret unsigned data as signed data as needed to use the functions added by this extension.
--
Expand All @@ -306,6 +361,7 @@ Applications are encouraged to use `as_type` to reinterpret unsigned data as sig
|========================================
|Rev|Date|Author|Changes
|1.0.0|2022-05-18|Ben Ashbaugh|*Initial public revision*
|1.0.0|2024-06-06|Ben Ashbaugh|Document additional functions.
|========================================

//************************************************************************
Expand Down