diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index fc69c6b870adf3b..fc503c7866e3b23 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -46,6 +46,27 @@ using namespace Fortran::lower::omp; // Code generation helper functions //===----------------------------------------------------------------------===// +static bool evalHasSiblings(lower::pft::Evaluation &eval) { + return eval.parent.visit(common::visitors{ + [&](const lower::pft::Program &parent) { + return parent.getUnits().size() + parent.getCommonBlocks().size() > 1; + }, + [&](const lower::pft::Evaluation &parent) { + for (auto &sibling : *parent.evaluationList) + if (&sibling != &eval && !sibling.isEndStmt()) + return true; + + return false; + }, + [&](const auto &parent) { + for (auto &sibling : parent.evaluationList) + if (&sibling != &eval && !sibling.isEndStmt()) + return true; + + return false; + }}); +} + static mlir::omp::TargetOp findParentTargetOp(mlir::OpBuilder &builder) { mlir::Operation *parentOp = builder.getBlock()->getParentOp(); if (!parentOp) @@ -92,6 +113,38 @@ static void genNestedEvaluations(lower::AbstractConverter &converter, converter.genEval(e); } +static bool mustEvalTeamsThreadsOutsideTarget(lower::pft::Evaluation &eval, + mlir::omp::TargetOp targetOp) { + if (!targetOp) + return false; + + auto offloadModOp = llvm::cast( + *targetOp->getParentOfType()); + if (offloadModOp.getIsTargetDevice()) + return false; + + auto dir = Fortran::common::visit( + common::visitors{ + [&](const parser::OpenMPBlockConstruct &c) { + return std::get( + std::get(c.t).t) + .v; + }, + [&](const parser::OpenMPLoopConstruct &c) { + return std::get( + std::get(c.t).t) + .v; + }, + [&](const auto &) { + llvm_unreachable("Unexpected OpenMP construct"); + return llvm::omp::OMPD_unknown; + }, + }, + eval.get().u); + + return llvm::omp::allTargetSet.test(dir) || !evalHasSiblings(eval); +} + //===----------------------------------------------------------------------===// // HostClausesInsertionGuard //===----------------------------------------------------------------------===// @@ -412,27 +465,6 @@ createAndSetPrivatizedLoopVar(lower::AbstractConverter &converter, return storeOp; } -static bool evalHasSiblings(lower::pft::Evaluation &eval) { - return eval.parent.visit(common::visitors{ - [&](const lower::pft::Program &parent) { - return parent.getUnits().size() + parent.getCommonBlocks().size() > 1; - }, - [&](const lower::pft::Evaluation &parent) { - for (auto &sibling : *parent.evaluationList) - if (&sibling != &eval && !sibling.isEndStmt()) - return true; - - return false; - }, - [&](const auto &parent) { - for (auto &sibling : parent.evaluationList) - if (&sibling != &eval && !sibling.isEndStmt()) - return true; - - return false; - }}); -} - // This helper function implements the functionality of "promoting" // non-CPTR arguments of use_device_ptr to use_device_addr // arguments (automagic conversion of use_device_ptr -> @@ -1549,8 +1581,6 @@ genTeamsClauses(lower::AbstractConverter &converter, cp.processAllocate(clauseOps); cp.processDefault(); cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps); - cp.processNumTeams(stmtCtx, clauseOps); - cp.processThreadLimit(stmtCtx, clauseOps); // TODO Support delayed privatization. // Evaluate NUM_TEAMS and THREAD_LIMIT on the host device, if currently inside @@ -2291,17 +2321,15 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, ConstructQueue::iterator item) { lower::StatementContext stmtCtx; - auto offloadModOp = llvm::cast( - converter.getModuleOp().getOperation()); mlir::omp::TargetOp targetOp = findParentTargetOp(converter.getFirOpBuilder()); - bool mustEvalOutsideTarget = targetOp && !offloadModOp.getIsTargetDevice(); + bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp); mlir::omp::TeamsOperands clauseOps; mlir::omp::NumTeamsClauseOps numTeamsClauseOps; mlir::omp::ThreadLimitClauseOps threadLimitClauseOps; genTeamsClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - mustEvalOutsideTarget, clauseOps, numTeamsClauseOps, + evalOutsideTarget, clauseOps, numTeamsClauseOps, threadLimitClauseOps); auto teamsOp = genOpWithBody( @@ -2311,7 +2339,7 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, queue, item, clauseOps); if (numTeamsClauseOps.numTeamsUpper) { - if (mustEvalOutsideTarget) + if (evalOutsideTarget) targetOp.getNumTeamsUpperMutable().assign( numTeamsClauseOps.numTeamsUpper); else @@ -2319,7 +2347,7 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, } if (threadLimitClauseOps.threadLimit) { - if (mustEvalOutsideTarget) + if (evalOutsideTarget) targetOp.getTeamsThreadLimitMutable().assign( threadLimitClauseOps.threadLimit); else @@ -2399,12 +2427,9 @@ static void genStandaloneParallel(lower::AbstractConverter &converter, ConstructQueue::iterator item) { lower::StatementContext stmtCtx; - auto offloadModOp = - llvm::cast(*converter.getModuleOp()); mlir::omp::TargetOp targetOp = findParentTargetOp(converter.getFirOpBuilder()); - bool evalOutsideTarget = - targetOp && !offloadModOp.getIsTargetDevice() && !evalHasSiblings(eval); + bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp); mlir::omp::ParallelOperands parallelClauseOps; mlir::omp::NumThreadsClauseOps numThreadsClauseOps; @@ -2463,12 +2488,9 @@ static void genCompositeDistributeParallelDo( ConstructQueue::iterator item, DataSharingProcessor &dsp) { lower::StatementContext stmtCtx; - auto offloadModOp = - llvm::cast(*converter.getModuleOp()); mlir::omp::TargetOp targetOp = findParentTargetOp(converter.getFirOpBuilder()); - bool evalOutsideTarget = - targetOp && !offloadModOp.getIsTargetDevice() && !evalHasSiblings(eval); + bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp); // Clause processing. mlir::omp::DistributeOperands distributeClauseOps; @@ -2480,9 +2502,8 @@ static void genCompositeDistributeParallelDo( llvm::SmallVector parallelReductionSyms; llvm::SmallVector parallelReductionTypes; genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - /*evalOutsideTarget=*/evalOutsideTarget, parallelClauseOps, - numThreadsClauseOps, parallelReductionTypes, - parallelReductionSyms); + evalOutsideTarget, parallelClauseOps, numThreadsClauseOps, + parallelReductionTypes, parallelReductionSyms); const auto &privateClauseOps = dsp.getPrivateClauseOps(); parallelClauseOps.privateVars = privateClauseOps.privateVars; @@ -2538,12 +2559,9 @@ static void genCompositeDistributeParallelDoSimd( ConstructQueue::iterator item, DataSharingProcessor &dsp) { lower::StatementContext stmtCtx; - auto offloadModOp = - llvm::cast(*converter.getModuleOp()); mlir::omp::TargetOp targetOp = findParentTargetOp(converter.getFirOpBuilder()); - bool evalOutsideTarget = - targetOp && !offloadModOp.getIsTargetDevice() && !evalHasSiblings(eval); + bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp); // Clause processing. mlir::omp::DistributeOperands distributeClauseOps; @@ -2555,9 +2573,8 @@ static void genCompositeDistributeParallelDoSimd( llvm::SmallVector parallelReductionSyms; llvm::SmallVector parallelReductionTypes; genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - /*evalOutsideTarget=*/evalOutsideTarget, parallelClauseOps, - numThreadsClauseOps, parallelReductionTypes, - parallelReductionSyms); + evalOutsideTarget, parallelClauseOps, numThreadsClauseOps, + parallelReductionTypes, parallelReductionSyms); const auto &privateClauseOps = dsp.getPrivateClauseOps(); parallelClauseOps.privateVars = privateClauseOps.privateVars; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 2b4212ce7e76950..7e6e22b7f881042 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1777,7 +1777,10 @@ LogicalResult TeamsOp::verify() { auto offloadModOp = llvm::cast(*(*this)->getParentOfType()); if (targetOp && !offloadModOp.getIsTargetDevice()) { - if (getNumTeamsLower() || getNumTeamsUpper() || getThreadLimit()) + // Only disallow num_teams and thread_limit if this is the only omp.teams + // inside the target region. + if (getSingleNestedOpOfType(targetOp.getRegion()) == *this && + (getNumTeamsLower() || getNumTeamsUpper() || getThreadLimit())) return emitError("num_teams and thread_limit arguments expected to be " "attached to parent omp.target operation"); } else {