Skip to content

Tensor Pack and Unpack

lorenzo chelini edited this page Oct 6, 2023 · 10 revisions

Motivation

The objective of tensor.pack and tensor.unpack is to reorder tensor data to align in memory and create cache-friendly access. For example, in a (row major) matrix multiply, the A matrix read is in rows, while the B matrix read is in columns. If we transpose the B matrix, its reads would also be in rows and therefore both reads would be cache friendly.

These operations are profitable for larger matrices. Since a copy is O(n^2) while matmul is O(n^3), a large enough matrix multiply will have more loads and stores inside the matmul than in the transpose, saving time.

Matrix multiplication are also typically tiled to increase cache reuse. When tile sizes are small enough, they can be fit into registers, allowing more efficient algorithms where the accumulation remains in registers throughout the reduction dimension of the tiles.

Therefore a "block-transpose", where whole tiles are moved "as-is" is even more profitable. Not only the tiles are read consecutively, but the packing copies also benefit from sequential reads, where an element-wise transpose does not.

Base Semantics

tensor.pack converts a smaller rank tensor into a larger rank tensor by dividing the dimensions given a tile size. For example, a dimension of 1024 with a tile size of 64 would be split into 16 x 64.

tensor.unpack is the reciprocal operation. In the example above, it takes a tensor with dimension 16 x 64 and converts back to 1024.

This allows one to pack a 2D tensor into a 4D tensor with a particular 2D tile size:

  %1 = tensor.pack %0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %empty : tensor<128x256xf32> -> tensor<4x8x32x32xf32>

Note:

  • The inner_tiles is [32, 32], so the two inner dimensions become <...x32x32xf32>
  • The inner_dims_pos is [0, 1], so the original dims that are tiled are 0 (128) and 1 (256) in that order.
  • The dimensionality of the tensor becomes: <original0 x original1 x tile0 x tile1>
  • A simple reshape would convert <128x256xf32> / <32x32> into <4x32x8x32xf32>
  • Pack is not just a reshape, given that the tile0 dimension has moved over original1 alongside tile1.
  • In the end, pack = reshape(shape/<tiles>) + transpose(tile-dims -> inner-dims)

Arguments

The basic arguments are:

  • Input tensor: The original tensor in the original shape
  • Output tensor: The into tensor (empty or not, for buffer reuse) with the packed shape
  • inner_tiles: The N-dimensional tile size
  • inner_dims_pos: The order in which the dimensions will be divided by the tile sizes
  • outer_dims_perm: A permutation to perform after the pack reshapes and transposes

Inner Tiles / Dims Pos

inner_dims_pos defines which original dimensions will be divided by the corresponding tile size.

inner_tiles is simply the tile size, in as many dimensions as needed. Must be the same order as the inner_dims_pos argument.

Example:

  for tile = [16, 64] and tensor = < 1024 x 512 x f32 >, pack is:
  ---------------------------------------------------------------
            [ 1024    x 512 ]
  reshape      |  \      |  \
            [ 64 x 16 x  8 x 64 ]
               |    \   /     |
  transpose    |      X       |
               |    /   \     |
            [ 64 x  8 x 16 x 64 ]

The number of tile dimensions needs to be equal or less than the rank of the input tensor. If it has less dimensions, not all dimensions of the original tensor will be tiled.

Example:

  <128x256x512xf32> / <16x8>[1, 2] = <128x16x64x16x8xf32>
    * 128 is left alone because the positions are [1, 2], ie. 256 and 512.
    * 256 is split to 16x16, with the tile dim (second 16) moved inside
    * 512 is split into 64x8, with the tile dim (8) moved as the inner-most dimension

Outer Permutations

outer_dims_perm defines what is the order of the packed shape's dimensions in the final shape.

The number of dimensions must be the same as the rank of the packed shape.

Example:

  for packed shape <128x16x64x16x8xf32> and outer_dims_perm = [ 0, 4, 1, 3, 2 ]:
  ------------------------------------------------------------------------------
               [ 128 x 16 x 64 x 16 x  8 ]
               [   0    1    2    3    4 ]
                   |     \    \   |   /
                   |      \    .\-|--'
                   |      .\--/  '|-.
                   |     /  \     |  \
               [ 128 x  8 x 16 x 16 x 64 ]

Performance numbers:

tpp-opt commit: 70332c24b427086668a820ed7af985f821ce4b6c The pack and unpack are for a GEMM of size: 512x512x1024 (M, N, K). The pack is for the A operand (B is constant fold). Unpack is for the output. The performance are compared with a memcpy of size 512x512.

In-house decomposition

~/tpp-sandbox-micro-bench# tpp-opt mlir/pack_gemm_operand_a.mlir -tpp-mapping -bufferize -convert-memref-to-xsmm  
module {
  func.func @pack_gemm_operand_a(%arg0: memref<512x1024xf32>, %arg1: memref<16x32x32x32xf32>) -> memref<16x32x32x32xf32> attributes {llvm.emit_c_interface} {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c16 = arith.constant 16 : index
    %c32 = arith.constant 32 : index
    scf.for %arg2 = %c0 to %c16 step %c1 {
      scf.for %arg3 = %c0 to %c32 step %c1 {
        %0 = arith.muli %arg2, %c32 : index
        %1 = arith.muli %arg3, %c32 : index
        %subview = memref.subview %arg0[%0, %1] [32, 32] [1, 1] : memref<512x1024xf32> to memref<32x32xf32, strided<[1024, 1], offset: ?>>
        %subview_0 = memref.subview %arg1[%arg2, %arg3, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<16x32x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
        %2 = xsmm.unary.dispatch identity [32, 32, 1024, 32] flags = (none) data_type = f32
        xsmm.unary identity(data_type = f32, %2, %subview, %subview_0) : (i64, memref<32x32xf32, strided<[1024, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
      }
    }
    return %arg1 : memref<16x32x32x32xf32>
  }
}
~/tpp-sandbox-micro-bench# tpp-opt mlir/unpack_gemm_operand.mlir -tpp-mapping -bufferize -convert-memref-to-xsmm 
module {
  func.func @unpack_gemm_operand(%arg0: memref<16x16x32x32xf32>, %arg1: memref<512x512xf32>) -> memref<512x512xf32> attributes {llvm.emit_c_interface} {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c16 = arith.constant 16 : index
    %c32 = arith.constant 32 : index
    scf.for %arg2 = %c0 to %c16 step %c1 {
      scf.for %arg3 = %c0 to %c16 step %c1 {
        %0 = arith.muli %arg2, %c32 : index
        %1 = arith.muli %arg3, %c32 : index
        %subview = memref.subview %arg0[%arg2, %arg3, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<16x16x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
        %subview_0 = memref.subview %arg1[%0, %1] [32, 32] [1, 1] : memref<512x512xf32> to memref<32x32xf32, strided<[512, 1], offset: ?>>
        %2 = xsmm.unary.dispatch identity [32, 32, 32, 512] flags = (none) data_type = f32
        xsmm.unary identity(data_type = f32, %2, %subview, %subview_0) : (i64, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[512, 1], offset: ?>>) -> ()
      }
    }
    return %arg1 : memref<512x512xf32>
  }
}
CPU Caches:
  L1 Data 32 KiB (x6)
  L1 Instruction 32 KiB (x6)
  L2 Unified 256 KiB (x6)
  L3 Unified 12288 KiB (x1)
Load Average: 0.72, 0.77, 0.55
----------------------------------------------------------------------------------------
Benchmark                              Time             CPU   Iterations UserCounters...
----------------------------------------------------------------------------------------
BM_memcpy/262144_mean              13967 ns        13965 ns           50 bytes_per_second=69.9351Gi/s
BM_memcpy/262144_median            13912 ns        13910 ns           50 bytes_per_second=70.206Gi/s
BM_memcpy/262144_stddev              151 ns          151 ns           50 bytes_per_second=760.714Mi/s
BM_memcpy/262144_cv                 1.08 %          1.08 %            50 bytes_per_second=1.06%
BM_pack_gemm_operand_a_mean       162372 ns       162358 ns           50 bytes_per_second=12.047Gi/s
BM_pack_gemm_operand_a_median     165627 ns       165600 ns           50 bytes_per_second=11.7942Gi/s
BM_pack_gemm_operand_a_stddev       6157 ns         6155 ns           50 bytes_per_second=475.283Mi/s
BM_pack_gemm_operand_a_cv           3.79 %          3.79 %            50 bytes_per_second=3.85%
BM_unpack_gemm_operand_mean        95195 ns        95188 ns           50 bytes_per_second=10.3174Gi/s
BM_unpack_gemm_operand_median      98148 ns        98139 ns           50 bytes_per_second=9.95084Gi/s
BM_unpack_gemm_operand_stddev       6701 ns         6701 ns           50 bytes_per_second=862.63Mi/s
BM_unpack_gemm_operand_cv           7.04 %          7.04 %            50 bytes_per_second=8.16%

Upstream linalg lowering. NOTE: Linalg is lowered to loops

~/tpp-sandbox-micro-bench# tpp-opt mlir/pack_gemm_operand_a.mlir -tpp-mapping -bufferize -convert-memref-to-xsmm 
module {
  func.func @pack_gemm_operand_a(%arg0: memref<512x1024xf32>, %arg1: memref<16x32x32x32xf32>) attributes {llvm.emit_c_interface} {
    %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<512x1024xf32> into memref<16x32x32x32xf32>
    linalg.transpose ins(%expand_shape : memref<16x32x32x32xf32>) outs(%arg1 : memref<16x32x32x32xf32>) permutation = [0, 2, 1, 3] 
    return
  }
}
module {
  func.func @unpack_gemm_operand(%arg0: memref<16x16x32x32xf32>, %arg1: memref<512x512xf32>) attributes {llvm.emit_c_interface} {
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<16x32x16x32xf32>
    linalg.transpose ins(%arg0 : memref<16x16x32x32xf32>) outs(%alloc : memref<16x32x16x32xf32>) permutation = [0, 2, 1, 3] 
    %collapse_shape = memref.collapse_shape %alloc [[0, 1], [2, 3]] : memref<16x32x16x32xf32> into memref<512x512xf32>
    linalg.copy ins(%collapse_shape : memref<512x512xf32>) outs(%arg1 : memref<512x512xf32>)
    memref.dealloc %alloc : memref<16x32x16x32xf32>
    return
  }
}
CPU Caches:
  L1 Data 32 KiB (x6)
  L1 Instruction 32 KiB (x6)
  L2 Unified 256 KiB (x6)
  L3 Unified 12288 KiB (x1)
Load Average: 0.83, 0.65, 0.42
----------------------------------------------------------------------------------------
Benchmark                              Time             CPU   Iterations UserCounters...
----------------------------------------------------------------------------------------
BM_pack_gemm_operand_a_mean       352527 ns       352477 ns           50 bytes_per_second=5.54246Gi/s
BM_pack_gemm_operand_a_median     354444 ns       354424 ns           50 bytes_per_second=5.5107Gi/s
BM_pack_gemm_operand_a_stddev       5449 ns         5477 ns           50 bytes_per_second=89.1645Mi/s
BM_pack_gemm_operand_a_cv           1.55 %          1.55 %            50 bytes_per_second=1.57%
BM_unpack_gemm_operand_mean       407236 ns       407182 ns           50 bytes_per_second=2.39837Gi/s
BM_unpack_gemm_operand_median     406341 ns       406297 ns           50 bytes_per_second=2.40357Gi/s
BM_unpack_gemm_operand_stddev       1438 ns         1448 ns           50 bytes_per_second=8.70749Mi/s
BM_unpack_gemm_operand_cv           0.35 %          0.36 %            50 bytes_per_second=0.35%
BM_memcpy/262144_mean              14219 ns        14218 ns           50 bytes_per_second=68.7323Gi/s
BM_memcpy/262144_median            14039 ns        14038 ns           50 bytes_per_second=69.5664Gi/s
BM_memcpy/262144_stddev              384 ns          384 ns           50 bytes_per_second=1.81746Gi/s
BM_memcpy/262144_cv                 2.70 %          2.70 %            50 bytes_per_second=2.64%

Some thinking while looking at IR (Oct 6 2023)

func.func @unpack_fusion(%arg0: tensor<16x32x32x32xbf16>, %arg1: tensor<16x32x32x32xbf16>, %arg2: tensor<16x16x32x32xbf16>, %arg3: tensor<512x512xbf16>) -> tensor<512x512xbf16> {
  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x32x32x32xbf16>, tensor<16x32x32x32xbf16>) outs(%arg2 : tensor<16x16x32x32xbf16>) {
    ^bb0(%in: bf16, %in_0: bf16, %out: bf16):
      %1 = arith.mulf %in, %in_0 : bf16
      %2 = arith.addf %out, %1 : bf16
      linalg.yield %2 : bf16
    } -> tensor<16x16x32x32xbf16>
  %unpack = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg3 : tensor<16x16x32x32xbf16> -> tensor<512x512xbf16>
  return %unpack : tensor<512x512xbf16>
}
#map = affine_map<(d0) -> (d0 floordiv 32)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
module {
  func.func @unpack_fusion(%arg0: tensor<16x32x32x32xbf16>, %arg1: tensor<16x32x32x32xbf16>, %arg2: tensor<16x16x32x32xbf16>, %arg3: tensor<512x512xbf16>) -> tensor<512x512xbf16> {
    // We tile the unpack using the inner tile sizes and fuse the packed matmul inside.
    // The unpack is now replaced by the BRGEMM directly writing in %arg3.
    %0 = scf.forall (%arg4, %arg5) = (0, 0) to (512, 512) step (32, 32) shared_outs(%arg6 = %arg3) -> (tensor<512x512xbf16>) {
      %1 = affine.apply #map(%arg4)
      %2 = affine.apply #map(%arg5)
      %3 = affine.apply #map(%arg4)
      %4 = affine.apply #map(%arg5)
      %extracted_slice = tensor.extract_slice %arg0[%1, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : tensor<16x32x32x32xbf16> to tensor<32x32x32xbf16>
      %extracted_slice_0 = tensor.extract_slice %arg1[%2, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : tensor<16x32x32x32xbf16> to tensor<32x32x32xbf16>
      %extracted_slice_1 = tensor.extract_slice %arg2[%3, %4, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<16x16x32x32xbf16> to tensor<32x32xbf16>
      %5 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["reduction", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %extracted_slice_0 : tensor<32x32x32xbf16>, tensor<32x32x32xbf16>) outs(%extracted_slice_1 : tensor<32x32xbf16>) {
      ^bb0(%in: bf16, %in_2: bf16, %out: bf16):
        %6 = arith.mulf %in, %in_2 : bf16
        %7 = arith.addf %out, %6 : bf16
        linalg.yield %7 : bf16
      } -> tensor<32x32xbf16>
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %5 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<512x512xbf16>
      }
    }
    return %0 : tensor<512x512xbf16>
  }
}
module {
  func.func @unpack_fusion(%arg0: tensor<16x32x32x32xbf16>, %arg1: tensor<16x32x32x32xbf16>, %arg2: tensor<16x16x32x32xbf16>, %arg3: tensor<512x512xbf16>) -> tensor<512x512xbf16> {
    // In the main fusion pass however, we tile along the outer parallel dimensions of given contraction operations.
    // Can we fuse the unpack in this loops? Yes, we can see IR below.
    %0 = scf.forall (%arg4, %arg5) in (16, 16) shared_outs(%arg6 = %arg2) -> (tensor<16x16x32x32xbf16>) {
      %extracted_slice = tensor.extract_slice %arg0[%arg4, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : tensor<16x32x32x32xbf16> to tensor<32x32x32xbf16>
      %extracted_slice_0 = tensor.extract_slice %arg1[%arg5, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : tensor<16x32x32x32xbf16> to tensor<32x32x32xbf16>
      %extracted_slice_1 = tensor.extract_slice %arg6[%arg4, %arg5, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<16x16x32x32xbf16> to tensor<32x32xbf16>
      %1 = linalg.batch_reduce_matmul ins(%extracted_slice, %extracted_slice_0 : tensor<32x32x32xbf16>, tensor<32x32x32xbf16>) outs(%extracted_slice_1 : tensor<32x32xbf16>) -> tensor<32x32xbf16>
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %1 into %arg6[%arg4, %arg5, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xbf16> into tensor<16x16x32x32xbf16>
      }
    }
    %unpack = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg3 : tensor<16x16x32x32xbf16> -> tensor<512x512xbf16>
    return %unpack : tensor<512x512xbf16>
  }
}
module {
  func.func @unpack_fusion(%arg0: tensor<16x32x32x32xbf16>, %arg1: tensor<16x32x32x32xbf16>, %arg2: tensor<16x16x32x32xbf16>, %arg3: tensor<512x512xbf16>) -> tensor<512x512xbf16> {
    // Basically we run over the tiled loops and insert the BRGEMM tile with stride of 32.
    %0 = scf.forall (%arg4, %arg5) in (16, 16) shared_outs(%arg6 = %arg3) -> (tensor<512x512xbf16>) {
      %extracted_slice = tensor.extract_slice %arg0[%arg4, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : tensor<16x32x32x32xbf16> to tensor<32x32x32xbf16>
      %extracted_slice_0 = tensor.extract_slice %arg1[%arg5, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : tensor<16x32x32x32xbf16> to tensor<32x32x32xbf16>
      %extracted_slice_1 = tensor.extract_slice %arg2[%arg4, %arg5, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<16x16x32x32xbf16> to tensor<32x32xbf16>
      %1 = linalg.batch_reduce_matmul ins(%extracted_slice, %extracted_slice_0 : tensor<32x32x32xbf16>, tensor<32x32x32xbf16>) outs(%extracted_slice_1 : tensor<32x32xbf16>) -> tensor<32x32xbf16>
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %1 into %arg6[%arg4, %arg5] [32, 32] [32, 32] : tensor<32x32xbf16> into tensor<512x512xbf16>
      }
    }
    return %0 : tensor<512x512xbf16>
  }
}