union amx_reg { // 64 byte vector...
// ...of unsigned integers:
uint8_t u8 [64];
uint16_t u16[32];
uint32_t u32[16];
// ...of signed integers:
int8_t i8 [64];
int16_t i16[32];
int32_t i32[16];
// ...of IEEE 754 floating point:
_Float16 f16[32]; // NB: IEEE half-precision, _not_ BF16
float f32[16];
double f64[ 8];
};
struct amx_state {
amx_reg x[ 8]; // 512 bytes, of which 64 bytes extracted / inserted by operations
amx_reg y[ 8]; // 512 bytes, of which 64 bytes extracted / inserted by operations
amx_reg z[64]; // 64 by 64 matrix of bytes
}; // 5KB total
Each register is 64 bytes, viewed as vector of u8/u16/u32/i8/i16/i32/f16/f32/f64 elements. The architectural state contains 80 such registers: 8 of which in the X pool, 8 of which in the Y pool, and the remaining 64 forming a 64x64 grid called Z.
The entire X register pool can be concatenated to form a circular buffer of 512 bytes. Most instructions can operate on any contiguous 64 byte range from this circular buffer. The same is true for Y: the entire Y pool can be concatenated to form a circular buffer of 512 bytes, and most instructions can operate on any contiguous 64 byte range from this circular buffer.
Once 64 bytes of X and 64 bytes of Y have been selected, operations between X and Y and Z can be performed. Said operations fall into two main categories:
- Vector: Select one register from Z, and combine X/Y/Z in a standard SIMD manner:
Z[i] += X[i] * Y[i]
- Matrix: Select a number of registers from Z equal to the number of lanes in X and Y, and combine X/Y/Z in an outer-product manner:
Z[j][i] += X[i] * Y[j]
Load/store instructions move data between memory and AMX registers.
Computation instructions can be used to synthesise various constants in the AMX registers: 0
is easy, as is floating-point -0
. The latter can be used with integer shift instructions to synthesise (positive or negative) integer powers of two.
There is no direct movement between AMX registers and A64 general purpose registers or SIMD registers; data has to go via memory.
By default, instructions operate on a 64-byte span from X or Y. Some operations support indexed loads rather than 64-byte span loads. Said loads are parameterised by two things: the element size and the index size. The element size (ES
) is 8/16/32/64 bits, and the index size (IS
) is 2/4/5 bits. The element count (EC
) is then 512 divided by the element size. A regular load would load an ES * EC
(i.e. 512) bit span from X or Y. An indexed load instead loads an IS * EC
bit span from X or Y, and then treats every group of IS
bits as a lane index into a different register with element size ES
. For example, taking ES
of 16 for f16 data and IS
of 2, a 64-bit span is loaded from X or Y, which can be viewed as u2[32]
vector, and is expanded to form an f16[32]
vector by looking up into lanes 0/1/2/3 of some other f16[32]
vector.
Once a 64 byte X (or Y) vector has been obtained (either by a regular load or an indexed load), some instructions support shuffling the 64 bytes before use.
For vectors of 8 elements (i.e. f64[8]
), the four (albeit only three distinct) available shuffles are:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | |
---|---|---|---|---|---|---|---|---|
S0 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
S1 | 0 | 4 | 1 | 5 | 2 | 6 | 3 | 7 |
S2 | 0 | 2 | 4 | 6 | 1 | 3 | 5 | 7 |
S3 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
For vectors of 16 elements (i.e. f32[16]
or i32[16]
or u32[16]
), the four available shuffles are:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
S0 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
S1 | 0 | 8 | 1 | 9 | 2 | 10 | 3 | 11 | 4 | 12 | 5 | 13 | 6 | 14 | 7 | 15 |
S2 | 0 | 4 | 8 | 12 | 1 | 5 | 9 | 13 | 2 | 6 | 10 | 14 | 3 | 7 | 11 | 15 |
S3 | 0 | 2 | 4 | 6 | 8 | 10 | 12 | 14 | 1 | 3 | 5 | 7 | 9 | 11 | 13 | 15 |
For vectors of 32 elements (i.e. f16[32]
or i16[32]
or u16[32]
), the four available shuffles are:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
S0 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 |
S1 | 0 | 16 | 1 | 17 | 2 | 18 | 3 | 19 | 4 | 20 | 5 | 21 | 6 | 22 | 7 | 23 | 8 | 24 | 9 | 25 | 10 | 26 | 11 | 27 | 12 | 28 | 13 | 29 | 14 | 30 | 15 | 31 |
S2 | 0 | 8 | 16 | 24 | 1 | 9 | 17 | 25 | 2 | 10 | 18 | 26 | 3 | 11 | 19 | 27 | 4 | 12 | 20 | 28 | 5 | 13 | 21 | 29 | 6 | 14 | 22 | 30 | 7 | 15 | 23 | 31 |
S3 | 0 | 4 | 8 | 12 | 16 | 20 | 24 | 28 | 1 | 5 | 9 | 13 | 17 | 21 | 25 | 29 | 2 | 6 | 10 | 14 | 18 | 22 | 26 | 30 | 3 | 7 | 11 | 15 | 19 | 23 | 27 | 31 |
For vectors of 64 elements (i.e. i8[64]
or u8[64]
), the four available shuffles are:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
S0 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 |
S1 | 0 | 32 | 1 | 33 | 2 | 34 | 3 | 35 | 4 | 36 | 5 | 37 | 6 | 38 | 7 | 39 | 8 | 40 | 9 | 41 | 10 | 42 | 11 | 43 | 12 | 44 | 13 | 45 | 14 | 46 | 15 | 47 | 16 | 48 | 17 | 49 | 18 | 50 | 19 | 51 | 20 | 52 | 21 | 53 | 22 | 54 | 23 | 55 | 24 | 56 | 25 | 57 | 26 | 58 | 27 | 59 | 28 | 60 | 29 | 61 | 30 | 62 | 31 | 63 |
S2 | 0 | 16 | 32 | 48 | 1 | 17 | 33 | 49 | 2 | 18 | 34 | 50 | 3 | 19 | 35 | 51 | 4 | 20 | 36 | 52 | 5 | 21 | 37 | 53 | 6 | 22 | 38 | 54 | 7 | 23 | 39 | 55 | 8 | 24 | 40 | 56 | 9 | 25 | 41 | 57 | 10 | 26 | 42 | 58 | 11 | 27 | 43 | 59 | 12 | 28 | 44 | 60 | 13 | 29 | 45 | 61 | 14 | 30 | 46 | 62 | 15 | 31 | 47 | 63 |
S3 | 0 | 8 | 16 | 24 | 32 | 40 | 48 | 56 | 1 | 9 | 17 | 25 | 33 | 41 | 49 | 57 | 2 | 10 | 18 | 26 | 34 | 42 | 50 | 58 | 3 | 11 | 19 | 27 | 35 | 43 | 51 | 59 | 4 | 12 | 20 | 28 | 36 | 44 | 52 | 60 | 5 | 13 | 21 | 29 | 37 | 45 | 53 | 61 | 6 | 14 | 22 | 30 | 38 | 46 | 54 | 62 | 7 | 15 | 23 | 31 | 39 | 47 | 55 | 63 |
In all cases, S0 is the identity, S1 moves lane 1 to lane 2, S2 moves lane 1 to lane 4, and S3 moves lane 1 to lane 8.
Most instructions support writing to only a subset of the output lanes, leaving the other lanes unchanged. This is controlled by a combination of a mode field and a value field. Said fields typically combine along the lines of:
Mode | Meaning of value (N) |
---|---|
0 |
Write to all lanes (0 ), or to odd lanes only (1 ), or to even lanes only (2 ), or to no lanes |
1 |
Only write lane #N (or for certain vector operations, write all lanes, but broadcast Y lane #N to all lanes of Y) |
2 |
Only write first N lanes, or to all lanes when N is zero |
3 |
Only write last N lanes, or to all lanes when N is zero |
4 |
Only write first N lanes (no lanes when N is zero) |
5 |
Only write last N lanes (no lanes when N is zero) |
6 |
Write to no lanes |
7 |
Write to no lanes |
Matrix operations have separate write-enable for the X axis and the Y axis, with the enabled Z elements being the outer product of the two write-enables.
When the element size is identical between X and Y and Z, indexing is simple. Assume an element size in bits (ES) of 8, 16, 32, or 64 for all three, then X and Y have N elements, where N = 512 / ES. In vector mode, a single Z register also has N elements. In matrix mode, a 2D grid of N2 values is used from Z: N distinct registers from Z, each containing N elements. The N distinct registers are equally spaced in the Y dimension, with spacing 64 / N (the user can choose the starting row, subject to 0 ≤ starting row < 64 / N).
When the element sizes are mixed (for example f16 × f16 ↦ f32 or i8 × i16 ↦ i32), then things are more complex. Either more Z registers need to be used (to make space for all the outputs), or some lanes from X and/or Y need to be dropped (because otherwise there is not space for all the outputs), or a combination of both. When lanes are dropped, it is typical to keep just the even lanes, or keep just one lane from every four (i.e. keep lanes 0, 4, 8, etc). Shuffles can be used to select different lanes; for example after applying shuffle S1 and then keeping just the even lanes, the result is lanes 0, 1, 2, etc; and after applying shuffle S2 and then keeping just one lane from every four then the result is lanes 0, 1, 2, etc. Alternatively, byte offsets on the input operands can be used to select different lanes: adding a byte offset equal to one lane turns even lanes into odd lanes, and turns lanes 0, 4, 8, etc into 1, 5, 9, etc.
One particularly common mixed-width combination is X and Y having element size of 16 bits (i.e. i16 or u32 or f16) and Z having element size 32 bits (i.e. i32 or u32 or f32). In this case, both X and Y have 32 elements, and every Z register has 16 elements. The complete outer product of X and Y would need 322 Z values, which there is just space for: use all 64 Z registers, with 16 elements in each. Each 4 by 4 block of bytes ends up looking like:
X0:1 | X2:3 | |||
Y0:1 | Z0,0:3 += X0:1 × Y0:1 | |||
Z1,0:3 += X2:3 × Y0:1 | ||||
Y2:3 | Z2,0:3 += X0:1 × Y2:3 | |||
Z3,0:3 += X2:3 × Y2:3 |
An alternative way of viewing this combination is that every pair of Z registers contains 32 lanes (corresponding to the lanes of X), and there are 32 such pairs (corresponding to the lanes of Y), with each pair arranged as:
Z0 | 0 | 2 | 4 | 6 | 8 | 10 | 12 | 14 | 16 | 18 | 20 | 22 | 24 | 26 | 28 | 30 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Z1 | 1 | 3 | 5 | 7 | 9 | 11 | 13 | 15 | 17 | 19 | 21 | 23 | 25 | 27 | 29 | 31 |
This arrangement is called an interleaved pair of Z registers, and for (16,16,32) has support instructions in the form of ldzi
and stzi
.