From 18926666f509104c3f478444b282291ce19fab6a Mon Sep 17 00:00:00 2001 From: SJW <48454132+sjw36@users.noreply.github.com> Date: Thu, 5 Sep 2024 13:46:18 -0500 Subject: [PATCH] [MLIR][SCF] Loop pipelining fails on failed predication (no assert) (#107442) The SCFLoopPipelining allows predication on peeled or loop ops. When the predicationFn returns a nullptr this signifies the op type is unsupported and the pipeliner fails except in `emitPrologue` where it asserts. This patch fixes handling in the prologue to gracefully fail. --- mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp index a34542f0161aca..7cecd4942b640f 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -77,7 +77,7 @@ struct LoopPipelinerInternal { bool initializeLoopInfo(ForOp op, const PipeliningOption &options); /// Emits the prologue, this creates `maxStage - 1` part which will contain /// operations from stages [0; i], where i is the part index. - void emitPrologue(RewriterBase &rewriter); + LogicalResult emitPrologue(RewriterBase &rewriter); /// Gather liverange information for Values that are used in a different stage /// than its definition. llvm::MapVector analyzeCrossStageValues(); @@ -263,7 +263,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, return clone; } -void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { +LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { // Initialize the iteration argument to the loop initial values. for (auto [arg, operand] : llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) { @@ -311,7 +311,8 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { if (predicates[predicateIdx]) { OpBuilder::InsertionGuard insertGuard(rewriter); newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]); - assert(newOp && "failed to predicate op."); + if (newOp == nullptr) + return failure(); } if (annotateFn) annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i); @@ -339,6 +340,7 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { } } } + return success(); } llvm::MapVector @@ -772,7 +774,8 @@ FailureOr mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, *modifiedIR = true; // 1. Emit prologue. - pipeliner.emitPrologue(rewriter); + if (failed(pipeliner.emitPrologue(rewriter))) + return failure(); // 2. Track values used across stages. When a value cross stages it will // need to be passed as loop iteration arguments.