From 7421040b44c0ad43512a54af9acb0b8a6b9a7898 Mon Sep 17 00:00:00 2001 From: Jian Cai Date: Thu, 21 Sep 2023 15:35:05 -0700 Subject: [PATCH] [mlir] Move supplemental patterns before op replacement (#66959) This moves the C++ code generated from supplemental patterns before op replacement. It is necessary if the supllemental patterns need to access the source op. --- mlir/tools/mlir-tblgen/RewriterGen.cpp | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 131662d0cda67..78947b70f5cc2 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1173,9 +1173,22 @@ void PatternEmitter::emitRewriteLogic() { os << val << ";\n"; } + auto processSupplementalPatterns = [&]() { + int numSupplementalPatterns = pattern.getNumSupplementalPatterns(); + for (int i = 0, offset = -numSupplementalPatterns; + i < numSupplementalPatterns; ++i) { + DagNode resultTree = pattern.getSupplementalPattern(i); + auto val = handleResultPattern(resultTree, offset++, 0); + if (resultTree.isNativeCodeCall() && + resultTree.getNumReturnsOfNativeCode() == 0) + os << val << ";\n"; + } + }; + if (numExpectedResults == 0) { assert(replStartIndex >= numResultPatterns && "invalid auxiliary vs. replacement pattern division!"); + processSupplementalPatterns(); // No result to replace. Just erase the op. os << "rewriter.eraseOp(op0);\n"; } else { @@ -1197,20 +1210,10 @@ void PatternEmitter::emitRewriteLogic() { " tblgen_repl_values.push_back(v);\n}\n", "\n"); } + processSupplementalPatterns(); os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n"; } - // Process supplemtal patterns. - int numSupplementalPatterns = pattern.getNumSupplementalPatterns(); - for (int i = 0, offset = -numSupplementalPatterns; - i < numSupplementalPatterns; ++i) { - DagNode resultTree = pattern.getSupplementalPattern(i); - auto val = handleResultPattern(resultTree, offset++, 0); - if (resultTree.isNativeCodeCall() && - resultTree.getNumReturnsOfNativeCode() == 0) - os << val << ";\n"; - } - LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); }