Skip to content

Commit

Permalink
[Substrait] Add pattern removing duplicate yields in project.
Browse files Browse the repository at this point in the history
This PR adds a pattern to the emit deduplication pass that removes
duplicates introduced by the `project` op by yielding values more than
once.

Signed-off-by: Ingo Müller <ingomueller@google.com>
  • Loading branch information
ingomueller-net committed May 31, 2024
1 parent 360dba3 commit ccddb3c
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 0 deletions.
71 changes: 71 additions & 0 deletions lib/Dialect/Substrait/Transforms/EmitDeduplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,76 @@ void deduplicateRegionArgs(Region &region, ArrayAttr newMapping,
region.getArgument(0).setType(newElementType);
}

struct EliminateDuplicateYieldsInProjectPattern
: public OpRewritePattern<ProjectOp> {
using OpRewritePattern<ProjectOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ProjectOp op,
PatternRewriter &rewriter) const override {
MLIRContext *context = op.getContext();
Operation *terminator = op.getExpressions().front().getTerminator();
int64_t numOriginalYields = terminator->getNumOperands();
auto inputTupleType = cast<TupleType>(op.getInput().getType());

// Determine duplicate values in `yield` and remember the first ocurrence of
// each value.
llvm::DenseMap<Value, int64_t> valuePositions;
for (Value value : terminator->getOperands())
valuePositions.try_emplace(value, valuePositions.size());

if (valuePositions.size() == numOriginalYields)
return rewriter.notifyMatchFailure(op, "does not yield duplicate values");

// Create a mapping from the de-duplicated values that re-establishes the
// original emit order. The input fields are just forwarded, so create
// identity prefix.
SmallVector<int64_t> reverseMapping;
reverseMapping.reserve(inputTupleType.size() + numOriginalYields);
append_range(reverseMapping, iota_range<int64_t>(0, inputTupleType.size(),
/*Inclusive=*/false));

// Reverse mapping: The fields added by the `expression` regions are now
// de-duplicated, so we need to reverse the effect of the deduplication,
// taking the prefix into account.
for (Value value : terminator->getOperands()) {
int64_t pos = valuePositions[value];
reverseMapping.push_back(inputTupleType.size() + pos);
}

// Remove duplicate values in `yield` op of the `expressions` region.
{
SmallVector<Value> values;
values.reserve(valuePositions.size());
for (auto [value, pos] : valuePositions)
values.push_back(value);

PatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(terminator);
terminator = rewriter.replaceOpWithNewOp<YieldOp>(terminator, values);
}

// Compute deduplicated output field types.
SmallVector<Type> outputTypes;
int64_t numNewYields = terminator->getNumOperands();
outputTypes.reserve(inputTupleType.size() + numNewYields);
append_range(outputTypes, inputTupleType.getTypes());
append_range(outputTypes, terminator->getOperandTypes());
auto newOutputType = TupleType::get(context, outputTypes);

// Create new `project` op with updated region and output type.
auto newOp =
rewriter.create<ProjectOp>(op.getLoc(), newOutputType, op.getInput());
rewriter.inlineRegionBefore(op.getExpressions(), newOp.getExpressions(),
newOp.getExpressions().end());

// Create `emit` op with the reverse mapping.
ArrayAttr reverseMappingAttr = rewriter.getI64ArrayAttr(reverseMapping);
rewriter.replaceOpWithNewOp<EmitOp>(op, newOp, reverseMappingAttr);

return success();
}
};

/// Pushes duplicates in the mappings of `emit` ops producing either of the two
/// inputs through the `cross` op. This works by introducing new emit ops
/// without the duplicates, creating a new `cross` op that uses them, and
Expand Down Expand Up @@ -340,6 +410,7 @@ void populateEmitDeduplicationPatterns(RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<
// clang-format off
EliminateDuplicateYieldsInProjectPattern,
PushDuplicatesThroughCrossPattern,
PushDuplicatesThroughFilterPattern,
PushDuplicatesThroughProjectPattern
Expand Down
30 changes: 30 additions & 0 deletions test/Transforms/Substrait/emit-deduplication.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,33 @@ substrait.plan version 0 : 42 : 1 {
yield %2 : tuple<si32, si32, si1>
}
}

// -----

// `project` op (`EliminateDuplicateYieldsInProjectPattern`).

// CHECK-LABEL: substrait.plan
// CHECK-NEXT: relation
// CHECK-NEXT: %[[V0:.*]] = named_table
// CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : {{.*}} {
// CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: [[TYPE:.*]]):
// CHECK-NEXT: %[[V2:.*]] = field_reference %[[ARG0]]{{\[}}[0]] : [[TYPE]]
// CHECK-NEXT: %[[V3:.*]] = func.call @f(%[[V2]]) :
// CHECK-NEXT: yield %[[V3]] : si1
// CHECK-NEXT: }
// CHECK-NEXT: %[[V4:.*]] = emit [0, 1, 1] from %[[V1]]

func.func private @f(si32) -> si1

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = project %0 : tuple<si32> -> tuple<si32, si1, si1> {
^bb0(%arg : tuple<si32>):
%2 = field_reference %arg[[0]] : tuple<si32>
%3 = func.call @f(%2) : (si32) -> si1
yield %3, %3 : si1, si1
}
yield %1 : tuple<si32, si1, si1>
}
}

0 comments on commit ccddb3c

Please sign in to comment.