From ed2a10bf0a73f2b4a25b06c19b1477106af49fe6 Mon Sep 17 00:00:00 2001 From: Dominik Adamski Date: Wed, 31 Jul 2024 15:45:16 +0200 Subject: [PATCH] [Flang][OpenMP] Erase trip count for generic kernels (#125) We determine the type of the kernel based on the MLIR code, which changes during the lowering phase. Some kernels, such as those with multiple workshare loops, are initially classified as SPMD kernels, but are later recognized as generic kernels during PFT lowering. In such cases, we need to identify the change in type and clear the trip count if it was previously set. --- flang/lib/Lower/OpenMP/OpenMP.cpp | 33 ++++++++++++++++++------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 73ab9fe9ea2c20..9cefc54fee4597 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -1681,21 +1681,26 @@ genLoopNestOp(lower::AbstractConverter &converter, lower::SymMap &symTable, firOpBuilder.getModule().getOperation()); auto targetOp = loopNestOp->getParentOfType(); - if (offloadMod && targetOp && !offloadMod.getIsTargetDevice() && - targetOp.isTargetSPMDLoop()) { - // Lower loop bounds and step, and process collapsing again, putting lowered - // values outside of omp.target this time. This enables calculating and - // accessing the trip count in the host, which is needed when lowering to - // LLVM IR via the OMPIRBuilder. - HostClausesInsertionGuard guard(firOpBuilder); - mlir::omp::CollapseClauseOps collapseClauseOps; - llvm::SmallVector iv; - ClauseProcessor cp(converter, semaCtx, item->clauses); - cp.processCollapse(loc, eval, collapseClauseOps, iv); - targetOp.getTripCountMutable().assign(calculateTripCount( - converter.getFirOpBuilder(), loc, collapseClauseOps)); + if (offloadMod && targetOp && !offloadMod.getIsTargetDevice()) { + if (targetOp.isTargetSPMDLoop()) { + // Lower loop bounds and step, and process collapsing again, putting + // lowered values outside of omp.target this time. This enables + // calculating and accessing the trip count in the host, which is needed + // when lowering to LLVM IR via the OMPIRBuilder. + HostClausesInsertionGuard guard(firOpBuilder); + mlir::omp::CollapseClauseOps collapseClauseOps; + llvm::SmallVector iv; + ClauseProcessor cp(converter, semaCtx, item->clauses); + cp.processCollapse(loc, eval, collapseClauseOps, iv); + targetOp.getTripCountMutable().assign(calculateTripCount( + converter.getFirOpBuilder(), loc, collapseClauseOps)); + } else if (targetOp.getTripCountMutable().size()) { + // The MLIR target operation was updated during PFT lowering, + // and it is no longer an SPMD kernel. Erase the trip count because + // as it is now invalid. + targetOp.getTripCountMutable().erase(0); + } } - return loopNestOp; }