-
Notifications
You must be signed in to change notification settings - Fork 31
Tensor Pack and Unpack
The objective of tensor.pack
and tensor.unpack
is to reorder tensor data to align in memory and create cache-friendly access.
For example, in a (row major) matrix multiply, the A
matrix read is in rows, while the B
matrix read is in columns. If we transpose the B
matrix, its reads would also be in rows and therefore both reads would be cache friendly.
These operations are profitable for larger matrices. Since a copy is O(n^2)
while matmul is O(n^3)
, a large enough matrix multiply will have more loads and stores inside the matmul than in the transpose, saving time.
Matrix multiplication are also typically tiled to increase cache reuse. When tile sizes are small enough, they can be fit into registers, allowing more efficient algorithms where the accumulation remains in registers throughout the reduction dimension of the tiles.
Therefore a "block-transpose", where whole tiles are moved "as-is" is even more profitable. Not only the tiles are read consecutively, but the packing copies also benefit from sequential reads, where an element-wise transpose does not.
tensor.pack
converts a smaller rank tensor into a larger rank tensor by dividing the dimensions given a tile size.
For example, a dimension of 1024
with a tile size of 64
would be split into 16 x 64
.
tensor.unpack
is the reciprocal operation. In the example above, it takes a tensor with dimension 16 x 64
and converts back to 1024
.
This allows one to pack
a 2D tensor into a 4D tensor with a particular 2D tile size:
%1 = tensor.pack %0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %empty : tensor<128x256xf32> -> tensor<4x8x32x32xf32>
Note:
- The
inner_tiles
is[32, 32]
, so the two inner dimensions become<...x32x32xf32>
- The
inner_dims_pos
is[0, 1]
, so the original dims that are tiled are0
(128
) and1
(256
) in that order. - The dimensionality of the tensor becomes:
<original0 x original1 x tile0 x tile1>
- A simple reshape would convert
<128x256xf32> / <32x32>
into<4x32x8x32xf32>
- Pack is not just a reshape, given that the
tile0
dimension has moved overoriginal1
alongsidetile1
. - In the end,
pack = reshape(shape/<tiles>) + transpose(tile-dims -> inner-dims)
The basic arguments are:
- Input tensor: The original tensor in the original shape
- Output tensor: The
into
tensor (empty
or not, for buffer reuse) with the packed shape -
inner_tiles
: The N-dimensional tile size -
inner_dims_pos
: The order in which the dimensions will be divided by the tile sizes -
outer_dims_perm
: A permutation to perform after the pack reshapes and transposes
inner_dims_pos
defines which original dimensions will be divided by the corresponding tile size.
inner_tiles
is simply the tile size, in as many dimensions as needed.
Must be the same order as the inner_dims_pos
argument.
Example:
for tile = [16, 64] and tensor = < 1024 x 512 x f32 >, pack is:
---------------------------------------------------------------
[ 1024 x 512 ]
reshape | \ | \
[ 64 x 16 x 8 x 64 ]
| \ / |
transpose | X |
| / \ |
[ 64 x 8 x 16 x 64 ]
The number of tile dimensions needs to be equal or less than the rank of the input tensor. If it has less dimensions, not all dimensions of the original tensor will be tiled.
Example:
<128x256x512xf32> / <16x8>[1, 2] = <128x16x64x16x8xf32>
* 128 is left alone because the positions are [1, 2], ie. 256 and 512.
* 256 is split to 16x16, with the tile dim (second 16) moved inside
* 512 is split into 64x8, with the tile dim (8) moved as the inner-most dimension
outer_dims_perm
defines what is the order of the packed shape's dimensions in the final shape.
The number of dimensions must be the same as the rank of the packed shape.
Example:
for packed shape <128x16x64x16x8xf32> and outer_dims_perm = [ 0, 4, 1, 3, 2 ]:
------------------------------------------------------------------------------
[ 128 x 16 x 64 x 16 x 8 ]
[ 0 1 2 3 4 ]
| \ \ | /
| \ .\-|--'
| .\--/ '|-.
| / \ | \
[ 128 x 8 x 16 x 16 x 64 ]
tpp-opt commit: 70332c24b427086668a820ed7af985f821ce4b6c The pack and unpack are for a GEMM of size: 512x512x1024 (M, N, K). The pack is for the A operand (B is constant fold). Unpack is for the output. The performance are compared with a memcpy of size 512x512.
In-house decomposition
~/tpp-sandbox-micro-bench# tpp-opt mlir/pack_gemm_operand_a.mlir -tpp-mapping -bufferize -convert-memref-to-xsmm
module {
func.func @pack_gemm_operand_a(%arg0: memref<512x1024xf32>, %arg1: memref<16x32x32x32xf32>) -> memref<16x32x32x32xf32> attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
scf.for %arg2 = %c0 to %c16 step %c1 {
scf.for %arg3 = %c0 to %c32 step %c1 {
%0 = arith.muli %arg2, %c32 : index
%1 = arith.muli %arg3, %c32 : index
%subview = memref.subview %arg0[%0, %1] [32, 32] [1, 1] : memref<512x1024xf32> to memref<32x32xf32, strided<[1024, 1], offset: ?>>
%subview_0 = memref.subview %arg1[%arg2, %arg3, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<16x32x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
%2 = xsmm.unary.dispatch identity [32, 32, 1024, 32] flags = (none) data_type = f32
xsmm.unary identity(data_type = f32, %2, %subview, %subview_0) : (i64, memref<32x32xf32, strided<[1024, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> ()
}
}
return %arg1 : memref<16x32x32x32xf32>
}
}
~/tpp-sandbox-micro-bench# tpp-opt mlir/unpack_gemm_operand.mlir -tpp-mapping -bufferize -convert-memref-to-xsmm
module {
func.func @unpack_gemm_operand(%arg0: memref<16x16x32x32xf32>, %arg1: memref<512x512xf32>) -> memref<512x512xf32> attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
scf.for %arg2 = %c0 to %c16 step %c1 {
scf.for %arg3 = %c0 to %c16 step %c1 {
%0 = arith.muli %arg2, %c32 : index
%1 = arith.muli %arg3, %c32 : index
%subview = memref.subview %arg0[%arg2, %arg3, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<16x16x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
%subview_0 = memref.subview %arg1[%0, %1] [32, 32] [1, 1] : memref<512x512xf32> to memref<32x32xf32, strided<[512, 1], offset: ?>>
%2 = xsmm.unary.dispatch identity [32, 32, 32, 512] flags = (none) data_type = f32
xsmm.unary identity(data_type = f32, %2, %subview, %subview_0) : (i64, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[512, 1], offset: ?>>) -> ()
}
}
return %arg1 : memref<512x512xf32>
}
}
CPU Caches:
L1 Data 32 KiB (x6)
L1 Instruction 32 KiB (x6)
L2 Unified 256 KiB (x6)
L3 Unified 12288 KiB (x1)
Load Average: 0.72, 0.77, 0.55
----------------------------------------------------------------------------------------
Benchmark Time CPU Iterations UserCounters...
----------------------------------------------------------------------------------------
BM_memcpy/262144_mean 13967 ns 13965 ns 50 bytes_per_second=69.9351Gi/s
BM_memcpy/262144_median 13912 ns 13910 ns 50 bytes_per_second=70.206Gi/s
BM_memcpy/262144_stddev 151 ns 151 ns 50 bytes_per_second=760.714Mi/s
BM_memcpy/262144_cv 1.08 % 1.08 % 50 bytes_per_second=1.06%
BM_pack_gemm_operand_a_mean 162372 ns 162358 ns 50 bytes_per_second=12.047Gi/s
BM_pack_gemm_operand_a_median 165627 ns 165600 ns 50 bytes_per_second=11.7942Gi/s
BM_pack_gemm_operand_a_stddev 6157 ns 6155 ns 50 bytes_per_second=475.283Mi/s
BM_pack_gemm_operand_a_cv 3.79 % 3.79 % 50 bytes_per_second=3.85%
BM_unpack_gemm_operand_mean 95195 ns 95188 ns 50 bytes_per_second=10.3174Gi/s
BM_unpack_gemm_operand_median 98148 ns 98139 ns 50 bytes_per_second=9.95084Gi/s
BM_unpack_gemm_operand_stddev 6701 ns 6701 ns 50 bytes_per_second=862.63Mi/s
BM_unpack_gemm_operand_cv 7.04 % 7.04 % 50 bytes_per_second=8.16%
Upstream linalg lowering. NOTE: Linalg is lowered to loops
~/tpp-sandbox-micro-bench# tpp-opt mlir/pack_gemm_operand_a.mlir -tpp-mapping -bufferize -convert-memref-to-xsmm
module {
func.func @pack_gemm_operand_a(%arg0: memref<512x1024xf32>, %arg1: memref<16x32x32x32xf32>) attributes {llvm.emit_c_interface} {
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<512x1024xf32> into memref<16x32x32x32xf32>
linalg.transpose ins(%expand_shape : memref<16x32x32x32xf32>) outs(%arg1 : memref<16x32x32x32xf32>) permutation = [0, 2, 1, 3]
return
}
}
module {
func.func @unpack_gemm_operand(%arg0: memref<16x16x32x32xf32>, %arg1: memref<512x512xf32>) attributes {llvm.emit_c_interface} {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<16x32x16x32xf32>
linalg.transpose ins(%arg0 : memref<16x16x32x32xf32>) outs(%alloc : memref<16x32x16x32xf32>) permutation = [0, 2, 1, 3]
%collapse_shape = memref.collapse_shape %alloc [[0, 1], [2, 3]] : memref<16x32x16x32xf32> into memref<512x512xf32>
linalg.copy ins(%collapse_shape : memref<512x512xf32>) outs(%arg1 : memref<512x512xf32>)
memref.dealloc %alloc : memref<16x32x16x32xf32>
return
}
}
CPU Caches:
L1 Data 32 KiB (x6)
L1 Instruction 32 KiB (x6)
L2 Unified 256 KiB (x6)
L3 Unified 12288 KiB (x1)
Load Average: 0.83, 0.65, 0.42
----------------------------------------------------------------------------------------
Benchmark Time CPU Iterations UserCounters...
----------------------------------------------------------------------------------------
BM_pack_gemm_operand_a_mean 352527 ns 352477 ns 50 bytes_per_second=5.54246Gi/s
BM_pack_gemm_operand_a_median 354444 ns 354424 ns 50 bytes_per_second=5.5107Gi/s
BM_pack_gemm_operand_a_stddev 5449 ns 5477 ns 50 bytes_per_second=89.1645Mi/s
BM_pack_gemm_operand_a_cv 1.55 % 1.55 % 50 bytes_per_second=1.57%
BM_unpack_gemm_operand_mean 407236 ns 407182 ns 50 bytes_per_second=2.39837Gi/s
BM_unpack_gemm_operand_median 406341 ns 406297 ns 50 bytes_per_second=2.40357Gi/s
BM_unpack_gemm_operand_stddev 1438 ns 1448 ns 50 bytes_per_second=8.70749Mi/s
BM_unpack_gemm_operand_cv 0.35 % 0.36 % 50 bytes_per_second=0.35%
BM_memcpy/262144_mean 14219 ns 14218 ns 50 bytes_per_second=68.7323Gi/s
BM_memcpy/262144_median 14039 ns 14038 ns 50 bytes_per_second=69.5664Gi/s
BM_memcpy/262144_stddev 384 ns 384 ns 50 bytes_per_second=1.81746Gi/s
BM_memcpy/262144_cv 2.70 % 2.70 % 50 bytes_per_second=2.64%
func.func @unpack_fusion(%arg0: tensor<16x32x32x32xbf16>, %arg1: tensor<16x32x32x32xbf16>, %arg2: tensor<16x16x32x32xbf16>, %arg3: tensor<512x512xbf16>) -> tensor<512x512xbf16> {
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x32x32x32xbf16>, tensor<16x32x32x32xbf16>) outs(%arg2 : tensor<16x16x32x32xbf16>) {
^bb0(%in: bf16, %in_0: bf16, %out: bf16):
%1 = arith.mulf %in, %in_0 : bf16
%2 = arith.addf %out, %1 : bf16
linalg.yield %2 : bf16
} -> tensor<16x16x32x32xbf16>
%unpack = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg3 : tensor<16x16x32x32xbf16> -> tensor<512x512xbf16>
return %unpack : tensor<512x512xbf16>
}
#map = affine_map<(d0) -> (d0 floordiv 32)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
module {
func.func @unpack_fusion(%arg0: tensor<16x32x32x32xbf16>, %arg1: tensor<16x32x32x32xbf16>, %arg2: tensor<16x16x32x32xbf16>, %arg3: tensor<512x512xbf16>) -> tensor<512x512xbf16> {
// We tile the unpack using the inner tile sizes and fuse the packed matmul inside.
// The unpack is now replaced by the BRGEMM directly writing in %arg3.
%0 = scf.forall (%arg4, %arg5) = (0, 0) to (512, 512) step (32, 32) shared_outs(%arg6 = %arg3) -> (tensor<512x512xbf16>) {
%1 = affine.apply #map(%arg4)
%2 = affine.apply #map(%arg5)
%3 = affine.apply #map(%arg4)
%4 = affine.apply #map(%arg5)
%extracted_slice = tensor.extract_slice %arg0[%1, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : tensor<16x32x32x32xbf16> to tensor<32x32x32xbf16>
%extracted_slice_0 = tensor.extract_slice %arg1[%2, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : tensor<16x32x32x32xbf16> to tensor<32x32x32xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%3, %4, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<16x16x32x32xbf16> to tensor<32x32xbf16>
%5 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["reduction", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %extracted_slice_0 : tensor<32x32x32xbf16>, tensor<32x32x32xbf16>) outs(%extracted_slice_1 : tensor<32x32xbf16>) {
^bb0(%in: bf16, %in_2: bf16, %out: bf16):
%6 = arith.mulf %in, %in_2 : bf16
%7 = arith.addf %out, %6 : bf16
linalg.yield %7 : bf16
} -> tensor<32x32xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %5 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<512x512xbf16>
}
}
return %0 : tensor<512x512xbf16>
}
}
module {
func.func @unpack_fusion(%arg0: tensor<16x32x32x32xbf16>, %arg1: tensor<16x32x32x32xbf16>, %arg2: tensor<16x16x32x32xbf16>, %arg3: tensor<512x512xbf16>) -> tensor<512x512xbf16> {
// In the main fusion pass however, we tile along the outer parallel dimensions of given contraction operations.
// Can we fuse the unpack in this loops? Yes, we can see IR below.
%0 = scf.forall (%arg4, %arg5) in (16, 16) shared_outs(%arg6 = %arg2) -> (tensor<16x16x32x32xbf16>) {
%extracted_slice = tensor.extract_slice %arg0[%arg4, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : tensor<16x32x32x32xbf16> to tensor<32x32x32xbf16>
%extracted_slice_0 = tensor.extract_slice %arg1[%arg5, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : tensor<16x32x32x32xbf16> to tensor<32x32x32xbf16>
%extracted_slice_1 = tensor.extract_slice %arg6[%arg4, %arg5, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<16x16x32x32xbf16> to tensor<32x32xbf16>
%1 = linalg.batch_reduce_matmul ins(%extracted_slice, %extracted_slice_0 : tensor<32x32x32xbf16>, tensor<32x32x32xbf16>) outs(%extracted_slice_1 : tensor<32x32xbf16>) -> tensor<32x32xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %1 into %arg6[%arg4, %arg5, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xbf16> into tensor<16x16x32x32xbf16>
}
}
%unpack = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg3 : tensor<16x16x32x32xbf16> -> tensor<512x512xbf16>
return %unpack : tensor<512x512xbf16>
}
}
module {
func.func @unpack_fusion(%arg0: tensor<16x32x32x32xbf16>, %arg1: tensor<16x32x32x32xbf16>, %arg2: tensor<16x16x32x32xbf16>, %arg3: tensor<512x512xbf16>) -> tensor<512x512xbf16> {
// Basically we run over the tiled loops and insert the BRGEMM tile with stride of 32.
%0 = scf.forall (%arg4, %arg5) in (16, 16) shared_outs(%arg6 = %arg3) -> (tensor<512x512xbf16>) {
%extracted_slice = tensor.extract_slice %arg0[%arg4, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : tensor<16x32x32x32xbf16> to tensor<32x32x32xbf16>
%extracted_slice_0 = tensor.extract_slice %arg1[%arg5, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : tensor<16x32x32x32xbf16> to tensor<32x32x32xbf16>
%extracted_slice_1 = tensor.extract_slice %arg2[%arg4, %arg5, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<16x16x32x32xbf16> to tensor<32x32xbf16>
%1 = linalg.batch_reduce_matmul ins(%extracted_slice, %extracted_slice_0 : tensor<32x32x32xbf16>, tensor<32x32x32xbf16>) outs(%extracted_slice_1 : tensor<32x32xbf16>) -> tensor<32x32xbf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %1 into %arg6[%arg4, %arg5] [32, 32] [32, 32] : tensor<32x32xbf16> into tensor<512x512xbf16>
}
}
return %0 : tensor<512x512xbf16>
}
}