diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 291e282c8b0b55..c0d02dd4c5edb2 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -1678,21 +1678,26 @@ genLoopNestOp(lower::AbstractConverter &converter, lower::SymMap &symTable, firOpBuilder.getModule().getOperation()); auto targetOp = loopNestOp->getParentOfType(); - if (offloadMod && targetOp && !offloadMod.getIsTargetDevice() && - targetOp.isTargetSPMDLoop()) { - // Lower loop bounds and step, and process collapsing again, putting lowered - // values outside of omp.target this time. This enables calculating and - // accessing the trip count in the host, which is needed when lowering to - // LLVM IR via the OMPIRBuilder. - HostClausesInsertionGuard guard(firOpBuilder); - mlir::omp::CollapseClauseOps collapseClauseOps; - llvm::SmallVector iv; - ClauseProcessor cp(converter, semaCtx, item->clauses); - cp.processCollapse(loc, eval, collapseClauseOps, iv); - targetOp.getTripCountMutable().assign(calculateTripCount( - converter.getFirOpBuilder(), loc, collapseClauseOps)); + if (offloadMod && targetOp && !offloadMod.getIsTargetDevice()) { + if (targetOp.isTargetSPMDLoop()) { + // Lower loop bounds and step, and process collapsing again, putting + // lowered values outside of omp.target this time. This enables + // calculating and accessing the trip count in the host, which is needed + // when lowering to LLVM IR via the OMPIRBuilder. + HostClausesInsertionGuard guard(firOpBuilder); + mlir::omp::CollapseClauseOps collapseClauseOps; + llvm::SmallVector iv; + ClauseProcessor cp(converter, semaCtx, item->clauses); + cp.processCollapse(loc, eval, collapseClauseOps, iv); + targetOp.getTripCountMutable().assign(calculateTripCount( + converter.getFirOpBuilder(), loc, collapseClauseOps)); + } else if (targetOp.getTripCountMutable().size()) { + // The MLIR target operation was updated during PFT lowering, + // and it is no longer an SPMD kernel. Erase the trip count because + // as it is now invalid. + targetOp.getTripCountMutable().erase(0); + } } - return loopNestOp; }