diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 591983bc7ec0de..662d0b56665f32 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -47,23 +47,23 @@ using namespace Fortran::lower::omp; //===----------------------------------------------------------------------===// static bool evalHasSiblings(lower::pft::Evaluation &eval) { + auto checkSiblings = [&eval](const lower::pft::EvaluationList &siblings) { + for (auto &sibling : siblings) + if (&sibling != &eval && !sibling.isEndStmt()) + return true; + + return false; + }; + return eval.parent.visit(common::visitors{ [&](const lower::pft::Program &parent) { return parent.getUnits().size() + parent.getCommonBlocks().size() > 1; }, [&](const lower::pft::Evaluation &parent) { - for (auto &sibling : *parent.evaluationList) - if (&sibling != &eval && !sibling.isEndStmt()) - return true; - - return false; + return checkSiblings(*parent.evaluationList); }, [&](const auto &parent) { - for (auto &sibling : parent.evaluationList) - if (&sibling != &eval && !sibling.isEndStmt()) - return true; - - return false; + return checkSiblings(parent.evaluationList); }}); } diff --git a/flang/test/Lower/OpenMP/eval-outside-target.f90 b/flang/test/Lower/OpenMP/eval-outside-target.f90 new file mode 100644 index 00000000000000..5d4a8a104c8952 --- /dev/null +++ b/flang/test/Lower/OpenMP/eval-outside-target.f90 @@ -0,0 +1,200 @@ +! The "thread_limit" clause was added to the "target" construct in OpenMP 5.1. +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 %s -o - | FileCheck %s --check-prefixes=BOTH,HOST +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 -fopenmp-is-target-device %s -o - | FileCheck %s --check-prefixes=BOTH,DEVICE + +! CHECK-LABEL: func.func @_QPteams +subroutine teams() + ! BOTH: omp.target + + ! HOST-SAME: num_teams({{.*}}) + ! HOST-SAME: teams_thread_limit({{.*}}) + + ! DEVICE-NOT: num_teams({{.*}}) + ! DEVICE-NOT: teams_thread_limit({{.*}}) + ! DEVICE-SAME: { + !$omp target + + ! BOTH: omp.teams + + ! HOST-NOT: num_teams({{.*}}) + ! HOST-NOT: thread_limit({{.*}}) + ! HOST-SAME: { + + ! DEVICE-SAME: num_teams({{.*}}) + ! DEVICE-SAME: thread_limit({{.*}}) + !$omp teams num_teams(1) thread_limit(2) + call foo() + !$omp end teams + + !$omp end target + + ! BOTH: omp.teams + ! BOTH-SAME: num_teams({{.*}}) + ! BOTH-SAME: thread_limit({{.*}}) + !$omp teams num_teams(1) thread_limit(2) + call foo() + !$omp end teams +end subroutine teams + +subroutine parallel() + ! BOTH: omp.target + + ! HOST-SAME: num_threads({{.*}}) + + ! DEVICE-NOT: num_threads({{.*}}) + ! DEVICE-SAME: { + !$omp target + + ! BOTH: omp.parallel + + ! HOST-NOT: num_threads({{.*}}) + ! HOST-SAME: { + + ! DEVICE-SAME: num_threads({{.*}}) + !$omp parallel num_threads(1) + call foo() + !$omp end parallel + !$omp end target + + ! BOTH: omp.target + ! BOTH-NOT: num_threads({{.*}}) + ! BOTH-SAME: { + !$omp target + call foo() + + ! BOTH: omp.parallel + ! BOTH-SAME: num_threads({{.*}}) + !$omp parallel num_threads(1) + call foo() + !$omp end parallel + !$omp end target + + ! BOTH: omp.parallel + ! BOTH-SAME: num_threads({{.*}}) + !$omp parallel num_threads(1) + call foo() + !$omp end parallel +end subroutine parallel + +subroutine distribute_parallel_do() + ! BOTH: omp.target + + ! HOST-SAME: num_threads({{.*}}) + + ! DEVICE-NOT: num_threads({{.*}}) + ! DEVICE-SAME: { + + ! BOTH: omp.teams + !$omp target teams + + ! BOTH: omp.distribute + ! BOTH-NEXT: omp.parallel + + ! HOST-NOT: num_threads({{.*}}) + ! HOST-SAME: { + + ! DEVICE-SAME: num_threads({{.*}}) + + ! BOTH-NEXT: omp.wsloop + !$omp distribute parallel do num_threads(1) + do i=1,10 + call foo() + end do + !$omp end distribute parallel do + !$omp end target teams + + ! BOTH: omp.target + ! BOTH-NOT: num_threads({{.*}}) + ! BOTH-SAME: { + ! BOTH: omp.teams + !$omp target teams + call foo() + + ! BOTH: omp.distribute + ! BOTH-NEXT: omp.parallel + ! BOTH-SAME: num_threads({{.*}}) + ! BOTH-NEXT: omp.wsloop + !$omp distribute parallel do num_threads(1) + do i=1,10 + call foo() + end do + !$omp end distribute parallel do + !$omp end target teams + + ! BOTH: omp.teams + !$omp teams + + ! BOTH: omp.distribute + ! BOTH-NEXT: omp.parallel + ! BOTH-SAME: num_threads({{.*}}) + ! BOTH-NEXT: omp.wsloop + !$omp distribute parallel do num_threads(1) + do i=1,10 + call foo() + end do + !$omp end distribute parallel do + !$omp end teams +end subroutine distribute_parallel_do + +subroutine distribute_parallel_do_simd() + ! BOTH: omp.target + + ! HOST-SAME: num_threads({{.*}}) + + ! DEVICE-NOT: num_threads({{.*}}) + ! DEVICE-SAME: { + + ! BOTH: omp.teams + !$omp target teams + + ! BOTH: omp.distribute + ! BOTH-NEXT: omp.parallel + + ! HOST-NOT: num_threads({{.*}}) + ! HOST-SAME: { + + ! DEVICE-SAME: num_threads({{.*}}) + + ! BOTH-NEXT: omp.wsloop + ! BOTH-NEXT: omp.simd + !$omp distribute parallel do simd num_threads(1) + do i=1,10 + call foo() + end do + !$omp end distribute parallel do simd + !$omp end target teams + + ! BOTH: omp.target + ! BOTH-NOT: num_threads({{.*}}) + ! BOTH-SAME: { + ! BOTH: omp.teams + !$omp target teams + call foo() + + ! BOTH: omp.distribute + ! BOTH-NEXT: omp.parallel + ! BOTH-SAME: num_threads({{.*}}) + ! BOTH-NEXT: omp.wsloop + ! BOTH-NEXT: omp.simd + !$omp distribute parallel do simd num_threads(1) + do i=1,10 + call foo() + end do + !$omp end distribute parallel do simd + !$omp end target teams + + ! BOTH: omp.teams + !$omp teams + + ! BOTH: omp.distribute + ! BOTH-NEXT: omp.parallel + ! BOTH-SAME: num_threads({{.*}}) + ! BOTH-NEXT: omp.wsloop + ! BOTH-NEXT: omp.simd + !$omp distribute parallel do simd num_threads(1) + do i=1,10 + call foo() + end do + !$omp end distribute parallel do simd + !$omp end teams +end subroutine distribute_parallel_do_simd diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 7e6e22b7f88104..5f6007561967e4 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1506,24 +1506,13 @@ static LogicalResult verifyNumTeamsClause(Operation *op, Value lb, Value ub) { return success(); } -template -static OpTy getSingleNestedOpOfType(Region ®ion) { - auto ops = region.getOps(); - return std::distance(ops.begin(), ops.end()) != 1 ? OpTy() : *ops.begin(); -} - LogicalResult TargetOp::verify() { auto teamsOps = getOps(); if (std::distance(teamsOps.begin(), teamsOps.end()) > 1) return emitError("target containing multiple teams constructs"); - if (!isTargetSPMDLoop()) { - if (getTripCount()) - return emitError("trip_count set on non-SPMD target region"); - - if (getNumThreads() && !getSingleNestedOpOfType(getRegion())) - return emitError("num_threads set on non-SPMD or loop target region"); - } + if (!isTargetSPMDLoop() && getTripCount()) + return emitError("trip_count set on non-SPMD target region"); if (teamsOps.empty()) { if (getNumTeamsLower() || getNumTeamsUpper() || getTeamsThreadLimit()) @@ -1721,17 +1710,6 @@ LogicalResult ParallelOp::verify() { return emitError( "expected equal sizes for allocate and allocator variables"); - auto offloadModOp = - llvm::cast(*(*this)->getParentOfType()); - if (!offloadModOp.getIsTargetDevice()) { - auto targetOp = (*this)->getParentOfType(); - if (getNumThreads() && targetOp && - (targetOp.isTargetSPMDLoop() || - getSingleNestedOpOfType(targetOp.getRegion()) == *this)) - return emitError("num_threads argument expected to be attached to parent " - "omp.target operation instead"); - } - if (failed(verifyPrivateVarList(*this))) return failure(); @@ -1777,10 +1755,7 @@ LogicalResult TeamsOp::verify() { auto offloadModOp = llvm::cast(*(*this)->getParentOfType()); if (targetOp && !offloadModOp.getIsTargetDevice()) { - // Only disallow num_teams and thread_limit if this is the only omp.teams - // inside the target region. - if (getSingleNestedOpOfType(targetOp.getRegion()) == *this && - (getNumTeamsLower() || getNumTeamsUpper() || getThreadLimit())) + if (getNumTeamsLower() || getNumTeamsUpper() || getThreadLimit()) return emitError("num_teams and thread_limit arguments expected to be " "attached to parent omp.target operation"); } else {