Skip to content

Commit

Permalink
[mlir][linalg] Fix rank-reduced cases for extract/insert slice in Dro…
Browse files Browse the repository at this point in the history
…pUnitDims (llvm#74723)

Inferring the reshape reassociation indices for extract/insert slice ops
based on the read sizes of the original slicing op will generate an
invalid expand/collapse shape op for already rank-reduced cases. Instead
just infer from the shape of the slice.

Ported from Differential Revision: https://reviews.llvm.org/D147488
  • Loading branch information
qedawkins authored Dec 16, 2023
1 parent c398fa0 commit 82ab0f7
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 8 deletions.
21 changes: 13 additions & 8 deletions mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,13 +572,17 @@ struct RankReducedExtractSliceOp
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
RankedTensorType resultType = sliceOp.getType();
SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
SmallVector<OpFoldResult> targetShape;
for (auto size : resultType.getShape())
targetShape.push_back(rewriter.getIndexAttr(size));
auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
if (!reassociation ||
reassociation->size() == static_cast<size_t>(resultType.getRank()))
return failure();

SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
auto rankReducedType = cast<RankedTensorType>(
tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
Expand All @@ -602,13 +606,14 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
PatternRewriter &rewriter) const override {
RankedTensorType sourceType = insertSliceOp.getSourceType();
SmallVector<OpFoldResult> offsets = insertSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = insertSliceOp.getMixedSizes();
SmallVector<OpFoldResult> strides = insertSliceOp.getMixedStrides();
auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
SmallVector<OpFoldResult> targetShape;
for (auto size : sourceType.getShape())
targetShape.push_back(rewriter.getIndexAttr(size));
auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
if (!reassociation ||
reassociation->size() == static_cast<size_t>(sourceType.getRank()))
return failure();

Location loc = insertSliceOp.getLoc();
tensor::CollapseShapeOp reshapedSource;
{
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,18 @@ func.func @slice_unit_dims(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> {

// -----

func.func @rank_reduced_extract_slice(%arg0: tensor<1x1x3x1x3xf32>) -> tensor<1x3x3xf32> {
%0 = tensor.extract_slice %arg0[0, 0, 0, 0, 0] [1, 1, 3, 1, 3] [1, 1, 1, 1, 1] : tensor<1x1x3x1x3xf32> to tensor<1x3x3xf32>
return %0 : tensor<1x3x3xf32>
}
// CHECK-LABEL: func @rank_reduced_extract_slice
// CHECK: %[[SLICE:.+]] = tensor.extract_slice
// CHECK-SAME: tensor<1x1x3x1x3xf32> to tensor<3x3xf32>
// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SLICE]] {{\[}}[0, 1], [2]]
// CHECK: return %[[RESULT]]

// -----

func.func @insert_slice_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>) -> tensor<1x3xf32> {
%0 = tensor.insert_slice %arg1 into %arg0[0, 2] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<1x3xf32>
return %0 : tensor<1x3xf32>
Expand All @@ -501,6 +513,18 @@ func.func @insert_slice_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>

// -----

func.func @rank_reduced_insert_slice(%arg0: tensor<1x1x3x1x3xf32>, %arg1: tensor<1x3x3xf32>) -> tensor<1x1x3x1x3xf32> {
%0 = tensor.insert_slice %arg1 into %arg0[0, 0, 0, 0, 0] [1, 1, 3, 1, 3] [1, 1, 1, 1, 1] : tensor<1x3x3xf32> into tensor<1x1x3x1x3xf32>
return %0 : tensor<1x1x3x1x3xf32>
}
// CHECK-LABEL: func @rank_reduced_insert_slice
// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %{{.+}} {{\[}}[0, 1], [2]]
// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[RESHAPE]]
// CHECK-SAME: tensor<3x3xf32> into tensor<1x1x3x1x3xf32>
// CHECK: return %[[RESULT]]

// -----

#accesses = [
affine_map<(i, j, k, l, m) -> (i, k, m)>,
affine_map<(i, j, k, l, m) -> ()>,
Expand Down

0 comments on commit 82ab0f7

Please sign in to comment.