From ae055899259a9dfb7229c7559fc086aaa31952b7 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Mon, 16 Dec 2024 11:56:32 +0100 Subject: [PATCH] TOSA: concat: fix canonicalization that would result in concat with no operands The fold() of the concat had two issues: a) It would fold `concat (tensor<0x2>, tensor<0x4>), axis=1 -> tensor<0x6>` into `concat ()` (no operands) based on the observation that both operands have zero elements. This is invalid as concat needs to have at least one operand. b) After that fix, it would fold `concat (tensor<0x2>, tensor<0x4>), axis=1 -> tensor<0x6>` int `concat (tensor<0x4>), axis=1 -> tensor<0x6>`. This is also invalid; even though the concat still produces the same number of elements (= none), shape relations are not corret (0x4 input, but 0x6 output). I fixed that by only removing operands from the concat when - the operand has zero dim on the concatenation axis (i.e. it doesn't contribute to the result shape) - there is at least one operand left after removing --- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 12 +++------ mlir/test/Dialect/Tosa/canonicalize.mlir | 26 +++++++++++++++++++ 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 5d31b18ed7525c..1af7ba4bbd5134 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1349,19 +1349,13 @@ OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) { return {}; } -static bool hasZeroSize(Type ty) { - auto ranked = dyn_cast(ty); - if (!ranked) - return false; - return any_of(ranked.getShape(), [](auto d) { return d == 0; }); -} - OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { /// Remove operands that have zero elements. bool changed = false; for (size_t i = 0; i < getInput1().size(); ) { - auto input = getInput1()[i]; - if (hasZeroSize(input.getType())) { + auto input = cast(getInput1()[i].getType()); + // Ensure that we have at least one operand left. + if (input.getDimSize(getAxis()) == 0 && getInput1().size() > 1) { getInput1Mutable().erase(i); changed = true; } else { diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index f35df639cca523..4e10833a775c39 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -204,6 +204,32 @@ func.func @concat_fold_zero(%arg0: tensor, %arg1: tensor, %arg %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 1 : i32}: (tensor, tensor, tensor) -> tensor return %0 : tensor } +// ----- + +// CHECK-LABEL: @concat_fold_zero +func.func @concat_fold_zero_all(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: return %arg1 + %0 = tosa.concat %arg0, %arg1 {axis = 1 : i32}: (tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @concat_fold_zero +func.func @concat_fold_zero_different_axis(%arg0: tensor<0x2xf32>, %arg1: tensor<0x4xf32>) -> tensor<0x6xf32> { + // CHECK: tosa.concat %arg0, %arg1 + %0 = tosa.concat %arg0, %arg1 {axis = 1 : i32}: (tensor<0x2xf32>, tensor<0x4xf32>) -> tensor<0x6xf32> + return %0 : tensor<0x6xf32> +} + +// ----- + +// CHECK-LABEL: @concat_fold_zero_size +func.func @concat_fold_zero_size(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: tosa.concat %arg1, %arg2 {axis = 1 : i32} + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 1 : i32}: (tensor, tensor, tensor) -> tensor + return %0 : tensor +} // -----