Skip to content

Commit

Permalink
replace with multiple
Browse files Browse the repository at this point in the history
Apply suggestions from code review

Co-authored-by: Markus Böck <markus.boeck02@gmail.com>

address comments

[WIP] 1:N conversion pattern

update test cases
  • Loading branch information
matthias-springer committed Nov 16, 2024
1 parent e872b86 commit 153310a
Show file tree
Hide file tree
Showing 9 changed files with 381 additions and 303 deletions.
35 changes: 31 additions & 4 deletions mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ template <typename SourceOp>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;

explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
Expand All @@ -153,17 +155,29 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
rewriter);
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult match(Operation *op) const final {
return match(cast<SourceOp>(op));
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op),
OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
rewriter);
}

/// Rewrite and Match methods that operate on the SourceOp type. These must be
Expand All @@ -175,6 +189,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override rewrite or matchAndRewrite");
}
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand All @@ -183,6 +203,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
rewrite(op, adaptor, rewriter);
return success();
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}

private:
using ConvertToLLVMPattern::match;
Expand Down
63 changes: 63 additions & 0 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,10 @@ class ConversionPattern : public RewritePattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite");
}
virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}

/// Hook for derived classes to implement combined matching and rewriting.
virtual LogicalResult
Expand All @@ -547,6 +551,11 @@ class ConversionPattern : public RewritePattern {
rewrite(op, operands, rewriter);
return success();
}
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}

/// Attempt to match and rewrite the IR root at the specified operation.
LogicalResult matchAndRewrite(Operation *op,
Expand Down Expand Up @@ -574,6 +583,15 @@ class ConversionPattern : public RewritePattern {
: RewritePattern(std::forward<Args>(args)...),
typeConverter(&typeConverter) {}

/// Given an array of value ranges, which are the inputs to a 1:N adaptor,
/// try to extract the single value of each range to construct a the inputs
/// for a 1:1 adaptor.
///
/// This function produces a fatal error if at least one range has 0 or
/// more than 1 value: "pattern 'name' does not support 1:N conversion"
SmallVector<Value>
getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;

protected:
/// An optional type converter for use by this pattern.
const TypeConverter *typeConverter = nullptr;
Expand All @@ -589,6 +607,8 @@ template <typename SourceOp>
class OpConversionPattern : public ConversionPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;

OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
Expand All @@ -607,12 +627,24 @@ class OpConversionPattern : public ConversionPattern {
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
rewriter);
}

/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
Expand All @@ -623,6 +655,12 @@ class OpConversionPattern : public ConversionPattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand All @@ -631,6 +669,13 @@ class OpConversionPattern : public ConversionPattern {
rewrite(op, adaptor, rewriter);
return success();
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}

private:
using ConversionPattern::matchAndRewrite;
Expand All @@ -656,18 +701,31 @@ class OpInterfaceConversionPattern : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}

/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Expand All @@ -676,6 +734,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
rewrite(op, operands, rewriter);
return success();
}
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}

private:
using ConversionPattern::matchAndRewrite;
Expand Down
56 changes: 6 additions & 50 deletions mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,6 @@
using namespace mlir;
using namespace mlir::func;

//===----------------------------------------------------------------------===//
// Helper functions
//===----------------------------------------------------------------------===//

/// If the given value can be decomposed with the type converter, decompose it.
/// Otherwise, return the given value.
// TODO: Value decomposition should happen automatically through a 1:N adaptor.
// This function will disappear when the 1:1 and 1:N drivers are merged.
static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc,
Value value,
const TypeConverter *converter) {
// Try to convert the given value's type. If that fails, just return the
// given value.
SmallVector<Type> convertedTypes;
if (failed(converter->convertType(value.getType(), convertedTypes)))
return {value};
if (convertedTypes.empty())
return {};

// If the given value's type is already legal, just return the given value.
TypeRange convertedTypeRange(convertedTypes);
if (convertedTypeRange == TypeRange(value.getType()))
return {value};

// Try to materialize a target conversion. If the materialization did not
// produce values of the requested type, the materialization failed. Just
// return the given value in that case.
SmallVector<Value> result = converter->materializeTargetConversion(
builder, loc, convertedTypeRange, value);
if (result.empty())
return {value};
return result;
}

//===----------------------------------------------------------------------===//
// DecomposeCallGraphTypesForFuncArgs
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -102,16 +68,11 @@ struct DecomposeCallGraphTypesForReturnOp
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
SmallVector<Value, 2> newOperands;
for (Value operand : adaptor.getOperands()) {
// TODO: We can directly take the values from the adaptor once this is a
// 1:N conversion pattern.
llvm::append_range(newOperands,
decomposeValue(rewriter, operand.getLoc(), operand,
getTypeConverter()));
}
for (ValueRange operand : adaptor.getOperands())
llvm::append_range(newOperands, operand);
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
return success();
}
Expand All @@ -128,18 +89,13 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(CallOp op, OpAdaptor adaptor,
matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {

// Create the operands list of the new `CallOp`.
SmallVector<Value, 2> newOperands;
for (Value operand : adaptor.getOperands()) {
// TODO: We can directly take the values from the adaptor once this is a
// 1:N conversion pattern.
llvm::append_range(newOperands,
decomposeValue(rewriter, operand.getLoc(), operand,
getTypeConverter()));
}
for (ValueRange operand : adaptor.getOperands())
llvm::append_range(newOperands, operand);

// Create the new result types for the new `CallOp` and track the number of
// replacement types for each original op result.
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {

/// Hook for derived classes to implement combined matching and rewriting.
LogicalResult
matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Convert the original function results.
SmallVector<Type, 1> convertedResults;
Expand All @@ -37,7 +37,8 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
// Substitute with the new result types from the corresponding FuncType
// conversion.
rewriter.replaceOpWithNewOp<CallOp>(
callOp, callOp.getCallee(), convertedResults, adaptor.getOperands());
callOp, callOp.getCallee(), convertedResults,
getOneToOneAdaptorOperands(adaptor.getOperands()));
return success();
}
};
Expand Down
Loading

0 comments on commit 153310a

Please sign in to comment.