Skip to content

Commit

Permalink
TOSA: concat: fix canonicalization that would result in concat with n…
Browse files Browse the repository at this point in the history
…o operands

The fold() of the concat had two issues:
a) It would fold `concat (tensor<0x2>, tensor<0x4>), axis=1 -> tensor<0x6>` into `concat ()` (no operands)
based on the observation that both operands have zero elements. This is invalid as concat needs to have
at least one operand.

b) After that fix, it would fold `concat (tensor<0x2>, tensor<0x4>), axis=1 -> tensor<0x6>`
int `concat (tensor<0x4>), axis=1 -> tensor<0x6>`. This is also invalid; even though the concat
still produces the same number of elements (= none), shape relations are not corret (0x4 input, but 0x6 output).

I fixed that by only removing operands from the concat when
- the operand has zero dim on the concatenation axis (i.e. it doesn't contribute to the result shape)
- there is at least one operand left after removing
  • Loading branch information
mgehre-amd committed Dec 16, 2024
1 parent 63401e3 commit ae05589
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
12 changes: 3 additions & 9 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1349,19 +1349,13 @@ OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
return {};
}

static bool hasZeroSize(Type ty) {
auto ranked = dyn_cast<RankedTensorType>(ty);
if (!ranked)
return false;
return any_of(ranked.getShape(), [](auto d) { return d == 0; });
}

OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
/// Remove operands that have zero elements.
bool changed = false;
for (size_t i = 0; i < getInput1().size(); ) {
auto input = getInput1()[i];
if (hasZeroSize(input.getType())) {
auto input = cast<RankedTensorType>(getInput1()[i].getType());
// Ensure that we have at least one operand left.
if (input.getDimSize(getAxis()) == 0 && getInput1().size() > 1) {
getInput1Mutable().erase(i);
changed = true;
} else {
Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,32 @@ func.func @concat_fold_zero(%arg0: tensor<?x0xf32>, %arg1: tensor<?x1xf32>, %arg
%0 = tosa.concat %arg0, %arg1, %arg2 {axis = 1 : i32}: (tensor<?x0xf32>, tensor<?x1xf32>, tensor<?x2xf32>) -> tensor<?x3xf32>
return %0 : tensor<?x3xf32>
}
// -----

// CHECK-LABEL: @concat_fold_zero
func.func @concat_fold_zero_all(%arg0: tensor<?x0xf32>, %arg1: tensor<?x0xf32>) -> tensor<?x0xf32> {
// CHECK: return %arg1
%0 = tosa.concat %arg0, %arg1 {axis = 1 : i32}: (tensor<?x0xf32>, tensor<?x0xf32>) -> tensor<?x0xf32>
return %0 : tensor<?x0xf32>
}

// -----

// CHECK-LABEL: @concat_fold_zero
func.func @concat_fold_zero_different_axis(%arg0: tensor<0x2xf32>, %arg1: tensor<0x4xf32>) -> tensor<0x6xf32> {
// CHECK: tosa.concat %arg0, %arg1
%0 = tosa.concat %arg0, %arg1 {axis = 1 : i32}: (tensor<0x2xf32>, tensor<0x4xf32>) -> tensor<0x6xf32>
return %0 : tensor<0x6xf32>
}

// -----

// CHECK-LABEL: @concat_fold_zero_size
func.func @concat_fold_zero_size(%arg0: tensor<?x0xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x2xf32>) -> tensor<?x3xf32> {
// CHECK: tosa.concat %arg1, %arg2 {axis = 1 : i32}
%0 = tosa.concat %arg0, %arg1, %arg2 {axis = 1 : i32}: (tensor<?x0xf32>, tensor<?x1xf32>, tensor<?x2xf32>) -> tensor<?x3xf32>
return %0 : tensor<?x3xf32>
}

// -----

Expand Down

0 comments on commit ae05589

Please sign in to comment.