From b7b6d54004ef8a89dc3bad411f11a1ef93319a13 Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Wed, 15 Nov 2023 14:14:33 +0000 Subject: [PATCH] [mlir][vector] Add vector.transpose with unit-dim to vector.shape_cast pattern (#72105) This patch extends the vector.transpose lowering to replace: vector.transpose %0, [1, 0] : vector> to vector<1xnx> with: vector.shape_cast %0 : vector> to vector<1xnx> Source with leading unit-dim (inverse) is also replaced. Unit dim must be fixed. Non-unit dim can be scalable. A check is also added to bail out for scalable vectors before unrolling. --- .../Transforms/LowerVectorTranspose.cpp | 21 ++++++ .../Vector/vector-transpose-lowering.mlir | 71 +++++++++++++++++++ 2 files changed, 92 insertions(+) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index 7d804ddcfa42ff..dee786007c8063 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -336,6 +336,27 @@ class TransposeOpLowering : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "Options specifies lowering to shuffle"); + // Replace: + // vector.transpose %0, [1, 0] : vector> to + // vector<1xnxelty> + // with: + // vector.shape_cast %0 : vector> to vector<1xnxelty> + // + // Source with leading unit dim (inverse) is also replaced. Unit dim must + // be fixed. Non-unit can be scalable. + if (resType.getRank() == 2 && + ((resType.getShape().front() == 1 && + !resType.getScalableDims().front()) || + (resType.getShape().back() == 1 && + !resType.getScalableDims().back())) && + transp == ArrayRef({1, 0})) { + rewriter.replaceOpWithNewOp(op, resType, input); + return success(); + } + + if (inputType.isScalable()) + return failure(); + // Handle a true 2-D matrix transpose differently when requested. if (vectorTransformOptions.vectorTransposeLowering == vector::VectorTransposeLowering::Flat && diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir index 22d9224838c49c..c0b44428d5bcf3 100644 --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -74,6 +74,17 @@ func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8 return %0 : vector<1x1x8x8xf32> } +/// Scalable dim should not be unrolled. + +// CHECK-LABEL: func @transpose23_scalable +// CHECK-NOT: vector.extract +// CHECK-NOT: vector.insert +// CHECK: vector.transpose +func.func @transpose23_scalable(%arg0: vector<2x[3]xf32>) -> vector<[3]x2xf32> { + %0 = vector.transpose %arg0, [1, 0] : vector<2x[3]xf32> to vector<[3]x2xf32> + return %0 : vector<[3]x2xf32> +} + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) { transform.apply_patterns to %func_op { @@ -778,3 +789,63 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +/// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast. + +// CHECK-LABEL: func @transpose10_4x1xf32 +func.func @transpose10_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> { + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32> + %0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32> + return %0 : vector<1x4xf32> +} + +// CHECK-LABEL: func @transpose10_nx4x1xf32 +func.func @transpose10_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> { + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32> + %0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> + return %0 : vector<1x[4]xf32> +} + +// CHECK-LABEL: func @transpose10_1x4xf32 +func.func @transpose10_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> { + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32> + %0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32> + return %0 : vector<4x1xf32> +} + +// CHECK-LABEL: func @transpose10_1xnx4xf32 +func.func @transpose10_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> { + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32> + %0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32> + return %0 : vector<[4]x1xf32> +} + +/// Scalable unit dim should not be lowered to shape_cast. + +// CHECK-LABEL: func @transpose10_4xnx1xf32 +func.func @transpose10_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> { + // CHECK-NOT: vector.shape_cast + // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32> + %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32> + return %0 : vector<[1]x4xf32> +} + +// CHECK-LABEL: func @transpose10_nx4xnx1xf32 +func.func @transpose10_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> { + // CHECK-NOT: vector.shape_cast + // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32> + %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32> + + return %0 : vector<[1]x4xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) { + transform.apply_patterns to %func_op { + transform.apply_patterns.vector.lower_transpose + } : !transform.op<"func.func"> + transform.yield + } +}