Skip to content

Commit

Permalink
[Flang][OpenMP] Erase trip count for generic kernels (#125)
Browse files Browse the repository at this point in the history
We determine the type of the kernel based on the MLIR code,
which changes during the lowering phase. Some kernels,
such as those with multiple workshare loops, are initially
classified as SPMD kernels, but are later recognized as generic
kernels during PFT lowering. In such cases,
we need to identify the change in type and clear the trip count
if it was previously set.
  • Loading branch information
DominikAdamski authored Jul 31, 2024
1 parent b3b35ea commit ed2a10b
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1681,21 +1681,26 @@ genLoopNestOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
firOpBuilder.getModule().getOperation());
auto targetOp = loopNestOp->getParentOfType<mlir::omp::TargetOp>();

if (offloadMod && targetOp && !offloadMod.getIsTargetDevice() &&
targetOp.isTargetSPMDLoop()) {
// Lower loop bounds and step, and process collapsing again, putting lowered
// values outside of omp.target this time. This enables calculating and
// accessing the trip count in the host, which is needed when lowering to
// LLVM IR via the OMPIRBuilder.
HostClausesInsertionGuard guard(firOpBuilder);
mlir::omp::CollapseClauseOps collapseClauseOps;
llvm::SmallVector<const semantics::Symbol *> iv;
ClauseProcessor cp(converter, semaCtx, item->clauses);
cp.processCollapse(loc, eval, collapseClauseOps, iv);
targetOp.getTripCountMutable().assign(calculateTripCount(
converter.getFirOpBuilder(), loc, collapseClauseOps));
if (offloadMod && targetOp && !offloadMod.getIsTargetDevice()) {
if (targetOp.isTargetSPMDLoop()) {
// Lower loop bounds and step, and process collapsing again, putting
// lowered values outside of omp.target this time. This enables
// calculating and accessing the trip count in the host, which is needed
// when lowering to LLVM IR via the OMPIRBuilder.
HostClausesInsertionGuard guard(firOpBuilder);
mlir::omp::CollapseClauseOps collapseClauseOps;
llvm::SmallVector<const semantics::Symbol *> iv;
ClauseProcessor cp(converter, semaCtx, item->clauses);
cp.processCollapse(loc, eval, collapseClauseOps, iv);
targetOp.getTripCountMutable().assign(calculateTripCount(
converter.getFirOpBuilder(), loc, collapseClauseOps));
} else if (targetOp.getTripCountMutable().size()) {
// The MLIR target operation was updated during PFT lowering,
// and it is no longer an SPMD kernel. Erase the trip count because
// as it is now invalid.
targetOp.getTripCountMutable().erase(0);
}
}

return loopNestOp;
}

Expand Down

0 comments on commit ed2a10b

Please sign in to comment.