diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 5d31b18ed7525c..1af7ba4bbd5134 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1349,19 +1349,13 @@ OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) { return {}; } -static bool hasZeroSize(Type ty) { - auto ranked = dyn_cast(ty); - if (!ranked) - return false; - return any_of(ranked.getShape(), [](auto d) { return d == 0; }); -} - OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { /// Remove operands that have zero elements. bool changed = false; for (size_t i = 0; i < getInput1().size(); ) { - auto input = getInput1()[i]; - if (hasZeroSize(input.getType())) { + auto input = cast(getInput1()[i].getType()); + // Ensure that we have at least one operand left. + if (input.getDimSize(getAxis()) == 0 && getInput1().size() > 1) { getInput1Mutable().erase(i); changed = true; } else { diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index f35df639cca523..4e10833a775c39 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -204,6 +204,32 @@ func.func @concat_fold_zero(%arg0: tensor, %arg1: tensor, %arg %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 1 : i32}: (tensor, tensor, tensor) -> tensor return %0 : tensor } +// ----- + +// CHECK-LABEL: @concat_fold_zero +func.func @concat_fold_zero_all(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: return %arg1 + %0 = tosa.concat %arg0, %arg1 {axis = 1 : i32}: (tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @concat_fold_zero +func.func @concat_fold_zero_different_axis(%arg0: tensor<0x2xf32>, %arg1: tensor<0x4xf32>) -> tensor<0x6xf32> { + // CHECK: tosa.concat %arg0, %arg1 + %0 = tosa.concat %arg0, %arg1 {axis = 1 : i32}: (tensor<0x2xf32>, tensor<0x4xf32>) -> tensor<0x6xf32> + return %0 : tensor<0x6xf32> +} + +// ----- + +// CHECK-LABEL: @concat_fold_zero_size +func.func @concat_fold_zero_size(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: tosa.concat %arg1, %arg2 {axis = 1 : i32} + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 1 : i32}: (tensor, tensor, tensor) -> tensor + return %0 : tensor +} // -----