Skip to content

Commit

Permalink
[Substrait] Add pattern that eleminates yielding of input fields.
Browse files Browse the repository at this point in the history
This PR adds a pattern to the emit deduplication pass that removes
values yielded from `project` that are `field_reference`s of top-level
input fields and expresses them through an `emit` op instead.

Signed-off-by: Ingo Müller <ingomueller@google.com>
  • Loading branch information
ingomueller-net committed May 31, 2024
1 parent ccddb3c commit f12a311
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 0 deletions.
70 changes: 70 additions & 0 deletions lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,75 @@ struct EliminateDuplicateYieldsInProjectPattern
}
};

struct EliminateIdentityYieldsInProjectPattern
: public OpRewritePattern<ProjectOp> {
using OpRewritePattern<ProjectOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ProjectOp op,
PatternRewriter &rewriter) const override {
MLIRContext *context = op.getContext();
Operation *terminator = op.getExpressions().front().getTerminator();
auto inputTupleType = cast<TupleType>(op.getInput().getType());
auto resultTupleType = cast<TupleType>(op.getResult().getType());
int64_t numInputFields = inputTupleType.size();

// Look for yielded values that are just forwarding input fields.
SmallVector<Value> newYields;
SmallVector<int64_t> mapping;
mapping.reserve(resultTupleType.size());
append_range(mapping, seq(numInputFields));
for (auto [i, value] : enumerate(terminator->getOperands())) {
// Test if this is a `field_reference` op that refers a top-level field.
auto refOp = value.getDefiningOp<FieldReferenceOp>();
if (refOp && refOp.getPosition().size() == 1) {
// Test if it refers to the block argument of the `expression` region.
auto arg = dyn_cast<BlockArgument>(refOp.getContainer());
if (arg && arg.getOwner() == &op.getExpressions().front()) {
// This is a references forwarding a top-level field, so we'll express
// that with an `emit` op reordering the result of this op.
mapping.push_back(refOp.getPosition().front());
continue;
}
}

// This is not just a forwarding an input field, so we keep it.
mapping.push_back(numInputFields + newYields.size());
newYields.push_back(value);
}

if (newYields.size() == terminator->getNumOperands())
return rewriter.notifyMatchFailure(
op, "does not yield unmodified input fields");

// Change the `yield` op to yield only those values we want to keep.
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(terminator);
terminator = rewriter.replaceOpWithNewOp<YieldOp>(terminator, newYields);
}

// Compute deduplicated output field types.
SmallVector<Type> outputTypes;
int64_t numNewYields = terminator->getNumOperands();
outputTypes.reserve(inputTupleType.size() + numNewYields);
append_range(outputTypes, inputTupleType.getTypes());
append_range(outputTypes, terminator->getOperandTypes());
auto newOutputType = TupleType::get(context, outputTypes);

// Create new `project` op with updated region.
auto newOp =
rewriter.create<ProjectOp>(op.getLoc(), newOutputType, op.getInput());
rewriter.inlineRegionBefore(op.getExpressions(), newOp.getExpressions(),
newOp.getExpressions().end());

// Create `emit` op with a mapping that recreates the fields we removed.
ArrayAttr reverseMappingAttr = rewriter.getI64ArrayAttr(mapping);
rewriter.replaceOpWithNewOp<EmitOp>(op, newOp, reverseMappingAttr);

return success();
}
};

/// Pushes duplicates in the mappings of `emit` ops producing either of the two
/// inputs through the `cross` op. This works by introducing new emit ops
/// without the duplicates, creating a new `cross` op that uses them, and
Expand Down Expand Up @@ -411,6 +480,7 @@ void populateEmitDeduplicationPatterns(RewritePatternSet &patterns) {
patterns.add<
// clang-format off
EliminateDuplicateYieldsInProjectPattern,
EliminateIdentityYieldsInProjectPattern,
PushDuplicatesThroughCrossPattern,
PushDuplicatesThroughFilterPattern,
PushDuplicatesThroughProjectPattern
Expand Down
32 changes: 32 additions & 0 deletions test/Transforms/Substrait/emit-deduplication.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ substrait.plan version 0 : 42 : 1 {

func.func private @f(si32, si32) -> si1

// XXX(ingomueller): How can we test individual patterns here?

// CHECK-LABEL: substrait.plan
// CHECK-NEXT: relation
// CHECK-NEXT: %[[V0:.*]] = named_table
Expand Down Expand Up @@ -257,3 +259,33 @@ substrait.plan version 0 : 42 : 1 {
yield %1 : tuple<si32, si1, si1>
}
}

// -----

// `project` op (`EliminateIdentityYieldsInProjectPattern`).

// CHECK-LABEL: substrait.plan
// CHECK-NEXT: relation
// CHECK-NEXT: %[[V0:.*]] = named_table
// CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : {{.*}} {
// CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: [[TYPE:.*]]):
// CHECK-NEXT: %[[V2:.*]] = field_reference %[[ARG0]]{{\[}}[0]] : [[TYPE]]
// CHECK-NEXT: %[[V3:.*]] = func.call @f(%[[V2]]) :
// CHECK-NEXT: yield %[[V3]] : si1
// CHECK-NEXT: }
// CHECK-NEXT: %[[V4:.*]] = emit [0, 1, 0, 2] from %[[V1]]

func.func private @f(si32) -> si1

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a", "b"] : tuple<si32, si1>
%1 = project %0 : tuple<si32, si1> -> tuple<si32, si1, si32, si1> {
^bb0(%arg0: tuple<si32, si1>):
%2 = field_reference %arg0[[0]] : tuple<si32, si1>
%3 = func.call @f(%2) : (si32) -> si1
yield %2, %3 : si32, si1
}
yield %1 : tuple<si32, si1, si32, si1>
}
}

0 comments on commit f12a311

Please sign in to comment.