Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flang][OpenMP] Preliminary support for teams reductions #143

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1698,17 +1698,19 @@ static void genTaskwaitClauses(lower::AbstractConverter &converter,
loc, llvm::omp::Directive::OMPD_taskwait);
}

static void
genTeamsClauses(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx,
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
mlir::Location loc, bool evalOutsideTarget,
mlir::omp::TeamsOperands &clauseOps,
mlir::omp::NumTeamsClauseOps &numTeamsClauseOps,
mlir::omp::ThreadLimitClauseOps &threadLimitClauseOps) {
static void genTeamsClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
mlir::Location loc, bool evalOutsideTarget,
mlir::omp::TeamsOperands &clauseOps,
mlir::omp::NumTeamsClauseOps &numTeamsClauseOps,
mlir::omp::ThreadLimitClauseOps &threadLimitClauseOps,
llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms);

// Evaluate NUM_TEAMS and THREAD_LIMIT on the host device, if currently inside
// of an omp.target operation.
Expand All @@ -1723,8 +1725,6 @@ genTeamsClauses(lower::AbstractConverter &converter,
cp.processNumTeams(stmtCtx, numTeamsClauseOps);
cp.processThreadLimit(stmtCtx, threadLimitClauseOps);
}

// cp.processTODO<clause::Reduction>(loc, llvm::omp::Directive::OMPD_teams);
}

static void genWsloopClauses(
Expand Down Expand Up @@ -2496,14 +2496,22 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
mlir::omp::TeamsOperands clauseOps;
mlir::omp::NumTeamsClauseOps numTeamsClauseOps;
mlir::omp::ThreadLimitClauseOps threadLimitClauseOps;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
llvm::SmallVector<mlir::Type> reductionTypes;
genTeamsClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
evalOutsideTarget, clauseOps, numTeamsClauseOps,
threadLimitClauseOps);
threadLimitClauseOps, reductionTypes, reductionSyms);

auto reductionCallback = [&](mlir::Operation *op) {
genReductionVars(op, converter, loc, reductionSyms, reductionTypes);
return llvm::SmallVector<const semantics::Symbol *>(reductionSyms);
};

auto teamsOp = genOpWithBody<mlir::omp::TeamsOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_teams)
.setClauses(&item->clauses),
.setClauses(&item->clauses)
.setGenRegionEntryCb(reductionCallback),
queue, item, clauseOps);

if (numTeamsClauseOps.numTeamsUpper) {
Expand Down Expand Up @@ -2721,6 +2729,8 @@ static void genCompositeDistributeParallelDo(
genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc,
wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms);

// Pass the innermost leaf construct's clauses because that's where COLLAPSE
// is placed by construct decomposition.
mlir::omp::LoopNestOperands loopNestClauseOps;
llvm::SmallVector<const semantics::Symbol *> iv;
genLoopNestClauses(converter, semaCtx, eval, doItem->clauses, loc,
Expand Down Expand Up @@ -2804,6 +2814,8 @@ static void genCompositeDistributeParallelDoSimd(
mlir::omp::SimdOperands simdClauseOps;
genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps);

// Pass the innermost leaf construct's clauses because that's where COLLAPSE
// is placed by construct decomposition.
mlir::omp::LoopNestOperands loopNestClauseOps;
llvm::SmallVector<const semantics::Symbol *> iv;
genLoopNestClauses(converter, semaCtx, eval, simdItem->clauses, loc,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
! RUN: bbc -emit-fir -fopenmp -o - %s | FileCheck %s
! RUN: %flang_fc1 -emit-fir -fopenmp -o - %s | FileCheck %s
! XFAIL: *

! CHECK: omp.teams
! CHECK-SAME: reduction
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenMP/sections-array-reduction.f90
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ subroutine sectionsReduction(x)
! CHECK: omp.parallel {
! CHECK: %[[VAL_3:.*]] = fir.alloca !fir.box<!fir.array<?xf32>>
! CHECK: fir.store %[[VAL_2]]#1 to %[[VAL_3]] : !fir.ref<!fir.box<!fir.array<?xf32>>>
! CHECK: omp.sections reduction(byref @add_reduction_byref_box_Uxf32 -> %[[VAL_3]] : !fir.ref<!fir.box<!fir.array<?xf32>>>) {
! CHECK: ^bb0(%[[VAL_4:.*]]: !fir.ref<!fir.box<!fir.array<?xf32>>>):
! CHECK: omp.sections reduction(byref @add_reduction_byref_box_Uxf32 %[[VAL_3]] -> %[[VAL_4:.*]] : !fir.ref<!fir.box<!fir.array<?xf32>>>) {
! CHECK: ^bb0(%[[VAL_4]]: !fir.ref<!fir.box<!fir.array<?xf32>>>):
! CHECK: omp.section {
! CHECK: ^bb0(%[[VAL_5:.*]]: !fir.ref<!fir.box<!fir.array<?xf32>>>):
! [...]
Expand Down
8 changes: 4 additions & 4 deletions flang/test/Lower/OpenMP/sections-reduction.f90
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ subroutine sectionsReduction(x,y)
! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %[[VAL_2]] {uniq_name = "_QFsectionsreductionEx"} : (!fir.ref<f32>, !fir.dscope) -> (!fir.ref<f32>, !fir.ref<f32>)
! CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_1]] dummy_scope %[[VAL_2]] {uniq_name = "_QFsectionsreductionEy"} : (!fir.ref<f32>, !fir.dscope) -> (!fir.ref<f32>, !fir.ref<f32>)
! CHECK: omp.parallel {
! CHECK: omp.sections reduction(@add_reduction_f32 -> %[[VAL_3]]#0 : !fir.ref<f32>, @add_reduction_f32 -> %[[VAL_4]]#0 : !fir.ref<f32>) {
! CHECK: ^bb0(%[[VAL_5:.*]]: !fir.ref<f32>, %[[VAL_6:.*]]: !fir.ref<f32>):
! CHECK: omp.sections reduction(@add_reduction_f32 %[[VAL_3]]#0 -> %[[VAL_5:.*]] : !fir.ref<f32>, @add_reduction_f32 %[[VAL_4]]#0 -> %[[VAL_6:.*]] : !fir.ref<f32>) {
! CHECK: ^bb0(%[[VAL_5]]: !fir.ref<f32>, %[[VAL_6]]: !fir.ref<f32>):
! CHECK: omp.section {
! CHECK: ^bb0(%[[VAL_7:.*]]: !fir.ref<f32>, %[[VAL_8:.*]]: !fir.ref<f32>):
! CHECK: %[[VAL_9:.*]]:2 = hlfir.declare %[[VAL_7]] {uniq_name = "_QFsectionsreductionEx"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
Expand Down Expand Up @@ -71,8 +71,8 @@ subroutine sectionsReduction(x,y)
! CHECK: omp.terminator
! CHECK: }
! CHECK: omp.parallel {
! CHECK: omp.sections reduction(@add_reduction_f32 -> %[[VAL_3]]#0 : !fir.ref<f32>, @add_reduction_f32 -> %[[VAL_4]]#0 : !fir.ref<f32>) {
! CHECK: ^bb0(%[[VAL_23:.*]]: !fir.ref<f32>, %[[VAL_24:.*]]: !fir.ref<f32>):
! CHECK: omp.sections reduction(@add_reduction_f32 %[[VAL_3]]#0 -> %[[VAL_23:.*]] : !fir.ref<f32>, @add_reduction_f32 %[[VAL_4]]#0 -> %[[VAL_24:.*]] : !fir.ref<f32>) {
! CHECK: ^bb0(%[[VAL_23]]: !fir.ref<f32>, %[[VAL_24]]: !fir.ref<f32>):
! CHECK: omp.section {
! CHECK: ^bb0(%[[VAL_25:.*]]: !fir.ref<f32>, %[[VAL_26:.*]]: !fir.ref<f32>):
! CHECK: %[[VAL_27:.*]]:2 = hlfir.declare %[[VAL_25]] {uniq_name = "_QFsectionsreductionEx"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
Expand Down
74 changes: 28 additions & 46 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,16 +472,19 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op,
//===----------------------------------------------------------------------===//

static ParseResult parseClauseWithRegionArgs(
OpAsmParser &parser, Region &region,
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref, ArrayAttr &symbols,
SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs) {
SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
bool parseParens = true) {
SmallVector<SymbolRefAttr> reductionVec;
SmallVector<bool> isByRefVec;
unsigned regionArgOffset = regionPrivateArgs.size();

OpAsmParser::Delimiter delimiter = parseParens ? OpAsmParser::Delimiter::Paren
: OpAsmParser::Delimiter::None;
if (failed(
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
parser.parseCommaSeparatedList(delimiter, [&]() {
ParseResult optionalByref = parser.parseOptionalKeyword("byref");
if (parser.parseAttribute(reductionVec.emplace_back()) ||
parser.parseOperand(operands.emplace_back()) ||
Expand Down Expand Up @@ -536,17 +539,17 @@ static ParseResult parseParallelRegion(
llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;

if (succeeded(parser.parseOptionalKeyword("reduction"))) {
if (failed(parseClauseWithRegionArgs(parser, region, reductionVars,
reductionTypes, reductionByref,
reductionSyms, regionPrivateArgs)))
if (failed(parseClauseWithRegionArgs(parser, reductionVars, reductionTypes,
reductionByref, reductionSyms,
regionPrivateArgs)))
return failure();
}

if (succeeded(parser.parseOptionalKeyword("private"))) {
auto privateByref = DenseBoolArrayAttr::get(parser.getContext(), {});
if (failed(parseClauseWithRegionArgs(parser, region, privateVars,
privateTypes, privateByref,
privateSyms, regionPrivateArgs)))
if (failed(parseClauseWithRegionArgs(parser, privateVars, privateTypes,
privateByref, privateSyms,
regionPrivateArgs)))
return failure();
if (llvm::any_of(privateByref.asArrayRef(),
[](bool byref) { return byref; })) {
Expand Down Expand Up @@ -597,45 +600,24 @@ static ParseResult parseReductionVarList(
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
ArrayAttr &reductionSyms) {
SmallVector<SymbolRefAttr> reductionVec;
SmallVector<bool> isByRefVec;
if (failed(parser.parseCommaSeparatedList([&]() {
ParseResult optionalByref = parser.parseOptionalKeyword("byref");
if (parser.parseAttribute(reductionVec.emplace_back()) ||
parser.parseArrow() ||
parser.parseOperand(reductionVars.emplace_back()) ||
parser.parseColonType(reductionTypes.emplace_back()))
return failure();
isByRefVec.push_back(optionalByref.succeeded());
return success();
})))
return failure();
reductionByref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
reductionSyms = ArrayAttr::get(parser.getContext(), reductions);
return success();
llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
return parseClauseWithRegionArgs(parser, reductionVars, reductionTypes,
reductionByref, reductionSyms,
regionPrivateArgs, /*parseParens=*/false);
}

/// Print Reduction clause
static void
printReductionVarList(OpAsmPrinter &p, Operation *op,
OperandRange reductionVars, TypeRange reductionTypes,
std::optional<DenseBoolArrayAttr> reductionByref,
std::optional<ArrayAttr> reductionSyms) {
auto getByRef = [&](unsigned i) -> const char * {
if (!reductionByref || !*reductionByref)
return "";
assert(reductionByref->empty() || i < reductionByref->size());
if (!reductionByref->empty() && (*reductionByref)[i])
return "byref ";
return "";
};

for (unsigned i = 0, e = reductionVars.size(); i < e; ++i) {
if (i != 0)
p << ", ";
p << getByRef(i) << (*reductionSyms)[i] << " -> " << reductionVars[i]
<< " : " << reductionVars[i].getType();
static void printReductionVarList(OpAsmPrinter &p, Operation *op,
OperandRange reductionVars,
TypeRange reductionTypes,
DenseBoolArrayAttr reductionByref,
ArrayAttr reductionSyms) {
if (reductionSyms) {
auto *argsBegin = op->getRegion(0).front().getArguments().begin();
MutableArrayRef argsSubrange(argsBegin, argsBegin + reductionTypes.size());
printClauseWithRegionArgs(p, op, argsSubrange, llvm::StringRef(),
reductionVars, reductionTypes, reductionByref,
reductionSyms);
}
}

Expand Down Expand Up @@ -1850,7 +1832,7 @@ parseWsloop(OpAsmParser &parser, Region &region,
// Parse an optional reduction clause
llvm::SmallVector<OpAsmParser::Argument> privates;
if (succeeded(parser.parseOptionalKeyword("reduction"))) {
if (failed(parseClauseWithRegionArgs(parser, region, reductionOperands,
if (failed(parseClauseWithRegionArgs(parser, reductionOperands,
reductionTypes, reductionByRef,
reductionSymbols, privates)))
return failure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,8 @@ static LogicalResult createReductionsAndCleanup(
SmallVector<OwningReductionGen> &owningReductionGens,
SmallVector<OwningAtomicReductionGen> &owningAtomicReductionGens,
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> &reductionInfos,
bool isTeamsReduction = false, bool hasDistribute = false) {
bool IsNowait = false, bool isTeamsReduction = false,
bool hasDistribute = false) {
// Process the reductions if required.
if (op.getNumReductionVars() == 0)
return success();
Expand All @@ -884,7 +885,7 @@ static LogicalResult createReductionsAndCleanup(
builder.SetInsertPoint(tempTerminator);
llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint =
ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
isByRef, op.getNowait(), isTeamsReduction,
isByRef, IsNowait, isTeamsReduction,
hasDistribute);
if (!contInsertPoint.getBlock())
return op->emitOpError() << "failed to convert reductions";
Expand Down Expand Up @@ -1083,7 +1084,7 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
return createReductionsAndCleanup(
sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
privateReductionVariables, isByRef, owningReductionGens,
owningAtomicReductionGens, reductionInfos);
owningAtomicReductionGens, reductionInfos, sectionsOp.getNowait());
}

/// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
Expand Down Expand Up @@ -1127,10 +1128,36 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
LogicalResult bodyGenStatus = success();
if (!op.getAllocatorVars().empty() || op.getReductionSyms() ||
!op.getPrivateVars().empty() || op.getPrivateSyms())
if (!op.getAllocatorVars().empty() || !op.getPrivateVars().empty() ||
op.getPrivateSyms())
return op.emitError("unhandled clauses for translation to LLVM IR");

llvm::ArrayRef<bool> isByRef = getIsByRef(op.getReductionByref());
assert(isByRef.size() == op.getNumReductionVars());

SmallVector<omp::DeclareReductionOp> reductionDecls;
collectReductionDecls(op, reductionDecls);
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);

SmallVector<llvm::Value *> privateReductionVariables(
op.getNumReductionVars());
DenseMap<Value, llvm::Value *> reductionVariableMap;

MutableArrayRef<BlockArgument> reductionArgs = op.getRegion().getArguments();

if (failed(allocAndInitializeReductionVars(
op, reductionArgs, builder, moduleTranslation, allocaIP,
reductionDecls, privateReductionVariables, reductionVariableMap,
isByRef)))
return failure();

// Store the mapping between reduction variables and their private copies on
// ModuleTranslation stack. It can be then recovered when translating
// omp.reduce operations in a separate call.
LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard(
moduleTranslation, reductionVariableMap);

auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
moduleTranslation, allocaIP);
Expand Down Expand Up @@ -1160,7 +1187,17 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
builder.restoreIP(ompBuilder->createTeams(
ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr));

return bodyGenStatus;
if (failed(bodyGenStatus))
return bodyGenStatus;

// Process the reductions if required.
SmallVector<OwningReductionGen> owningReductionGens;
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
return createReductionsAndCleanup(op, builder, moduleTranslation, allocaIP,
reductionDecls, privateReductionVariables,
isByRef, owningReductionGens,
owningAtomicReductionGens, reductionInfos);
}

static void
Expand Down Expand Up @@ -1430,8 +1467,8 @@ static LogicalResult convertOmpWsloop(
return createReductionsAndCleanup(
wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
privateReductionVariables, isByRef, owningReductionGens,
owningAtomicReductionGens, reductionInfos, /*isTeamsReduction=*/false,
distributeCodeGen);
owningAtomicReductionGens, reductionInfos, wsloopOp.getNowait(),
/*isTeamsReduction=*/false, distributeCodeGen);
}

static LogicalResult
Expand Down
Loading