Skip to content

Commit

Permalink
[flang][OpenMP] Extend do concurrent mapping to device. #50
Browse files Browse the repository at this point in the history
For simple loops, we can now choose to map `do concurrent` to either the host (i.e. `omp parallel do`) or the device (i.e. `omp target teams distribute parallel do`).

In order to use this from `flang-new`, you can pass: `-fdo-concurrent-parallel=[none|host|device]`.
  • Loading branch information
ergawy authored May 15, 2024
2 parents a92e557 + 3bb1152 commit 2f58864
Show file tree
Hide file tree
Showing 20 changed files with 670 additions and 248 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ struct OmpMapMemberIndicesData {
};

mlir::omp::MapInfoOp
createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
createMapInfoOp(mlir::OpBuilder &builder, mlir::Location loc,
mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
mlir::ArrayRef<mlir::Value> bounds,
mlir::ArrayRef<mlir::Value> members,
Expand Down Expand Up @@ -102,6 +102,15 @@ void genObjectList(const ObjectList &objects,
Fortran::lower::AbstractConverter &converter,
llvm::SmallVectorImpl<mlir::Value> &operands);

// TODO: consider moving this to the `omp.loop_nest` op. Would be something like
// this:
//
// ```
// mlir::Value LoopNestOp::calculateTripCount(mlir::OpBuilder &builder,
// mlir::OpBuilder::InsertPoint ip)
// ```
mlir::Value calculateTripCount(fir::FirOpBuilder &builder, mlir::Location loc,
const mlir::omp::CollapseClauseOps &ops);
} // namespace omp
} // namespace lower
} // namespace Fortran
Expand Down
3 changes: 2 additions & 1 deletion flang/include/flang/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace fir {
#define GEN_PASS_DECL_ARRAYVALUECOPY
#define GEN_PASS_DECL_CHARACTERCONVERSION
#define GEN_PASS_DECL_CFGCONVERSION
#define GEN_PASS_DECL_DOCONCURRENTCONVERSIONPASS
#define GEN_PASS_DECL_EXTERNALNAMECONVERSION
#define GEN_PASS_DECL_MEMREFDATAFLOWOPT
#define GEN_PASS_DECL_SIMPLIFYINTRINSICS
Expand Down Expand Up @@ -88,7 +89,7 @@ createFunctionAttrPass(FunctionAttrTypes &functionAttr, bool noInfsFPMath,
bool noNaNsFPMath, bool approxFuncFPMath,
bool noSignedZerosFPMath, bool unsafeFPMath);

std::unique_ptr<mlir::Pass> createDoConcurrentConversionPass();
std::unique_ptr<mlir::Pass> createDoConcurrentConversionPass(bool mapToDevice);

void populateCfgConversionRewrites(mlir::RewritePatternSet &patterns,
bool forceLoopToExecuteOnce = false);
Expand Down
6 changes: 5 additions & 1 deletion flang/include/flang/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,12 @@ def DoConcurrentConversionPass : Pass<"fopenmp-do-concurrent-conversion", "mlir:
target.
}];

let constructor = "::fir::createDoConcurrentConversionPass()";
let dependentDialects = ["mlir::omp::OpenMPDialect"];

let options = [
Option<"mapTo", "map-to", "std::string", "",
"Try to map `do concurrent` loops to OpenMP (on host or device)">,
];
}

#endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES
11 changes: 9 additions & 2 deletions flang/include/flang/Tools/CLOptions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,9 @@ inline void createHLFIRToFIRPassPipeline(
pm.addPass(hlfir::createConvertHLFIRtoFIRPass());
}

using DoConcurrentMappingKind =
Fortran::frontend::CodeGenOptions::DoConcurrentMappingKind;

/// Create a pass pipeline for handling certain OpenMP transformations needed
/// prior to FIR lowering.
///
Expand All @@ -333,8 +336,12 @@ inline void createHLFIRToFIRPassPipeline(
/// \param pm - MLIR pass manager that will hold the pipeline definition.
/// \param isTargetDevice - Whether code is being generated for a target device
/// rather than the host device.
inline void createOpenMPFIRPassPipeline(
mlir::PassManager &pm, bool isTargetDevice) {
inline void createOpenMPFIRPassPipeline(mlir::PassManager &pm,
bool isTargetDevice, DoConcurrentMappingKind doConcurrentMappingKind) {
if (doConcurrentMappingKind != DoConcurrentMappingKind::DCMK_None)
pm.addPass(fir::createDoConcurrentConversionPass(
doConcurrentMappingKind == DoConcurrentMappingKind::DCMK_Device));

pm.addPass(fir::createOMPMapInfoFinalizationPass());
pm.addPass(fir::createOMPMarkDeclareTargetPass());
if (isTargetDevice)
Expand Down
45 changes: 19 additions & 26 deletions flang/lib/Frontend/FrontendActions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,41 +320,34 @@ bool CodeGenAction::beginSourceFileAction() {
// Add OpenMP-related passes
// WARNING: These passes must be run immediately after the lowering to ensure
// that the FIR is correct with respect to OpenMP operations/attributes.
bool isOpenMPEnabled = ci.getInvocation().getFrontendOpts().features.IsEnabled(
bool isOpenMPEnabled =
ci.getInvocation().getFrontendOpts().features.IsEnabled(
Fortran::common::LanguageFeature::OpenMP);

using DoConcurrentMappingKind =
Fortran::frontend::CodeGenOptions::DoConcurrentMappingKind;
DoConcurrentMappingKind doConcurrentMappingKind =
ci.getInvocation().getCodeGenOpts().getDoConcurrentMapping();

if (doConcurrentMappingKind != DoConcurrentMappingKind::DCMK_None &&
!isOpenMPEnabled) {
unsigned diagID = ci.getDiagnostics().getCustomDiagID(
clang::DiagnosticsEngine::Warning,
"lowering `do concurrent` loops to OpenMP is only supported if "
"OpenMP is enabled");
ci.getDiagnostics().Report(diagID);
}

if (isOpenMPEnabled) {
bool isDevice = false;
if (auto offloadMod = llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(
mlirModule->getOperation()))
isDevice = offloadMod.getIsTargetDevice();

// WARNING: This pipeline must be run immediately after the lowering to
// ensure that the FIR is correct with respect to OpenMP operations/
// attributes.
fir::createOpenMPFIRPassPipeline(pm, isDevice);
}

using DoConcurrentMappingKind =
Fortran::frontend::CodeGenOptions::DoConcurrentMappingKind;
DoConcurrentMappingKind selectedKind = ci.getInvocation().getCodeGenOpts().getDoConcurrentMapping();
if (selectedKind != DoConcurrentMappingKind::DCMK_None) {
if (!isOpenMPEnabled) {
unsigned diagID = ci.getDiagnostics().getCustomDiagID(
clang::DiagnosticsEngine::Warning,
"lowering `do concurrent` loops to OpenMP is only supported if "
"OpenMP is enabled");
ci.getDiagnostics().Report(diagID);
} else {
bool mapToDevice = selectedKind == DoConcurrentMappingKind::DCMK_Device;

if (mapToDevice) {
unsigned diagID = ci.getDiagnostics().getCustomDiagID(
clang::DiagnosticsEngine::Warning,
"TODO: lowering `do concurrent` loops to OpenMP device is not "
"supported yet");
ci.getDiagnostics().Report(diagID);
} else
pm.addPass(fir::createDoConcurrentConversionPass());
}
fir::createOpenMPFIRPassPipeline(pm, isDevice, doConcurrentMappingKind);
}

pm.enableVerifier(/*verifyPasses=*/true);
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
//===----------------------------------------------------------------------===//

#include "ClauseProcessor.h"
#include "Clauses.h"

#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Parser/tools.h"
#include "flang/Semantics/tools.h"
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
#ifndef FORTRAN_LOWER_CLAUASEPROCESSOR_H
#define FORTRAN_LOWER_CLAUASEPROCESSOR_H

#include "Clauses.h"
#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/OpenMP/Utils.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Parser/dump-parse-tree.h"
#include "flang/Parser/parse-tree.h"
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/OpenMP/Clauses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "Clauses.h"
#include "flang/Lower/OpenMP/Clauses.h"

#include "flang/Common/idioms.h"
#include "flang/Evaluate/expression.h"
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

#include "DataSharingProcessor.h"

#include "Utils.h"
#include "flang/Lower/OpenMP/Utils.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/OpenMP/DataSharingProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
#ifndef FORTRAN_LOWER_DATASHARINGPROCESSOR_H
#define FORTRAN_LOWER_DATASHARINGPROCESSOR_H

#include "Clauses.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/OpenMP.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/symbol.h"
Expand Down
86 changes: 4 additions & 82 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
#include "flang/Lower/OpenMP.h"

#include "ClauseProcessor.h"
#include "Clauses.h"
#include "DataSharingProcessor.h"
#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Common/idioms.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertExpr.h"
#include "flang/Lower/ConvertVariable.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/OpenMP/Utils.h"
#include "flang/Lower/StatementContext.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/BoxValue.h"
Expand Down Expand Up @@ -280,84 +280,6 @@ static void threadPrivatizeVars(Fortran::lower::AbstractConverter &converter,
}
}

static mlir::Value
calculateTripCount(Fortran::lower::AbstractConverter &converter,
mlir::Location loc,
const mlir::omp::CollapseClauseOps &ops) {
using namespace mlir::arith;
assert(ops.loopLBVar.size() == ops.loopUBVar.size() &&
ops.loopLBVar.size() == ops.loopStepVar.size() &&
!ops.loopLBVar.empty() && "Invalid bounds or step");

fir::FirOpBuilder &b = converter.getFirOpBuilder();

// Get the bit width of an integer-like type.
auto widthOf = [](mlir::Type ty) -> unsigned {
if (mlir::isa<mlir::IndexType>(ty)) {
return mlir::IndexType::kInternalStorageBitWidth;
}
if (auto tyInt = mlir::dyn_cast<mlir::IntegerType>(ty)) {
return tyInt.getWidth();
}
llvm_unreachable("Unexpected type");
};

// For a type that is either IntegerType or IndexType, return the
// equivalent IntegerType. In the former case this is a no-op.
auto asIntTy = [&](mlir::Type ty) -> mlir::IntegerType {
if (ty.isIndex()) {
return mlir::IntegerType::get(ty.getContext(), widthOf(ty));
}
assert(ty.isIntOrIndex() && "Unexpected type");
return mlir::cast<mlir::IntegerType>(ty);
};

// For two given values, establish a common signless IntegerType
// that can represent any value of type of x and of type of y,
// and return the pair of x, y converted to the new type.
auto unifyToSignless =
[&](fir::FirOpBuilder &b, mlir::Value x,
mlir::Value y) -> std::pair<mlir::Value, mlir::Value> {
auto tyX = asIntTy(x.getType()), tyY = asIntTy(y.getType());
unsigned width = std::max(widthOf(tyX), widthOf(tyY));
auto wideTy = mlir::IntegerType::get(b.getContext(), width,
mlir::IntegerType::Signless);
return std::make_pair(b.createConvert(loc, wideTy, x),
b.createConvert(loc, wideTy, y));
};

// Start with signless i32 by default.
auto tripCount = b.createIntegerConstant(loc, b.getI32Type(), 1);

for (auto [origLb, origUb, origStep] :
llvm::zip(ops.loopLBVar, ops.loopUBVar, ops.loopStepVar)) {
auto tmpS0 = b.createIntegerConstant(loc, origStep.getType(), 0);
auto [step, step0] = unifyToSignless(b, origStep, tmpS0);
auto reverseCond = b.create<CmpIOp>(loc, CmpIPredicate::slt, step, step0);
auto negStep = b.create<SubIOp>(loc, step0, step);
mlir::Value absStep = b.create<SelectOp>(loc, reverseCond, negStep, step);

auto [lb, ub] = unifyToSignless(b, origLb, origUb);
auto start = b.create<SelectOp>(loc, reverseCond, ub, lb);
auto end = b.create<SelectOp>(loc, reverseCond, lb, ub);

mlir::Value range = b.create<SubIOp>(loc, end, start);
auto rangeCond = b.create<CmpIOp>(loc, CmpIPredicate::slt, end, start);
std::tie(range, absStep) = unifyToSignless(b, range, absStep);
// numSteps = (range /u absStep) + 1
auto numSteps =
b.create<AddIOp>(loc, b.create<DivUIOp>(loc, range, absStep),
b.createIntegerConstant(loc, range.getType(), 1));

auto trip0 = b.createIntegerConstant(loc, numSteps.getType(), 0);
auto loopTripCount = b.create<SelectOp>(loc, rangeCond, trip0, numSteps);
auto [totalTC, thisTC] = unifyToSignless(b, tripCount, loopTripCount);
tripCount = b.create<MulIOp>(loc, totalTC, thisTC);
}

return tripCount;
}

static mlir::Operation *
createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
mlir::Location loc, mlir::Value indexVal,
Expand Down Expand Up @@ -1574,8 +1496,8 @@ genLoopNestOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processCollapse(loc, eval, collapseClauseOps, iv);
targetOp.getTripCountMutable().assign(
calculateTripCount(converter, loc, collapseClauseOps));
targetOp.getTripCountMutable().assign(calculateTripCount(
converter.getFirOpBuilder(), loc, collapseClauseOps));
}
return loopNestOp;
}
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/OpenMP/ReductionProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#ifndef FORTRAN_LOWER_REDUCTIONPROCESSOR_H
#define FORTRAN_LOWER_REDUCTIONPROCESSOR_H

#include "Clauses.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Semantics/symbol.h"
Expand Down
Loading

0 comments on commit 2f58864

Please sign in to comment.