Skip to content

Commit

Permalink
Merge pull request #427 from Xilinx/matthias.fix_concat_zero_fold
Browse files Browse the repository at this point in the history
TOSA: concat: fix canonicalization that would result in concat with no operands
  • Loading branch information
mgehre-amd authored Dec 16, 2024
2 parents c9c2863 + ae05589 commit 14e4586
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
12 changes: 3 additions & 9 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1349,19 +1349,13 @@ OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
return {};
}

static bool hasZeroSize(Type ty) {
auto ranked = dyn_cast<RankedTensorType>(ty);
if (!ranked)
return false;
return any_of(ranked.getShape(), [](auto d) { return d == 0; });
}

OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
/// Remove operands that have zero elements.
bool changed = false;
for (size_t i = 0; i < getInput1().size(); ) {
auto input = getInput1()[i];
if (hasZeroSize(input.getType())) {
auto input = cast<RankedTensorType>(getInput1()[i].getType());
// Ensure that we have at least one operand left.
if (input.getDimSize(getAxis()) == 0 && getInput1().size() > 1) {
getInput1Mutable().erase(i);
changed = true;
} else {
Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,32 @@ func.func @concat_fold_zero(%arg0: tensor<?x0xf32>, %arg1: tensor<?x1xf32>, %arg
%0 = tosa.concat %arg0, %arg1, %arg2 {axis = 1 : i32}: (tensor<?x0xf32>, tensor<?x1xf32>, tensor<?x2xf32>) -> tensor<?x3xf32>
return %0 : tensor<?x3xf32>
}
// -----

// CHECK-LABEL: @concat_fold_zero
func.func @concat_fold_zero_all(%arg0: tensor<?x0xf32>, %arg1: tensor<?x0xf32>) -> tensor<?x0xf32> {
// CHECK: return %arg1
%0 = tosa.concat %arg0, %arg1 {axis = 1 : i32}: (tensor<?x0xf32>, tensor<?x0xf32>) -> tensor<?x0xf32>
return %0 : tensor<?x0xf32>
}

// -----

// CHECK-LABEL: @concat_fold_zero
func.func @concat_fold_zero_different_axis(%arg0: tensor<0x2xf32>, %arg1: tensor<0x4xf32>) -> tensor<0x6xf32> {
// CHECK: tosa.concat %arg0, %arg1
%0 = tosa.concat %arg0, %arg1 {axis = 1 : i32}: (tensor<0x2xf32>, tensor<0x4xf32>) -> tensor<0x6xf32>
return %0 : tensor<0x6xf32>
}

// -----

// CHECK-LABEL: @concat_fold_zero_size
func.func @concat_fold_zero_size(%arg0: tensor<?x0xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x2xf32>) -> tensor<?x3xf32> {
// CHECK: tosa.concat %arg1, %arg2 {axis = 1 : i32}
%0 = tosa.concat %arg0, %arg1, %arg2 {axis = 1 : i32}: (tensor<?x0xf32>, tensor<?x1xf32>, tensor<?x2xf32>) -> tensor<?x3xf32>
return %0 : tensor<?x3xf32>
}

// -----

Expand Down

0 comments on commit 14e4586

Please sign in to comment.