Skip to content

Commit

Permalink
[Flang][OpenMP] PFT-based detection of target SPMD (#144)
Browse files Browse the repository at this point in the history
This patch improves the fix in #125 to detect target SPMD kernels during Flang
lowering to MLIR. It transitions from a MLIR-based check to a PFT-based check,
which is a more resilient alternative since the MLIR representation is in
process of being built where it's being checked.
  • Loading branch information
skatrak authored Aug 19, 2024
1 parent 3d730dc commit 6f99163
Show file tree
Hide file tree
Showing 2 changed files with 346 additions and 41 deletions.
196 changes: 155 additions & 41 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,68 @@ using namespace Fortran::lower::omp;
// Code generation helper functions
//===----------------------------------------------------------------------===//

static bool evalHasSiblings(lower::pft::Evaluation &eval) {
/// Get the directive enumeration value corresponding to the given OpenMP
/// construct PFT node.
llvm::omp::Directive
extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) {
return common::visit(
common::visitors{
[](const parser::OpenMPAllocatorsConstruct &c) {
return llvm::omp::OMPD_allocators;
},
[](const parser::OpenMPAtomicConstruct &c) {
return llvm::omp::OMPD_atomic;
},
[](const parser::OpenMPBlockConstruct &c) {
return std::get<parser::OmpBlockDirective>(
std::get<parser::OmpBeginBlockDirective>(c.t).t)
.v;
},
[](const parser::OpenMPCriticalConstruct &c) {
return llvm::omp::OMPD_critical;
},
[](const parser::OpenMPDeclarativeAllocate &c) {
return llvm::omp::OMPD_allocate;
},
[](const parser::OpenMPExecutableAllocate &c) {
return llvm::omp::OMPD_allocate;
},
[](const parser::OpenMPLoopConstruct &c) {
return std::get<parser::OmpLoopDirective>(
std::get<parser::OmpBeginLoopDirective>(c.t).t)
.v;
},
[](const parser::OpenMPSectionConstruct &c) {
return llvm::omp::OMPD_section;
},
[](const parser::OpenMPSectionsConstruct &c) {
return std::get<parser::OmpSectionsDirective>(
std::get<parser::OmpBeginSectionsDirective>(c.t).t)
.v;
},
[](const parser::OpenMPStandaloneConstruct &c) {
return common::visit(
common::visitors{
[](const parser::OpenMPSimpleStandaloneConstruct &c) {
return std::get<parser::OmpSimpleStandaloneDirective>(c.t)
.v;
},
[](const parser::OpenMPFlushConstruct &c) {
return llvm::omp::OMPD_flush;
},
[](const parser::OpenMPCancelConstruct &c) {
return llvm::omp::OMPD_cancel;
},
[](const parser::OpenMPCancellationPointConstruct &c) {
return llvm::omp::OMPD_cancellation_point;
}},
c.u);
}},
ompConstruct.u);
}

/// Check whether the parent of the given evaluation contains other evaluations.
static bool evalHasSiblings(const lower::pft::Evaluation &eval) {
auto checkSiblings = [&eval](const lower::pft::EvaluationList &siblings) {
for (auto &sibling : siblings)
if (&sibling != &eval && !sibling.isEndStmt())
Expand All @@ -67,6 +128,80 @@ static bool evalHasSiblings(lower::pft::Evaluation &eval) {
}});
}

/// Check whether a given evaluation points to an OpenMP loop construct that
/// represents a target SPMD kernel. For this to be true, it must be a `target
/// teams distribute parallel do [simd]` or equivalent construct.
///
/// Currently, this is limited to cases where all relevant OpenMP constructs are
/// either combined or directly nested within the same function. Also, the
/// composite `distribute parallel do` is not identified if split into two
/// explicit nested loops (a `distribute` loop and a `parallel do` loop).
static bool isTargetSPMDLoop(const lower::pft::Evaluation &eval) {
using namespace llvm::omp;

const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
if (!ompEval)
return false;

switch (extractOmpDirective(*ompEval)) {
case OMPD_distribute_parallel_do:
case OMPD_distribute_parallel_do_simd: {
// It will return true only if one of these are true:
// - It has a 'target teams' parent and no siblings.
// - It has a 'teams' parent and no siblings, and the 'teams' has a
// 'target' parent and no siblings.
if (evalHasSiblings(eval))
return false;

const auto *parentEval = eval.parent.getIf<lower::pft::Evaluation>();
if (!parentEval)
return false;

const auto *parentOmpEval = parentEval->getIf<parser::OpenMPConstruct>();
if (!parentOmpEval)
return false;

auto parentDir = extractOmpDirective(*parentOmpEval);
if (parentDir == OMPD_target_teams)
return true;

if (parentDir != OMPD_teams)
return false;

if (evalHasSiblings(*parentEval))
return false;

const auto *parentOfParentEval =
parentEval->parent.getIf<lower::pft::Evaluation>();
if (!parentEval)
return false;

const auto *parentOfParentOmpEval =
parentOfParentEval->getIf<parser::OpenMPConstruct>();
return parentOfParentOmpEval &&
extractOmpDirective(*parentOfParentOmpEval) == OMPD_target;
}
case OMPD_teams_distribute_parallel_do:
case OMPD_teams_distribute_parallel_do_simd: {
// Check there's a 'target' parent and no siblings.
if (evalHasSiblings(eval))
return false;

const auto *parentEval = eval.parent.getIf<lower::pft::Evaluation>();
if (!parentEval)
return false;

const auto *parentOmpEval = parentEval->getIf<parser::OpenMPConstruct>();
return parentOmpEval && extractOmpDirective(*parentOmpEval) == OMPD_target;
}
case OMPD_target_teams_distribute_parallel_do:
case OMPD_target_teams_distribute_parallel_do_simd:
return true;
default:
return false;
}
}

static mlir::omp::TargetOp findParentTargetOp(mlir::OpBuilder &builder) {
mlir::Operation *parentOp = builder.getBlock()->getParentOp();
if (!parentOp)
Expand Down Expand Up @@ -113,8 +248,9 @@ static void genNestedEvaluations(lower::AbstractConverter &converter,
converter.genEval(e);
}

static bool mustEvalTeamsThreadsOutsideTarget(lower::pft::Evaluation &eval,
mlir::omp::TargetOp targetOp) {
static bool
mustEvalTeamsThreadsOutsideTarget(const lower::pft::Evaluation &eval,
mlir::omp::TargetOp targetOp) {
if (!targetOp)
return false;

Expand All @@ -123,25 +259,8 @@ static bool mustEvalTeamsThreadsOutsideTarget(lower::pft::Evaluation &eval,
if (offloadModOp.getIsTargetDevice())
return false;

auto dir = Fortran::common::visit(
common::visitors{
[&](const parser::OpenMPBlockConstruct &c) {
return std::get<parser::OmpBlockDirective>(
std::get<parser::OmpBeginBlockDirective>(c.t).t)
.v;
},
[&](const parser::OpenMPLoopConstruct &c) {
return std::get<parser::OmpLoopDirective>(
std::get<parser::OmpBeginLoopDirective>(c.t).t)
.v;
},
[&](const auto &) {
llvm_unreachable("Unexpected OpenMP construct");
return llvm::omp::OMPD_unknown;
},
},
eval.get<parser::OpenMPConstruct>().u);

llvm::omp::Directive dir =
extractOmpDirective(eval.get<parser::OpenMPConstruct>());
return llvm::omp::allTargetSet.test(dir) || !evalHasSiblings(eval);
}

Expand Down Expand Up @@ -1722,25 +1841,20 @@ genLoopNestOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
firOpBuilder.getModule().getOperation());
auto targetOp = loopNestOp->getParentOfType<mlir::omp::TargetOp>();

if (offloadMod && targetOp && !offloadMod.getIsTargetDevice()) {
if (targetOp.isTargetSPMDLoop()) {
// Lower loop bounds and step, and process collapsing again, putting
// lowered values outside of omp.target this time. This enables
// calculating and accessing the trip count in the host, which is needed
// when lowering to LLVM IR via the OMPIRBuilder.
HostClausesInsertionGuard guard(firOpBuilder);
mlir::omp::LoopRelatedOps loopRelatedOps;
llvm::SmallVector<const semantics::Symbol *> iv;
ClauseProcessor cp(converter, semaCtx, item->clauses);
cp.processCollapse(loc, eval, loopRelatedOps, iv);
targetOp.getTripCountMutable().assign(
calculateTripCount(converter.getFirOpBuilder(), loc, loopRelatedOps));
} else if (targetOp.getTripCountMutable().size()) {
// The MLIR target operation was updated during PFT lowering,
// and it is no longer an SPMD kernel. Erase the trip count because
// as it is now invalid.
targetOp.getTripCountMutable().erase(0);
}
if (offloadMod && !offloadMod.getIsTargetDevice() && isTargetSPMDLoop(eval)) {
assert(targetOp && "must have omp.target parent");

// Lower loop bounds and step, and process collapsing again, putting lowered
// values outside of omp.target this time. This enables calculating and
// accessing the trip count in the host, which is needed when lowering to
// LLVM IR via the OMPIRBuilder.
HostClausesInsertionGuard guard(firOpBuilder);
mlir::omp::LoopRelatedOps loopRelatedOps;
llvm::SmallVector<const semantics::Symbol *> iv;
ClauseProcessor cp(converter, semaCtx, item->clauses);
cp.processCollapse(loc, eval, loopRelatedOps, iv);
targetOp.getTripCountMutable().assign(
calculateTripCount(converter.getFirOpBuilder(), loc, loopRelatedOps));
}
return loopNestOp;
}
Expand Down
Loading

0 comments on commit 6f99163

Please sign in to comment.