From 84aa02d3fa1f1f614c4f3c144ec118b2f05ae6b0 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Fri, 23 Aug 2024 06:52:09 +0100 Subject: [PATCH] [memref] Handle edge case in subview of full static size fold (#105635) It is possible to have a subview with a fully static size and a type that matches the source type, but a dynamic offset that may be different. However, currently the memref dialect folds: ```mlir func.func @subview_of_static_full_size( %arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %idx: index) -> memref<16x4xf32, strided<[4, 1], offset: ?>> { %0 = memref.subview %arg0[%idx, 0][16, 4][1, 1] : memref<16x4xf32, strided<[4, 1], offset: ?>> to memref<16x4xf32, strided<[4, 1], offset: ?>> return %0 : memref<16x4xf32, strided<[4, 1], offset: ?>> } ``` To: ```mlir func.func @subview_of_static_full_size( %arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %arg1: index) -> memref<16x4xf32, strided<[4, 1], offset: ?>> { return %arg0 : memref<16x4xf32, strided<[4, 1], offset: ?>> } ``` Which drops the dynamic offset from the `subview` op. --- mlir/include/mlir/IR/BuiltinAttributes.td | 4 ++++ mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 15 +++++++++------ mlir/lib/IR/BuiltinAttributes.cpp | 7 +++++++ mlir/test/Dialect/MemRef/canonicalize.mlir | 13 +++++++++++++ 4 files changed, 33 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td index d9295936ee97bd..f0d41754001400 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -1012,6 +1012,10 @@ def StridedLayoutAttr : Builtin_Attr<"StridedLayout", "strided_layout", let extraClassDeclaration = [{ /// Print the attribute to the given output stream. void print(raw_ostream &os) const; + + /// Returns true if this layout is static, i.e. the strides and offset all + /// have a known value > 0. + bool hasStaticLayout() const; }]; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 150049e5c5effe..9c021d3613f1c8 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -3279,11 +3279,14 @@ void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, } OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) { - auto resultShapedType = llvm::cast(getResult().getType()); - auto sourceShapedType = llvm::cast(getSource().getType()); - - if (resultShapedType.hasStaticShape() && - resultShapedType == sourceShapedType) { + MemRefType sourceMemrefType = getSource().getType(); + MemRefType resultMemrefType = getResult().getType(); + auto resultLayout = + dyn_cast_if_present(resultMemrefType.getLayout()); + + if (resultMemrefType == sourceMemrefType && + resultMemrefType.hasStaticShape() && + (!resultLayout || resultLayout.hasStaticLayout())) { return getViewSource(); } @@ -3301,7 +3304,7 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) { strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); }); bool allSizesSame = llvm::equal(sizes, srcSizes); if (allOffsetsZero && allStridesOne && allSizesSame && - resultShapedType == sourceShapedType) + resultMemrefType == sourceMemrefType) return getViewSource(); } diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index 89b1ed67f5d067..8861a940336133 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -229,6 +229,13 @@ void StridedLayoutAttr::print(llvm::raw_ostream &os) const { os << ">"; } +/// Returns true if this layout is static, i.e. the strides and offset all have +/// a known value > 0. +bool StridedLayoutAttr::hasStaticLayout() const { + return !ShapedType::isDynamic(getOffset()) && + !ShapedType::isDynamicShape(getStrides()); +} + /// Returns the strided layout as an affine map. AffineMap StridedLayoutAttr::getAffineMap() const { return makeStridedLinearLayoutMap(getStrides(), getOffset(), getContext()); diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index b15af9baca7dc7..02110bc2892d05 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -70,6 +70,19 @@ func.func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4 // ----- +// CHECK-LABEL: func @negative_subview_of_static_full_size +// CHECK-SAME: %[[ARG0:.+]]: memref<16x4xf32, strided<[4, 1], offset: ?>> +// CHECK-SAME: %[[IDX:.+]]: index +// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][%[[IDX]], 0] [16, 4] [1, 1] +// CHECK-SAME: to memref<16x4xf32, strided<[4, 1], offset: ?>> +// CHECK: return %[[S]] : memref<16x4xf32, strided<[4, 1], offset: ?>> +func.func @negative_subview_of_static_full_size(%arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %idx: index) -> memref<16x4xf32, strided<[4, 1], offset: ?>> { + %0 = memref.subview %arg0[%idx, 0][16, 4][1, 1] : memref<16x4xf32, strided<[4, 1], offset: ?>> to memref<16x4xf32, strided<[4, 1], offset: ?>> + return %0 : memref<16x4xf32, strided<[4, 1], offset: ?>> +} + +// ----- + func.func @subview_canonicalize(%arg0 : memref, %arg1 : index, %arg2 : index) -> memref> {