Skip to content

Commit

Permalink
[memref] Handle edge case in subview of full static size fold (llvm#1…
Browse files Browse the repository at this point in the history
…05635)

It is possible to have a subview with a fully static size and a type
that matches the source type, but a dynamic offset that may be
different. However, currently the memref dialect folds:

```mlir
func.func @subview_of_static_full_size(
  %arg0:  memref<16x4xf32,  strided<[4, 1], offset: ?>>, %idx: index)
  -> memref<16x4xf32,  strided<[4, 1], offset: ?>>
{
  %0 = memref.subview %arg0[%idx, 0][16, 4][1, 1]
   : memref<16x4xf32,  strided<[4, 1], offset: ?>>
     to memref<16x4xf32,  strided<[4, 1], offset: ?>>
  return %0 : memref<16x4xf32,  strided<[4, 1], offset: ?>>
}
```

To:

```mlir
func.func @subview_of_static_full_size(
  %arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %arg1: index)
  -> memref<16x4xf32, strided<[4, 1], offset: ?>>
{
  return %arg0 : memref<16x4xf32, strided<[4, 1], offset: ?>>
}
```

Which drops the dynamic offset from the `subview` op.
  • Loading branch information
MacDue authored Aug 23, 2024
1 parent 96b3166 commit 84aa02d
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 6 deletions.
4 changes: 4 additions & 0 deletions mlir/include/mlir/IR/BuiltinAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,10 @@ def StridedLayoutAttr : Builtin_Attr<"StridedLayout", "strided_layout",
let extraClassDeclaration = [{
/// Print the attribute to the given output stream.
void print(raw_ostream &os) const;

/// Returns true if this layout is static, i.e. the strides and offset all
/// have a known value > 0.
bool hasStaticLayout() const;
}];
}

Expand Down
15 changes: 9 additions & 6 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3279,11 +3279,14 @@ void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
}

OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
auto resultShapedType = llvm::cast<ShapedType>(getResult().getType());
auto sourceShapedType = llvm::cast<ShapedType>(getSource().getType());

if (resultShapedType.hasStaticShape() &&
resultShapedType == sourceShapedType) {
MemRefType sourceMemrefType = getSource().getType();
MemRefType resultMemrefType = getResult().getType();
auto resultLayout =
dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());

if (resultMemrefType == sourceMemrefType &&
resultMemrefType.hasStaticShape() &&
(!resultLayout || resultLayout.hasStaticLayout())) {
return getViewSource();
}

Expand All @@ -3301,7 +3304,7 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
bool allSizesSame = llvm::equal(sizes, srcSizes);
if (allOffsetsZero && allStridesOne && allSizesSame &&
resultShapedType == sourceShapedType)
resultMemrefType == sourceMemrefType)
return getViewSource();
}

Expand Down
7 changes: 7 additions & 0 deletions mlir/lib/IR/BuiltinAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,13 @@ void StridedLayoutAttr::print(llvm::raw_ostream &os) const {
os << ">";
}

/// Returns true if this layout is static, i.e. the strides and offset all have
/// a known value > 0.
bool StridedLayoutAttr::hasStaticLayout() const {
return !ShapedType::isDynamic(getOffset()) &&
!ShapedType::isDynamicShape(getStrides());
}

/// Returns the strided layout as an affine map.
AffineMap StridedLayoutAttr::getAffineMap() const {
return makeStridedLinearLayoutMap(getStrides(), getOffset(), getContext());
Expand Down
13 changes: 13 additions & 0 deletions mlir/test/Dialect/MemRef/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,19 @@ func.func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4

// -----

// CHECK-LABEL: func @negative_subview_of_static_full_size
// CHECK-SAME: %[[ARG0:.+]]: memref<16x4xf32, strided<[4, 1], offset: ?>>
// CHECK-SAME: %[[IDX:.+]]: index
// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][%[[IDX]], 0] [16, 4] [1, 1]
// CHECK-SAME: to memref<16x4xf32, strided<[4, 1], offset: ?>>
// CHECK: return %[[S]] : memref<16x4xf32, strided<[4, 1], offset: ?>>
func.func @negative_subview_of_static_full_size(%arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %idx: index) -> memref<16x4xf32, strided<[4, 1], offset: ?>> {
%0 = memref.subview %arg0[%idx, 0][16, 4][1, 1] : memref<16x4xf32, strided<[4, 1], offset: ?>> to memref<16x4xf32, strided<[4, 1], offset: ?>>
return %0 : memref<16x4xf32, strided<[4, 1], offset: ?>>
}

// -----

func.func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
%arg2 : index) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
{
Expand Down

0 comments on commit 84aa02d

Please sign in to comment.