diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index c54770c7fbfaee..662d0b56665f32 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -46,6 +46,27 @@ using namespace Fortran::lower::omp; // Code generation helper functions //===----------------------------------------------------------------------===// +static bool evalHasSiblings(lower::pft::Evaluation &eval) { + auto checkSiblings = [&eval](const lower::pft::EvaluationList &siblings) { + for (auto &sibling : siblings) + if (&sibling != &eval && !sibling.isEndStmt()) + return true; + + return false; + }; + + return eval.parent.visit(common::visitors{ + [&](const lower::pft::Program &parent) { + return parent.getUnits().size() + parent.getCommonBlocks().size() > 1; + }, + [&](const lower::pft::Evaluation &parent) { + return checkSiblings(*parent.evaluationList); + }, + [&](const auto &parent) { + return checkSiblings(parent.evaluationList); + }}); +} + static mlir::omp::TargetOp findParentTargetOp(mlir::OpBuilder &builder) { mlir::Operation *parentOp = builder.getBlock()->getParentOp(); if (!parentOp) @@ -92,6 +113,38 @@ static void genNestedEvaluations(lower::AbstractConverter &converter, converter.genEval(e); } +static bool mustEvalTeamsThreadsOutsideTarget(lower::pft::Evaluation &eval, + mlir::omp::TargetOp targetOp) { + if (!targetOp) + return false; + + auto offloadModOp = llvm::cast( + *targetOp->getParentOfType()); + if (offloadModOp.getIsTargetDevice()) + return false; + + auto dir = Fortran::common::visit( + common::visitors{ + [&](const parser::OpenMPBlockConstruct &c) { + return std::get( + std::get(c.t).t) + .v; + }, + [&](const parser::OpenMPLoopConstruct &c) { + return std::get( + std::get(c.t).t) + .v; + }, + [&](const auto &) { + llvm_unreachable("Unexpected OpenMP construct"); + return llvm::omp::OMPD_unknown; + }, + }, + eval.get().u); + + return llvm::omp::allTargetSet.test(dir) || !evalHasSiblings(eval); +} + //===----------------------------------------------------------------------===// // HostClausesInsertionGuard //===----------------------------------------------------------------------===// @@ -425,27 +478,6 @@ createAndSetPrivatizedLoopVar(lower::AbstractConverter &converter, return storeOp; } -static bool evalHasSiblings(lower::pft::Evaluation &eval) { - return eval.parent.visit(common::visitors{ - [&](const lower::pft::Program &parent) { - return parent.getUnits().size() + parent.getCommonBlocks().size() > 1; - }, - [&](const lower::pft::Evaluation &parent) { - for (auto &sibling : *parent.evaluationList) - if (&sibling != &eval && !sibling.isEndStmt()) - return true; - - return false; - }, - [&](const auto &parent) { - for (auto &sibling : parent.evaluationList) - if (&sibling != &eval && !sibling.isEndStmt()) - return true; - - return false; - }}); -} - // This helper function implements the functionality of "promoting" // non-CPTR arguments of use_device_ptr to use_device_addr // arguments (automagic conversion of use_device_ptr -> @@ -1562,8 +1594,6 @@ genTeamsClauses(lower::AbstractConverter &converter, cp.processAllocate(clauseOps); cp.processDefault(); cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps); - cp.processNumTeams(stmtCtx, clauseOps); - cp.processThreadLimit(stmtCtx, clauseOps); // TODO Support delayed privatization. // Evaluate NUM_TEAMS and THREAD_LIMIT on the host device, if currently inside @@ -2304,17 +2334,15 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, ConstructQueue::iterator item) { lower::StatementContext stmtCtx; - auto offloadModOp = llvm::cast( - converter.getModuleOp().getOperation()); mlir::omp::TargetOp targetOp = findParentTargetOp(converter.getFirOpBuilder()); - bool mustEvalOutsideTarget = targetOp && !offloadModOp.getIsTargetDevice(); + bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp); mlir::omp::TeamsOperands clauseOps; mlir::omp::NumTeamsClauseOps numTeamsClauseOps; mlir::omp::ThreadLimitClauseOps threadLimitClauseOps; genTeamsClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - mustEvalOutsideTarget, clauseOps, numTeamsClauseOps, + evalOutsideTarget, clauseOps, numTeamsClauseOps, threadLimitClauseOps); auto teamsOp = genOpWithBody( @@ -2324,7 +2352,7 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, queue, item, clauseOps); if (numTeamsClauseOps.numTeamsUpper) { - if (mustEvalOutsideTarget) + if (evalOutsideTarget) targetOp.getNumTeamsUpperMutable().assign( numTeamsClauseOps.numTeamsUpper); else @@ -2332,7 +2360,7 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, } if (threadLimitClauseOps.threadLimit) { - if (mustEvalOutsideTarget) + if (evalOutsideTarget) targetOp.getTeamsThreadLimitMutable().assign( threadLimitClauseOps.threadLimit); else @@ -2412,12 +2440,9 @@ static void genStandaloneParallel(lower::AbstractConverter &converter, ConstructQueue::iterator item) { lower::StatementContext stmtCtx; - auto offloadModOp = - llvm::cast(*converter.getModuleOp()); mlir::omp::TargetOp targetOp = findParentTargetOp(converter.getFirOpBuilder()); - bool evalOutsideTarget = - targetOp && !offloadModOp.getIsTargetDevice() && !evalHasSiblings(eval); + bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp); mlir::omp::ParallelOperands parallelClauseOps; mlir::omp::NumThreadsClauseOps numThreadsClauseOps; @@ -2476,12 +2501,9 @@ static void genCompositeDistributeParallelDo( ConstructQueue::iterator item, DataSharingProcessor &dsp) { lower::StatementContext stmtCtx; - auto offloadModOp = - llvm::cast(*converter.getModuleOp()); mlir::omp::TargetOp targetOp = findParentTargetOp(converter.getFirOpBuilder()); - bool evalOutsideTarget = - targetOp && !offloadModOp.getIsTargetDevice() && !evalHasSiblings(eval); + bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp); // Clause processing. mlir::omp::DistributeOperands distributeClauseOps; @@ -2493,9 +2515,8 @@ static void genCompositeDistributeParallelDo( llvm::SmallVector parallelReductionSyms; llvm::SmallVector parallelReductionTypes; genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - /*evalOutsideTarget=*/evalOutsideTarget, parallelClauseOps, - numThreadsClauseOps, parallelReductionTypes, - parallelReductionSyms); + evalOutsideTarget, parallelClauseOps, numThreadsClauseOps, + parallelReductionTypes, parallelReductionSyms); const auto &privateClauseOps = dsp.getPrivateClauseOps(); parallelClauseOps.privateVars = privateClauseOps.privateVars; @@ -2551,12 +2572,9 @@ static void genCompositeDistributeParallelDoSimd( ConstructQueue::iterator item, DataSharingProcessor &dsp) { lower::StatementContext stmtCtx; - auto offloadModOp = - llvm::cast(*converter.getModuleOp()); mlir::omp::TargetOp targetOp = findParentTargetOp(converter.getFirOpBuilder()); - bool evalOutsideTarget = - targetOp && !offloadModOp.getIsTargetDevice() && !evalHasSiblings(eval); + bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp); // Clause processing. mlir::omp::DistributeOperands distributeClauseOps; @@ -2568,9 +2586,8 @@ static void genCompositeDistributeParallelDoSimd( llvm::SmallVector parallelReductionSyms; llvm::SmallVector parallelReductionTypes; genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - /*evalOutsideTarget=*/evalOutsideTarget, parallelClauseOps, - numThreadsClauseOps, parallelReductionTypes, - parallelReductionSyms); + evalOutsideTarget, parallelClauseOps, numThreadsClauseOps, + parallelReductionTypes, parallelReductionSyms); const auto &privateClauseOps = dsp.getPrivateClauseOps(); parallelClauseOps.privateVars = privateClauseOps.privateVars; diff --git a/flang/test/Lower/OpenMP/eval-outside-target.f90 b/flang/test/Lower/OpenMP/eval-outside-target.f90 new file mode 100644 index 00000000000000..5d4a8a104c8952 --- /dev/null +++ b/flang/test/Lower/OpenMP/eval-outside-target.f90 @@ -0,0 +1,200 @@ +! The "thread_limit" clause was added to the "target" construct in OpenMP 5.1. +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 %s -o - | FileCheck %s --check-prefixes=BOTH,HOST +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 -fopenmp-is-target-device %s -o - | FileCheck %s --check-prefixes=BOTH,DEVICE + +! CHECK-LABEL: func.func @_QPteams +subroutine teams() + ! BOTH: omp.target + + ! HOST-SAME: num_teams({{.*}}) + ! HOST-SAME: teams_thread_limit({{.*}}) + + ! DEVICE-NOT: num_teams({{.*}}) + ! DEVICE-NOT: teams_thread_limit({{.*}}) + ! DEVICE-SAME: { + !$omp target + + ! BOTH: omp.teams + + ! HOST-NOT: num_teams({{.*}}) + ! HOST-NOT: thread_limit({{.*}}) + ! HOST-SAME: { + + ! DEVICE-SAME: num_teams({{.*}}) + ! DEVICE-SAME: thread_limit({{.*}}) + !$omp teams num_teams(1) thread_limit(2) + call foo() + !$omp end teams + + !$omp end target + + ! BOTH: omp.teams + ! BOTH-SAME: num_teams({{.*}}) + ! BOTH-SAME: thread_limit({{.*}}) + !$omp teams num_teams(1) thread_limit(2) + call foo() + !$omp end teams +end subroutine teams + +subroutine parallel() + ! BOTH: omp.target + + ! HOST-SAME: num_threads({{.*}}) + + ! DEVICE-NOT: num_threads({{.*}}) + ! DEVICE-SAME: { + !$omp target + + ! BOTH: omp.parallel + + ! HOST-NOT: num_threads({{.*}}) + ! HOST-SAME: { + + ! DEVICE-SAME: num_threads({{.*}}) + !$omp parallel num_threads(1) + call foo() + !$omp end parallel + !$omp end target + + ! BOTH: omp.target + ! BOTH-NOT: num_threads({{.*}}) + ! BOTH-SAME: { + !$omp target + call foo() + + ! BOTH: omp.parallel + ! BOTH-SAME: num_threads({{.*}}) + !$omp parallel num_threads(1) + call foo() + !$omp end parallel + !$omp end target + + ! BOTH: omp.parallel + ! BOTH-SAME: num_threads({{.*}}) + !$omp parallel num_threads(1) + call foo() + !$omp end parallel +end subroutine parallel + +subroutine distribute_parallel_do() + ! BOTH: omp.target + + ! HOST-SAME: num_threads({{.*}}) + + ! DEVICE-NOT: num_threads({{.*}}) + ! DEVICE-SAME: { + + ! BOTH: omp.teams + !$omp target teams + + ! BOTH: omp.distribute + ! BOTH-NEXT: omp.parallel + + ! HOST-NOT: num_threads({{.*}}) + ! HOST-SAME: { + + ! DEVICE-SAME: num_threads({{.*}}) + + ! BOTH-NEXT: omp.wsloop + !$omp distribute parallel do num_threads(1) + do i=1,10 + call foo() + end do + !$omp end distribute parallel do + !$omp end target teams + + ! BOTH: omp.target + ! BOTH-NOT: num_threads({{.*}}) + ! BOTH-SAME: { + ! BOTH: omp.teams + !$omp target teams + call foo() + + ! BOTH: omp.distribute + ! BOTH-NEXT: omp.parallel + ! BOTH-SAME: num_threads({{.*}}) + ! BOTH-NEXT: omp.wsloop + !$omp distribute parallel do num_threads(1) + do i=1,10 + call foo() + end do + !$omp end distribute parallel do + !$omp end target teams + + ! BOTH: omp.teams + !$omp teams + + ! BOTH: omp.distribute + ! BOTH-NEXT: omp.parallel + ! BOTH-SAME: num_threads({{.*}}) + ! BOTH-NEXT: omp.wsloop + !$omp distribute parallel do num_threads(1) + do i=1,10 + call foo() + end do + !$omp end distribute parallel do + !$omp end teams +end subroutine distribute_parallel_do + +subroutine distribute_parallel_do_simd() + ! BOTH: omp.target + + ! HOST-SAME: num_threads({{.*}}) + + ! DEVICE-NOT: num_threads({{.*}}) + ! DEVICE-SAME: { + + ! BOTH: omp.teams + !$omp target teams + + ! BOTH: omp.distribute + ! BOTH-NEXT: omp.parallel + + ! HOST-NOT: num_threads({{.*}}) + ! HOST-SAME: { + + ! DEVICE-SAME: num_threads({{.*}}) + + ! BOTH-NEXT: omp.wsloop + ! BOTH-NEXT: omp.simd + !$omp distribute parallel do simd num_threads(1) + do i=1,10 + call foo() + end do + !$omp end distribute parallel do simd + !$omp end target teams + + ! BOTH: omp.target + ! BOTH-NOT: num_threads({{.*}}) + ! BOTH-SAME: { + ! BOTH: omp.teams + !$omp target teams + call foo() + + ! BOTH: omp.distribute + ! BOTH-NEXT: omp.parallel + ! BOTH-SAME: num_threads({{.*}}) + ! BOTH-NEXT: omp.wsloop + ! BOTH-NEXT: omp.simd + !$omp distribute parallel do simd num_threads(1) + do i=1,10 + call foo() + end do + !$omp end distribute parallel do simd + !$omp end target teams + + ! BOTH: omp.teams + !$omp teams + + ! BOTH: omp.distribute + ! BOTH-NEXT: omp.parallel + ! BOTH-SAME: num_threads({{.*}}) + ! BOTH-NEXT: omp.wsloop + ! BOTH-NEXT: omp.simd + !$omp distribute parallel do simd num_threads(1) + do i=1,10 + call foo() + end do + !$omp end distribute parallel do simd + !$omp end teams +end subroutine distribute_parallel_do_simd diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 2b4212ce7e7695..5f6007561967e4 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1506,24 +1506,13 @@ static LogicalResult verifyNumTeamsClause(Operation *op, Value lb, Value ub) { return success(); } -template -static OpTy getSingleNestedOpOfType(Region ®ion) { - auto ops = region.getOps(); - return std::distance(ops.begin(), ops.end()) != 1 ? OpTy() : *ops.begin(); -} - LogicalResult TargetOp::verify() { auto teamsOps = getOps(); if (std::distance(teamsOps.begin(), teamsOps.end()) > 1) return emitError("target containing multiple teams constructs"); - if (!isTargetSPMDLoop()) { - if (getTripCount()) - return emitError("trip_count set on non-SPMD target region"); - - if (getNumThreads() && !getSingleNestedOpOfType(getRegion())) - return emitError("num_threads set on non-SPMD or loop target region"); - } + if (!isTargetSPMDLoop() && getTripCount()) + return emitError("trip_count set on non-SPMD target region"); if (teamsOps.empty()) { if (getNumTeamsLower() || getNumTeamsUpper() || getTeamsThreadLimit()) @@ -1721,17 +1710,6 @@ LogicalResult ParallelOp::verify() { return emitError( "expected equal sizes for allocate and allocator variables"); - auto offloadModOp = - llvm::cast(*(*this)->getParentOfType()); - if (!offloadModOp.getIsTargetDevice()) { - auto targetOp = (*this)->getParentOfType(); - if (getNumThreads() && targetOp && - (targetOp.isTargetSPMDLoop() || - getSingleNestedOpOfType(targetOp.getRegion()) == *this)) - return emitError("num_threads argument expected to be attached to parent " - "omp.target operation instead"); - } - if (failed(verifyPrivateVarList(*this))) return failure();