diff --git a/lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp b/lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp index 0748c36471f0..7882f30f7e10 100644 --- a/lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp +++ b/lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp @@ -240,6 +240,75 @@ struct EliminateDuplicateYieldsInProjectPattern } }; +struct EliminateIdentityYieldsInProjectPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ProjectOp op, + PatternRewriter &rewriter) const override { + MLIRContext *context = op.getContext(); + Operation *terminator = op.getExpressions().front().getTerminator(); + auto inputTupleType = cast(op.getInput().getType()); + auto resultTupleType = cast(op.getResult().getType()); + int64_t numInputFields = inputTupleType.size(); + + // Look for yielded values that are just forwarding input fields. + SmallVector newYields; + SmallVector mapping; + mapping.reserve(resultTupleType.size()); + append_range(mapping, seq(numInputFields)); + for (auto [i, value] : enumerate(terminator->getOperands())) { + // Test if this is a `field_reference` op that refers a top-level field. + auto refOp = value.getDefiningOp(); + if (refOp && refOp.getPosition().size() == 1) { + // Test if it refers to the block argument of the `expression` region. + auto arg = dyn_cast(refOp.getContainer()); + if (arg && arg.getOwner() == &op.getExpressions().front()) { + // This is a references forwarding a top-level field, so we'll express + // that with an `emit` op reordering the result of this op. + mapping.push_back(refOp.getPosition().front()); + continue; + } + } + + // This is not just a forwarding an input field, so we keep it. + mapping.push_back(numInputFields + newYields.size()); + newYields.push_back(value); + } + + if (newYields.size() == terminator->getNumOperands()) + return rewriter.notifyMatchFailure( + op, "does not yield unmodified input fields"); + + // Change the `yield` op to yield only those values we want to keep. + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(terminator); + terminator = rewriter.replaceOpWithNewOp(terminator, newYields); + } + + // Compute deduplicated output field types. + SmallVector outputTypes; + int64_t numNewYields = terminator->getNumOperands(); + outputTypes.reserve(inputTupleType.size() + numNewYields); + append_range(outputTypes, inputTupleType.getTypes()); + append_range(outputTypes, terminator->getOperandTypes()); + auto newOutputType = TupleType::get(context, outputTypes); + + // Create new `project` op with updated region. + auto newOp = + rewriter.create(op.getLoc(), newOutputType, op.getInput()); + rewriter.inlineRegionBefore(op.getExpressions(), newOp.getExpressions(), + newOp.getExpressions().end()); + + // Create `emit` op with a mapping that recreates the fields we removed. + ArrayAttr reverseMappingAttr = rewriter.getI64ArrayAttr(mapping); + rewriter.replaceOpWithNewOp(op, newOp, reverseMappingAttr); + + return success(); + } +}; + /// Pushes duplicates in the mappings of `emit` ops producing either of the two /// inputs through the `cross` op. This works by introducing new emit ops /// without the duplicates, creating a new `cross` op that uses them, and @@ -411,6 +480,7 @@ void populateEmitDeduplicationPatterns(RewritePatternSet &patterns) { patterns.add< // clang-format off EliminateDuplicateYieldsInProjectPattern, + EliminateIdentityYieldsInProjectPattern, PushDuplicatesThroughCrossPattern, PushDuplicatesThroughFilterPattern, PushDuplicatesThroughProjectPattern diff --git a/test/Transforms/Substrait/emit-deduplication.mlir b/test/Transforms/Substrait/emit-deduplication.mlir index 7cce4120364e..d5d1f9874b38 100644 --- a/test/Transforms/Substrait/emit-deduplication.mlir +++ b/test/Transforms/Substrait/emit-deduplication.mlir @@ -201,6 +201,8 @@ substrait.plan version 0 : 42 : 1 { func.func private @f(si32, si32) -> si1 +// XXX(ingomueller): How can we test individual patterns here? + // CHECK-LABEL: substrait.plan // CHECK-NEXT: relation // CHECK-NEXT: %[[V0:.*]] = named_table @@ -257,3 +259,33 @@ substrait.plan version 0 : 42 : 1 { yield %1 : tuple } } + +// ----- + +// `project` op (`EliminateIdentityYieldsInProjectPattern`). + +// CHECK-LABEL: substrait.plan +// CHECK-NEXT: relation +// CHECK-NEXT: %[[V0:.*]] = named_table +// CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : {{.*}} { +// CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: [[TYPE:.*]]): +// CHECK-NEXT: %[[V2:.*]] = field_reference %[[ARG0]]{{\[}}[0]] : [[TYPE]] +// CHECK-NEXT: %[[V3:.*]] = func.call @f(%[[V2]]) : +// CHECK-NEXT: yield %[[V3]] : si1 +// CHECK-NEXT: } +// CHECK-NEXT: %[[V4:.*]] = emit [0, 1, 0, 2] from %[[V1]] + +func.func private @f(si32) -> si1 + +substrait.plan version 0 : 42 : 1 { + relation { + %0 = named_table @t1 as ["a", "b"] : tuple + %1 = project %0 : tuple -> tuple { + ^bb0(%arg0: tuple): + %2 = field_reference %arg0[[0]] : tuple + %3 = func.call @f(%2) : (si32) -> si1 + yield %2, %3 : si32, si1 + } + yield %1 : tuple + } +}