Skip to content

Commit

Permalink
[Flang][MLIR][OpenMP] Fix num_teams, num_threads, thread_limit lowering
Browse files Browse the repository at this point in the history
This patch fixes lowering for the num_teams, num_threads and thread_limit
clauses when inside of a target region and compiling for the host device.

The current approach requires these to be attached to the parent MLIR
omp.target operation. However, some incorrect checks based on the
`evalHasSiblings()` helper function would result in these clauses being
attached to the `omp.teams` or `omp.parallel` operation instead, triggering
a verifier error.

In this patch, these checks are updated to stop breaking when lowering
combined `target teams [X]` constructs. Also, the `genTeamsClauses()` function
is fixed to avoid processing num_teams and thread_limit twice, which probably
resulted from a recent merge.
  • Loading branch information
skatrak committed Aug 5, 2024
1 parent 0058159 commit 34956b0
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 48 deletions.
111 changes: 64 additions & 47 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,27 @@ using namespace Fortran::lower::omp;
// Code generation helper functions
//===----------------------------------------------------------------------===//

static bool evalHasSiblings(lower::pft::Evaluation &eval) {
return eval.parent.visit(common::visitors{
[&](const lower::pft::Program &parent) {
return parent.getUnits().size() + parent.getCommonBlocks().size() > 1;
},
[&](const lower::pft::Evaluation &parent) {
for (auto &sibling : *parent.evaluationList)
if (&sibling != &eval && !sibling.isEndStmt())
return true;

return false;
},
[&](const auto &parent) {
for (auto &sibling : parent.evaluationList)
if (&sibling != &eval && !sibling.isEndStmt())
return true;

return false;
}});
}

static mlir::omp::TargetOp findParentTargetOp(mlir::OpBuilder &builder) {
mlir::Operation *parentOp = builder.getBlock()->getParentOp();
if (!parentOp)
Expand Down Expand Up @@ -92,6 +113,38 @@ static void genNestedEvaluations(lower::AbstractConverter &converter,
converter.genEval(e);
}

static bool mustEvalTeamsThreadsOutsideTarget(lower::pft::Evaluation &eval,
mlir::omp::TargetOp targetOp) {
if (!targetOp)
return false;

auto offloadModOp = llvm::cast<mlir::omp::OffloadModuleInterface>(
*targetOp->getParentOfType<mlir::ModuleOp>());
if (offloadModOp.getIsTargetDevice())
return false;

auto dir = Fortran::common::visit(
common::visitors{
[&](const parser::OpenMPBlockConstruct &c) {
return std::get<parser::OmpBlockDirective>(
std::get<parser::OmpBeginBlockDirective>(c.t).t)
.v;
},
[&](const parser::OpenMPLoopConstruct &c) {
return std::get<parser::OmpLoopDirective>(
std::get<parser::OmpBeginLoopDirective>(c.t).t)
.v;
},
[&](const auto &) {
llvm_unreachable("Unexpected OpenMP construct");
return llvm::omp::OMPD_unknown;
},
},
eval.get<parser::OpenMPConstruct>().u);

return llvm::omp::allTargetSet.test(dir) || !evalHasSiblings(eval);
}

//===----------------------------------------------------------------------===//
// HostClausesInsertionGuard
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -412,27 +465,6 @@ createAndSetPrivatizedLoopVar(lower::AbstractConverter &converter,
return storeOp;
}

static bool evalHasSiblings(lower::pft::Evaluation &eval) {
return eval.parent.visit(common::visitors{
[&](const lower::pft::Program &parent) {
return parent.getUnits().size() + parent.getCommonBlocks().size() > 1;
},
[&](const lower::pft::Evaluation &parent) {
for (auto &sibling : *parent.evaluationList)
if (&sibling != &eval && !sibling.isEndStmt())
return true;

return false;
},
[&](const auto &parent) {
for (auto &sibling : parent.evaluationList)
if (&sibling != &eval && !sibling.isEndStmt())
return true;

return false;
}});
}

// This helper function implements the functionality of "promoting"
// non-CPTR arguments of use_device_ptr to use_device_addr
// arguments (automagic conversion of use_device_ptr ->
Expand Down Expand Up @@ -1549,8 +1581,6 @@ genTeamsClauses(lower::AbstractConverter &converter,
cp.processAllocate(clauseOps);
cp.processDefault();
cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
cp.processNumTeams(stmtCtx, clauseOps);
cp.processThreadLimit(stmtCtx, clauseOps);
// TODO Support delayed privatization.

// Evaluate NUM_TEAMS and THREAD_LIMIT on the host device, if currently inside
Expand Down Expand Up @@ -2291,17 +2321,15 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
ConstructQueue::iterator item) {
lower::StatementContext stmtCtx;

auto offloadModOp = llvm::cast<mlir::omp::OffloadModuleInterface>(
converter.getModuleOp().getOperation());
mlir::omp::TargetOp targetOp =
findParentTargetOp(converter.getFirOpBuilder());
bool mustEvalOutsideTarget = targetOp && !offloadModOp.getIsTargetDevice();
bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp);

mlir::omp::TeamsOperands clauseOps;
mlir::omp::NumTeamsClauseOps numTeamsClauseOps;
mlir::omp::ThreadLimitClauseOps threadLimitClauseOps;
genTeamsClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
mustEvalOutsideTarget, clauseOps, numTeamsClauseOps,
evalOutsideTarget, clauseOps, numTeamsClauseOps,
threadLimitClauseOps);

auto teamsOp = genOpWithBody<mlir::omp::TeamsOp>(
Expand All @@ -2311,15 +2339,15 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
queue, item, clauseOps);

if (numTeamsClauseOps.numTeamsUpper) {
if (mustEvalOutsideTarget)
if (evalOutsideTarget)
targetOp.getNumTeamsUpperMutable().assign(
numTeamsClauseOps.numTeamsUpper);
else
teamsOp.getNumTeamsUpperMutable().assign(numTeamsClauseOps.numTeamsUpper);
}

if (threadLimitClauseOps.threadLimit) {
if (mustEvalOutsideTarget)
if (evalOutsideTarget)
targetOp.getTeamsThreadLimitMutable().assign(
threadLimitClauseOps.threadLimit);
else
Expand Down Expand Up @@ -2399,12 +2427,9 @@ static void genStandaloneParallel(lower::AbstractConverter &converter,
ConstructQueue::iterator item) {
lower::StatementContext stmtCtx;

auto offloadModOp =
llvm::cast<mlir::omp::OffloadModuleInterface>(*converter.getModuleOp());
mlir::omp::TargetOp targetOp =
findParentTargetOp(converter.getFirOpBuilder());
bool evalOutsideTarget =
targetOp && !offloadModOp.getIsTargetDevice() && !evalHasSiblings(eval);
bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp);

mlir::omp::ParallelOperands parallelClauseOps;
mlir::omp::NumThreadsClauseOps numThreadsClauseOps;
Expand Down Expand Up @@ -2463,12 +2488,9 @@ static void genCompositeDistributeParallelDo(
ConstructQueue::iterator item, DataSharingProcessor &dsp) {
lower::StatementContext stmtCtx;

auto offloadModOp =
llvm::cast<mlir::omp::OffloadModuleInterface>(*converter.getModuleOp());
mlir::omp::TargetOp targetOp =
findParentTargetOp(converter.getFirOpBuilder());
bool evalOutsideTarget =
targetOp && !offloadModOp.getIsTargetDevice() && !evalHasSiblings(eval);
bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp);

// Clause processing.
mlir::omp::DistributeOperands distributeClauseOps;
Expand All @@ -2480,9 +2502,8 @@ static void genCompositeDistributeParallelDo(
llvm::SmallVector<const semantics::Symbol *> parallelReductionSyms;
llvm::SmallVector<mlir::Type> parallelReductionTypes;
genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
/*evalOutsideTarget=*/evalOutsideTarget, parallelClauseOps,
numThreadsClauseOps, parallelReductionTypes,
parallelReductionSyms);
evalOutsideTarget, parallelClauseOps, numThreadsClauseOps,
parallelReductionTypes, parallelReductionSyms);

const auto &privateClauseOps = dsp.getPrivateClauseOps();
parallelClauseOps.privateVars = privateClauseOps.privateVars;
Expand Down Expand Up @@ -2538,12 +2559,9 @@ static void genCompositeDistributeParallelDoSimd(
ConstructQueue::iterator item, DataSharingProcessor &dsp) {
lower::StatementContext stmtCtx;

auto offloadModOp =
llvm::cast<mlir::omp::OffloadModuleInterface>(*converter.getModuleOp());
mlir::omp::TargetOp targetOp =
findParentTargetOp(converter.getFirOpBuilder());
bool evalOutsideTarget =
targetOp && !offloadModOp.getIsTargetDevice() && !evalHasSiblings(eval);
bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp);

// Clause processing.
mlir::omp::DistributeOperands distributeClauseOps;
Expand All @@ -2555,9 +2573,8 @@ static void genCompositeDistributeParallelDoSimd(
llvm::SmallVector<const semantics::Symbol *> parallelReductionSyms;
llvm::SmallVector<mlir::Type> parallelReductionTypes;
genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
/*evalOutsideTarget=*/evalOutsideTarget, parallelClauseOps,
numThreadsClauseOps, parallelReductionTypes,
parallelReductionSyms);
evalOutsideTarget, parallelClauseOps, numThreadsClauseOps,
parallelReductionTypes, parallelReductionSyms);

const auto &privateClauseOps = dsp.getPrivateClauseOps();
parallelClauseOps.privateVars = privateClauseOps.privateVars;
Expand Down
5 changes: 4 additions & 1 deletion mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1777,7 +1777,10 @@ LogicalResult TeamsOp::verify() {
auto offloadModOp =
llvm::cast<OffloadModuleInterface>(*(*this)->getParentOfType<ModuleOp>());
if (targetOp && !offloadModOp.getIsTargetDevice()) {
if (getNumTeamsLower() || getNumTeamsUpper() || getThreadLimit())
// Only disallow num_teams and thread_limit if this is the only omp.teams
// inside the target region.
if (getSingleNestedOpOfType<TeamsOp>(targetOp.getRegion()) == *this &&
(getNumTeamsLower() || getNumTeamsUpper() || getThreadLimit()))
return emitError("num_teams and thread_limit arguments expected to be "
"attached to parent omp.target operation");
} else {
Expand Down

0 comments on commit 34956b0

Please sign in to comment.