Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flang][MLIR][OpenMP] Fix num_teams, num_threads, thread_limit lowering #132

Merged
merged 2 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 64 additions & 47 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,27 @@ using namespace Fortran::lower::omp;
// Code generation helper functions
//===----------------------------------------------------------------------===//

static bool evalHasSiblings(lower::pft::Evaluation &eval) {
auto checkSiblings = [&eval](const lower::pft::EvaluationList &siblings) {
for (auto &sibling : siblings)
if (&sibling != &eval && !sibling.isEndStmt())
return true;

return false;
};

return eval.parent.visit(common::visitors{
[&](const lower::pft::Program &parent) {
return parent.getUnits().size() + parent.getCommonBlocks().size() > 1;
},
[&](const lower::pft::Evaluation &parent) {
return checkSiblings(*parent.evaluationList);
},
[&](const auto &parent) {
return checkSiblings(parent.evaluationList);
}});
}

static mlir::omp::TargetOp findParentTargetOp(mlir::OpBuilder &builder) {
mlir::Operation *parentOp = builder.getBlock()->getParentOp();
if (!parentOp)
Expand Down Expand Up @@ -92,6 +113,38 @@ static void genNestedEvaluations(lower::AbstractConverter &converter,
converter.genEval(e);
}

static bool mustEvalTeamsThreadsOutsideTarget(lower::pft::Evaluation &eval,
mlir::omp::TargetOp targetOp) {
if (!targetOp)
return false;

auto offloadModOp = llvm::cast<mlir::omp::OffloadModuleInterface>(
*targetOp->getParentOfType<mlir::ModuleOp>());
if (offloadModOp.getIsTargetDevice())
return false;

auto dir = Fortran::common::visit(
common::visitors{
[&](const parser::OpenMPBlockConstruct &c) {
return std::get<parser::OmpBlockDirective>(
std::get<parser::OmpBeginBlockDirective>(c.t).t)
.v;
},
[&](const parser::OpenMPLoopConstruct &c) {
return std::get<parser::OmpLoopDirective>(
std::get<parser::OmpBeginLoopDirective>(c.t).t)
.v;
},
[&](const auto &) {
llvm_unreachable("Unexpected OpenMP construct");
return llvm::omp::OMPD_unknown;
},
},
eval.get<parser::OpenMPConstruct>().u);

return llvm::omp::allTargetSet.test(dir) || !evalHasSiblings(eval);
}

//===----------------------------------------------------------------------===//
// HostClausesInsertionGuard
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -425,27 +478,6 @@ createAndSetPrivatizedLoopVar(lower::AbstractConverter &converter,
return storeOp;
}

static bool evalHasSiblings(lower::pft::Evaluation &eval) {
return eval.parent.visit(common::visitors{
[&](const lower::pft::Program &parent) {
return parent.getUnits().size() + parent.getCommonBlocks().size() > 1;
},
[&](const lower::pft::Evaluation &parent) {
for (auto &sibling : *parent.evaluationList)
if (&sibling != &eval && !sibling.isEndStmt())
return true;

return false;
},
[&](const auto &parent) {
for (auto &sibling : parent.evaluationList)
if (&sibling != &eval && !sibling.isEndStmt())
return true;

return false;
}});
}

// This helper function implements the functionality of "promoting"
// non-CPTR arguments of use_device_ptr to use_device_addr
// arguments (automagic conversion of use_device_ptr ->
Expand Down Expand Up @@ -1562,8 +1594,6 @@ genTeamsClauses(lower::AbstractConverter &converter,
cp.processAllocate(clauseOps);
cp.processDefault();
cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
cp.processNumTeams(stmtCtx, clauseOps);
cp.processThreadLimit(stmtCtx, clauseOps);
// TODO Support delayed privatization.

// Evaluate NUM_TEAMS and THREAD_LIMIT on the host device, if currently inside
Expand Down Expand Up @@ -2304,17 +2334,15 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
ConstructQueue::iterator item) {
lower::StatementContext stmtCtx;

auto offloadModOp = llvm::cast<mlir::omp::OffloadModuleInterface>(
converter.getModuleOp().getOperation());
mlir::omp::TargetOp targetOp =
findParentTargetOp(converter.getFirOpBuilder());
bool mustEvalOutsideTarget = targetOp && !offloadModOp.getIsTargetDevice();
bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp);

mlir::omp::TeamsOperands clauseOps;
mlir::omp::NumTeamsClauseOps numTeamsClauseOps;
mlir::omp::ThreadLimitClauseOps threadLimitClauseOps;
genTeamsClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
mustEvalOutsideTarget, clauseOps, numTeamsClauseOps,
evalOutsideTarget, clauseOps, numTeamsClauseOps,
threadLimitClauseOps);

auto teamsOp = genOpWithBody<mlir::omp::TeamsOp>(
Expand All @@ -2324,15 +2352,15 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
queue, item, clauseOps);

if (numTeamsClauseOps.numTeamsUpper) {
if (mustEvalOutsideTarget)
if (evalOutsideTarget)
targetOp.getNumTeamsUpperMutable().assign(
numTeamsClauseOps.numTeamsUpper);
else
teamsOp.getNumTeamsUpperMutable().assign(numTeamsClauseOps.numTeamsUpper);
}

if (threadLimitClauseOps.threadLimit) {
if (mustEvalOutsideTarget)
if (evalOutsideTarget)
targetOp.getTeamsThreadLimitMutable().assign(
threadLimitClauseOps.threadLimit);
else
Expand Down Expand Up @@ -2412,12 +2440,9 @@ static void genStandaloneParallel(lower::AbstractConverter &converter,
ConstructQueue::iterator item) {
lower::StatementContext stmtCtx;

auto offloadModOp =
llvm::cast<mlir::omp::OffloadModuleInterface>(*converter.getModuleOp());
mlir::omp::TargetOp targetOp =
findParentTargetOp(converter.getFirOpBuilder());
bool evalOutsideTarget =
targetOp && !offloadModOp.getIsTargetDevice() && !evalHasSiblings(eval);
bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp);

mlir::omp::ParallelOperands parallelClauseOps;
mlir::omp::NumThreadsClauseOps numThreadsClauseOps;
Expand Down Expand Up @@ -2476,12 +2501,9 @@ static void genCompositeDistributeParallelDo(
ConstructQueue::iterator item, DataSharingProcessor &dsp) {
lower::StatementContext stmtCtx;

auto offloadModOp =
llvm::cast<mlir::omp::OffloadModuleInterface>(*converter.getModuleOp());
mlir::omp::TargetOp targetOp =
findParentTargetOp(converter.getFirOpBuilder());
bool evalOutsideTarget =
targetOp && !offloadModOp.getIsTargetDevice() && !evalHasSiblings(eval);
bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp);

// Clause processing.
mlir::omp::DistributeOperands distributeClauseOps;
Expand All @@ -2493,9 +2515,8 @@ static void genCompositeDistributeParallelDo(
llvm::SmallVector<const semantics::Symbol *> parallelReductionSyms;
llvm::SmallVector<mlir::Type> parallelReductionTypes;
genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
/*evalOutsideTarget=*/evalOutsideTarget, parallelClauseOps,
numThreadsClauseOps, parallelReductionTypes,
parallelReductionSyms);
evalOutsideTarget, parallelClauseOps, numThreadsClauseOps,
parallelReductionTypes, parallelReductionSyms);

const auto &privateClauseOps = dsp.getPrivateClauseOps();
parallelClauseOps.privateVars = privateClauseOps.privateVars;
Expand Down Expand Up @@ -2551,12 +2572,9 @@ static void genCompositeDistributeParallelDoSimd(
ConstructQueue::iterator item, DataSharingProcessor &dsp) {
lower::StatementContext stmtCtx;

auto offloadModOp =
llvm::cast<mlir::omp::OffloadModuleInterface>(*converter.getModuleOp());
mlir::omp::TargetOp targetOp =
findParentTargetOp(converter.getFirOpBuilder());
bool evalOutsideTarget =
targetOp && !offloadModOp.getIsTargetDevice() && !evalHasSiblings(eval);
bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp);

// Clause processing.
mlir::omp::DistributeOperands distributeClauseOps;
Expand All @@ -2568,9 +2586,8 @@ static void genCompositeDistributeParallelDoSimd(
llvm::SmallVector<const semantics::Symbol *> parallelReductionSyms;
llvm::SmallVector<mlir::Type> parallelReductionTypes;
genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
/*evalOutsideTarget=*/evalOutsideTarget, parallelClauseOps,
numThreadsClauseOps, parallelReductionTypes,
parallelReductionSyms);
evalOutsideTarget, parallelClauseOps, numThreadsClauseOps,
parallelReductionTypes, parallelReductionSyms);

const auto &privateClauseOps = dsp.getPrivateClauseOps();
parallelClauseOps.privateVars = privateClauseOps.privateVars;
Expand Down
Loading