Skip to content

Tensor Pack and Unpack

Renato Golin edited this page Jun 22, 2023 · 10 revisions

Motivation

The objective of tensor.pack and tensor.unpack is to reorder tensor data to align in memory and create cache-friendly access. For example, in a (row major) matrix multiply, the A matrix read is in rows, while the B matrix read is in columns. If we transpose the B matrix, its reads would also be in rows and therefore both reads would be cache friendly.

These operations are profitable for larger matrices. Since a copy is O(n^2) while matmul is O(n^3), a large enough matrix multiply will have more loads and stores inside the matmul than in the transpose, saving time.

Matrix multiplication are also typically tiled to increase cache reuse. When tile sizes are small enough, they can be fit into registers, allowing more efficient algorithms where the accumulation remains in registers throughout the reduction dimension of the tiles.

Therefore a "block-transpose", where whole tiles are moved "as-is" is even more profitable. Not only the tiles are read consecutively, but the packing copies also benefit from sequential reads, where an element-wise transpose does not.

Base Semantics

tensor.pack converts a smaller rank tensor into a larger rank tensor by dividing the dimensions given a tile size. For example, a dimension of 1024 with a tile size of 64 would be split into 16 x 64.

tensor.unpack is the reciprocal operation. In the example above, it takes a tensor with dimension 16 x 64 and converts back to 1024.

This allows one to pack a 2D tensor into a 4D tensor with a particular 2D tile size:

  %1 = tensor.pack %0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %empty : tensor<128x256xf32> -> tensor<4x8x32x32xf32>

Note:

  • The inner_tiles is [32, 32], so the two inner dimensions become <...x32x32xf32>
  • The inner_dims_pos is [0, 1], so the original dims that are tiled are 0 (128) and 1 (256) in that order.
  • The dimensionality of the tensor becomes: <original0 x original1 x tile0 x tile1>
  • A simple reshape would convert <128x256xf32> / <32x32> into <4x32x8x32xf32>
  • Pack is not just a reshape, given that the tile0 dimension has moved over original1 alongside tile1.
  • In the end, pack = reshape(shape/<tiles>) + transpose(tile-dims -> inner-dims)

Arguments

The basic arguments are:

  • Input tensor: The original tensor in the original shape
  • Output tensor: The into tensor (empty or not, for buffer reuse) with the packed shape
  • inner_tiles: The N-dimensional tile size
  • inner_dims_pos: The order in which the dimensions will be divided by the tile sizes
  • outer_dims_perm: A permutation to perform after the pack reshapes and transposes

Inner Tiles / Dims Pos

inner_dims_pos defines which original dimensions will be divided by the corresponding tile size.

inner_tiles is simply the tile size, in as many dimensions as needed. Must be the same order as the inner_dims_pos argument.

Example:

  for tile = [16, 64] and tensor = < 1024 x 512 x f32 >, pack is:
  ---------------------------------------------------------------
            [ 1024    x 512 ]
  reshape      |  \      |  \
            [ 64 x 16 x  8 x 64 ]
               |    \   /     |
  transpose    |      X       |
               |    /   \     |
            [ 64 x  8 x 16 x 64 ]

The number of tile dimensions needs to be equal or less than the rank of the input tensor. If it has less dimensions, not all dimensions of the original tensor will be tiled.

Example:

  <128x256x512xf32> / <16x8>[1, 2] = <128x16x64x16x8xf32>
    * 128 is left alone because the positions are [1, 2], ie. 256 and 512.
    * 256 is split to 16x16, with the tile dim (second 16) moved inside
    * 512 is split into 64x8, with the tile dim (8) moved as the inner-most dimension

Outer Permutations

outer_dims_perm defines what is the order of the packed shape's dimensions in the final shape.

The number of dimensions must be the same as the rank of the packed shape.

Example:

  for packed shape <128x16x64x16x8xf32> and outer_dims_perm = [ 0, 4, 1, 3, 2 ]:
  ------------------------------------------------------------------------------
               [ 128 x 16 x 64 x 16 x  8 ]
               [   0    1    2    3    4 ]
                   |     \    \   |   /
                   |      \    .\-|--'
                   |      .\--/  '|-.
                   |     /  \     |  \
               [ 128 x  8 x 16 x 16 x 64 ]