Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][OpenMP] Extend do concurrent mapping to device. #50

Merged
merged 10 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I suppose moving flang/include/Lower/OpenMP.h to flang/include/Lower/OpenMP/OpenMP.h would be in order as well (just to keep things consistent between the lib/ and include/ directories). If you do, remember to update the include guard on that file.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To reduce the pain for anyone merging from upstream, I prefer to avoid doing that for now. Maybe we should replicate the header reorganization upstream first?

File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ struct OmpMapMemberIndicesData {
};

mlir::omp::MapInfoOp
createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
createMapInfoOp(mlir::OpBuilder &builder, mlir::Location loc,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not against changing this to the more generic mlir::OpBuilder, but I am wondering if we should just keep it the same as most of the other flang utilities to keep it somewhat standardized, but I'll leave that up to you!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will come a small additional cost of having to replicate FirOpBuilder::createIntegerConstant's logic. Not much work but I would prefer to keep the FirOpBuilder parameter.

mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
mlir::ArrayRef<mlir::Value> bounds,
mlir::ArrayRef<mlir::Value> members,
Expand Down Expand Up @@ -102,6 +102,8 @@ void genObjectList(const ObjectList &objects,
Fortran::lower::AbstractConverter &converter,
llvm::SmallVectorImpl<mlir::Value> &operands);

mlir::Value calculateTripCount(fir::FirOpBuilder &builder, mlir::Location loc,
const mlir::omp::CollapseClauseOps &ops);
ergawy marked this conversation as resolved.
Show resolved Hide resolved
} // namespace omp
} // namespace lower
} // namespace Fortran
Expand Down
2 changes: 2 additions & 0 deletions flang/include/flang/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace fir {
#define GEN_PASS_DECL_ARRAYVALUECOPY
#define GEN_PASS_DECL_CHARACTERCONVERSION
#define GEN_PASS_DECL_CFGCONVERSION
#define GEN_PASS_DECL_DOCONCURRENTCONVERSIONPASS
#define GEN_PASS_DECL_EXTERNALNAMECONVERSION
#define GEN_PASS_DECL_MEMREFDATAFLOWOPT
#define GEN_PASS_DECL_SIMPLIFYINTRINSICS
Expand Down Expand Up @@ -89,6 +90,7 @@ createFunctionAttrPass(FunctionAttrTypes &functionAttr, bool noInfsFPMath,
bool noSignedZerosFPMath, bool unsafeFPMath);

std::unique_ptr<mlir::Pass> createDoConcurrentConversionPass();
ergawy marked this conversation as resolved.
Show resolved Hide resolved
std::unique_ptr<mlir::Pass> createDoConcurrentConversionPass(bool mapToDevice);

void populateCfgConversionRewrites(mlir::RewritePatternSet &patterns,
bool forceLoopToExecuteOnce = false);
Expand Down
6 changes: 5 additions & 1 deletion flang/include/flang/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,12 @@ def DoConcurrentConversionPass : Pass<"fopenmp-do-concurrent-conversion", "mlir:
target.
}];

let constructor = "::fir::createDoConcurrentConversionPass()";
let dependentDialects = ["mlir::omp::OpenMPDialect"];

let options = [
Option<"mapTo", "map-to", "std::string", "",
"Try to map `do concurrent` loops to OpenMP (on host or device)">,
];
}

#endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES
37 changes: 16 additions & 21 deletions flang/lib/Frontend/FrontendActions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,22 +320,14 @@ bool CodeGenAction::beginSourceFileAction() {
// Add OpenMP-related passes
// WARNING: These passes must be run immediately after the lowering to ensure
// that the FIR is correct with respect to OpenMP operations/attributes.
bool isOpenMPEnabled = ci.getInvocation().getFrontendOpts().features.IsEnabled(
bool isOpenMPEnabled =
ci.getInvocation().getFrontendOpts().features.IsEnabled(
Fortran::common::LanguageFeature::OpenMP);
if (isOpenMPEnabled) {
bool isDevice = false;
if (auto offloadMod = llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(
mlirModule->getOperation()))
isDevice = offloadMod.getIsTargetDevice();
// WARNING: This pipeline must be run immediately after the lowering to
// ensure that the FIR is correct with respect to OpenMP operations/
// attributes.
fir::createOpenMPFIRPassPipeline(pm, isDevice);
}

using DoConcurrentMappingKind =
Fortran::frontend::CodeGenOptions::DoConcurrentMappingKind;
DoConcurrentMappingKind selectedKind = ci.getInvocation().getCodeGenOpts().getDoConcurrentMapping();
DoConcurrentMappingKind selectedKind =
ci.getInvocation().getCodeGenOpts().getDoConcurrentMapping();
if (selectedKind != DoConcurrentMappingKind::DCMK_None) {
if (!isOpenMPEnabled) {
unsigned diagID = ci.getDiagnostics().getCustomDiagID(
Expand All @@ -345,18 +337,21 @@ bool CodeGenAction::beginSourceFileAction() {
ci.getDiagnostics().Report(diagID);
} else {
bool mapToDevice = selectedKind == DoConcurrentMappingKind::DCMK_Device;

if (mapToDevice) {
unsigned diagID = ci.getDiagnostics().getCustomDiagID(
clang::DiagnosticsEngine::Warning,
"TODO: lowering `do concurrent` loops to OpenMP device is not "
"supported yet");
ci.getDiagnostics().Report(diagID);
} else
pm.addPass(fir::createDoConcurrentConversionPass());
pm.addPass(fir::createDoConcurrentConversionPass(mapToDevice));
}
}

if (isOpenMPEnabled) {
bool isDevice = false;
if (auto offloadMod = llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(
mlirModule->getOperation()))
isDevice = offloadMod.getIsTargetDevice();
// WARNING: This pipeline must be run immediately after the lowering to
// ensure that the FIR is correct with respect to OpenMP operations/
// attributes.
fir::createOpenMPFIRPassPipeline(pm, isDevice);
skatrak marked this conversation as resolved.
Show resolved Hide resolved
}

pm.enableVerifier(/*verifyPasses=*/true);
pm.addPass(std::make_unique<Fortran::lower::VerifierPass>());

Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
//===----------------------------------------------------------------------===//

#include "ClauseProcessor.h"
#include "Clauses.h"

#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Parser/tools.h"
#include "flang/Semantics/tools.h"
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
#ifndef FORTRAN_LOWER_CLAUASEPROCESSOR_H
#define FORTRAN_LOWER_CLAUASEPROCESSOR_H

#include "Clauses.h"
#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/OpenMP/Utils.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Parser/dump-parse-tree.h"
#include "flang/Parser/parse-tree.h"
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/OpenMP/Clauses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "Clauses.h"
#include "flang/Lower/OpenMP/Clauses.h"

#include "flang/Common/idioms.h"
#include "flang/Evaluate/expression.h"
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

#include "DataSharingProcessor.h"

#include "Utils.h"
#include "flang/Lower/OpenMP/Utils.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/OpenMP/DataSharingProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
#ifndef FORTRAN_LOWER_DATASHARINGPROCESSOR_H
#define FORTRAN_LOWER_DATASHARINGPROCESSOR_H

#include "Clauses.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/OpenMP.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/symbol.h"
Expand Down
86 changes: 4 additions & 82 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
#include "flang/Lower/OpenMP.h"

#include "ClauseProcessor.h"
#include "Clauses.h"
#include "DataSharingProcessor.h"
#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Common/idioms.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertExpr.h"
#include "flang/Lower/ConvertVariable.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/OpenMP/Utils.h"
#include "flang/Lower/StatementContext.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/BoxValue.h"
Expand Down Expand Up @@ -280,84 +280,6 @@ static void threadPrivatizeVars(Fortran::lower::AbstractConverter &converter,
}
}

static mlir::Value
calculateTripCount(Fortran::lower::AbstractConverter &converter,
mlir::Location loc,
const mlir::omp::CollapseClauseOps &ops) {
using namespace mlir::arith;
assert(ops.loopLBVar.size() == ops.loopUBVar.size() &&
ops.loopLBVar.size() == ops.loopStepVar.size() &&
!ops.loopLBVar.empty() && "Invalid bounds or step");

fir::FirOpBuilder &b = converter.getFirOpBuilder();

// Get the bit width of an integer-like type.
auto widthOf = [](mlir::Type ty) -> unsigned {
if (mlir::isa<mlir::IndexType>(ty)) {
return mlir::IndexType::kInternalStorageBitWidth;
}
if (auto tyInt = mlir::dyn_cast<mlir::IntegerType>(ty)) {
return tyInt.getWidth();
}
llvm_unreachable("Unexpected type");
};

// For a type that is either IntegerType or IndexType, return the
// equivalent IntegerType. In the former case this is a no-op.
auto asIntTy = [&](mlir::Type ty) -> mlir::IntegerType {
if (ty.isIndex()) {
return mlir::IntegerType::get(ty.getContext(), widthOf(ty));
}
assert(ty.isIntOrIndex() && "Unexpected type");
return mlir::cast<mlir::IntegerType>(ty);
};

// For two given values, establish a common signless IntegerType
// that can represent any value of type of x and of type of y,
// and return the pair of x, y converted to the new type.
auto unifyToSignless =
[&](fir::FirOpBuilder &b, mlir::Value x,
mlir::Value y) -> std::pair<mlir::Value, mlir::Value> {
auto tyX = asIntTy(x.getType()), tyY = asIntTy(y.getType());
unsigned width = std::max(widthOf(tyX), widthOf(tyY));
auto wideTy = mlir::IntegerType::get(b.getContext(), width,
mlir::IntegerType::Signless);
return std::make_pair(b.createConvert(loc, wideTy, x),
b.createConvert(loc, wideTy, y));
};

// Start with signless i32 by default.
auto tripCount = b.createIntegerConstant(loc, b.getI32Type(), 1);

for (auto [origLb, origUb, origStep] :
llvm::zip(ops.loopLBVar, ops.loopUBVar, ops.loopStepVar)) {
auto tmpS0 = b.createIntegerConstant(loc, origStep.getType(), 0);
auto [step, step0] = unifyToSignless(b, origStep, tmpS0);
auto reverseCond = b.create<CmpIOp>(loc, CmpIPredicate::slt, step, step0);
auto negStep = b.create<SubIOp>(loc, step0, step);
mlir::Value absStep = b.create<SelectOp>(loc, reverseCond, negStep, step);

auto [lb, ub] = unifyToSignless(b, origLb, origUb);
auto start = b.create<SelectOp>(loc, reverseCond, ub, lb);
auto end = b.create<SelectOp>(loc, reverseCond, lb, ub);

mlir::Value range = b.create<SubIOp>(loc, end, start);
auto rangeCond = b.create<CmpIOp>(loc, CmpIPredicate::slt, end, start);
std::tie(range, absStep) = unifyToSignless(b, range, absStep);
// numSteps = (range /u absStep) + 1
auto numSteps =
b.create<AddIOp>(loc, b.create<DivUIOp>(loc, range, absStep),
b.createIntegerConstant(loc, range.getType(), 1));

auto trip0 = b.createIntegerConstant(loc, numSteps.getType(), 0);
auto loopTripCount = b.create<SelectOp>(loc, rangeCond, trip0, numSteps);
auto [totalTC, thisTC] = unifyToSignless(b, tripCount, loopTripCount);
tripCount = b.create<MulIOp>(loc, totalTC, thisTC);
}

return tripCount;
}

static mlir::Operation *
createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
mlir::Location loc, mlir::Value indexVal,
Expand Down Expand Up @@ -1574,8 +1496,8 @@ genLoopNestOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processCollapse(loc, eval, collapseClauseOps, iv);
targetOp.getTripCountMutable().assign(
calculateTripCount(converter, loc, collapseClauseOps));
targetOp.getTripCountMutable().assign(calculateTripCount(
converter.getFirOpBuilder(), loc, collapseClauseOps));
}
return loopNestOp;
}
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/OpenMP/ReductionProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#ifndef FORTRAN_LOWER_REDUCTIONPROCESSOR_H
#define FORTRAN_LOWER_REDUCTIONPROCESSOR_H

#include "Clauses.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Semantics/symbol.h"
Expand Down
Loading