Skip to content

Commit

Permalink
[Flang][MLIR][OpenMP] Initial multi-layered derived type member map s…
Browse files Browse the repository at this point in the history
…upport

This PR adds explicit derived type member mapping for nested descriptor types (allocatables) and other types, allowing users to map specific components of a derived type rather than the whole derived type. Currently this would also be the only way to map descriptor types within a derived type, as the automagic mapping of these when mapping
an entire derived type is still a WIP and should follow on from this work.

There's a lot of Fortran tests added in this PR that should give examples of what kind of mappings are handled in this PR. This PR shouldn't (at least from self testing) regress any existing map behaviour just add on to it.
  • Loading branch information
agozillon committed Apr 16, 2024
1 parent 2a2be38 commit c59f85b
Show file tree
Hide file tree
Showing 64 changed files with 3,329 additions and 410 deletions.
4 changes: 2 additions & 2 deletions flang/docs/OpenMP-descriptor-management.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Currently, Flang will lower these descriptor types in the OpenMP lowering (lower
to all other map types, generating an omp.MapInfoOp containing relevant information required for lowering
the OpenMP dialect to LLVM-IR during the final stages of the MLIR lowering. However, after
the lowering to FIR/HLFIR has been performed an OpenMP dialect specific pass for Fortran,
`OMPDescriptorMapInfoGenPass` (Optimizer/OMPDescriptorMapInfoGen.cpp) will expand the
`OMPMapInfoFinalizationPass` (Optimizer/OMPMapInfoFinalization.cpp) will expand the
`omp.MapInfoOp`'s containing descriptors (which currently will be a `BoxType` or `BoxAddrOp`) into multiple
mappings, with one extra per pointer member in the descriptor that is supported on top of the original
descriptor map operation. These pointers members are linked to the parent descriptor by adding them to
Expand All @@ -53,7 +53,7 @@ owning operation's (`omp.TargetOp`, `omp.TargetDataOp` etc.) map operand list an
operation is `IsolatedFromAbove`, it also inserts them as `BlockArgs` to canonicalize the mappings and
simplify lowering.
An example transformation by the `OMPDescriptorMapInfoGenPass`:
An example transformation by the `OMPMapInfoFinalizationPass`:
```

Expand Down
2 changes: 1 addition & 1 deletion flang/include/flang/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ std::unique_ptr<mlir::Pass>
createAlgebraicSimplificationPass(const mlir::GreedyRewriteConfig &config);
std::unique_ptr<mlir::Pass> createPolymorphicOpConversionPass();

std::unique_ptr<mlir::Pass> createOMPDescriptorMapInfoGenPass();
std::unique_ptr<mlir::Pass> createOMPMapInfoFinalizationPass();
std::unique_ptr<mlir::Pass> createOMPFunctionFilteringPass();
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createOMPMarkDeclareTargetPass();
Expand Down
6 changes: 3 additions & 3 deletions flang/include/flang/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -326,15 +326,15 @@ def LoopVersioning : Pass<"loop-versioning", "mlir::func::FuncOp"> {
let dependentDialects = [ "fir::FIROpsDialect" ];
}

def OMPDescriptorMapInfoGenPass
: Pass<"omp-descriptor-map-info-gen", "mlir::func::FuncOp"> {
def OMPMapInfoFinalizationPass
: Pass<"omp-map-info-finalization", "mlir::func::FuncOp"> {
let summary = "expands OpenMP MapInfo operations containing descriptors";
let description = [{
Expands MapInfo operations containing descriptor types into multiple
MapInfo's for each pointer element in the descriptor that requires
explicit individual mapping by the OpenMP runtime.
}];
let constructor = "::fir::createOMPDescriptorMapInfoGenPass()";
let constructor = "::fir::createOMPMapInfoFinalizationPass()";
let dependentDialects = ["mlir::omp::OpenMPDialect"];
}

Expand Down
2 changes: 1 addition & 1 deletion flang/include/flang/Tools/CLOptions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ inline void createHLFIRToFIRPassPipeline(
/// rather than the host device.
inline void createOpenMPFIRPassPipeline(
mlir::PassManager &pm, bool isTargetDevice) {
pm.addPass(fir::createOMPDescriptorMapInfoGenPass());
pm.addPass(fir::createOMPMapInfoFinalizationPass());
pm.addPass(fir::createOMPMarkDeclareTargetPass());
if (isTargetDevice)
pm.addPass(fir::createOMPFunctionFilteringPass());
Expand Down
74 changes: 58 additions & 16 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "flang/Lower/PFTBuilder.h"
#include "flang/Parser/tools.h"
#include "flang/Semantics/expression.h"
#include "flang/Semantics/tools.h"

namespace Fortran {
Expand Down Expand Up @@ -811,9 +812,10 @@ mlir::omp::MapInfoOp
createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
llvm::ArrayRef<mlir::Value> bounds,
llvm::ArrayRef<mlir::Value> members, uint64_t mapType,
llvm::ArrayRef<mlir::Value> members,
mlir::DenseIntElementsAttr membersIndex, uint64_t mapType,
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
bool isVal) {
bool partialMap) {
if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
retTy = baseAddr.getType();
Expand All @@ -823,10 +825,10 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());

mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
loc, retTy, baseAddr, varType, varPtrPtr, members, bounds,
loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds,
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
builder.getStringAttr(name));
builder.getStringAttr(name), builder.getBoolAttr(partialMap));

return op;
}
Expand All @@ -835,10 +837,13 @@ bool ClauseProcessor::processMap(
mlir::Location currentLocation, const llvm::omp::Directive &directive,
Fortran::lower::StatementContext &stmtCtx, mlir::omp::MapClauseOps &result,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
return findRepeatableClause<omp::clause::Map>(
std::map<const Fortran::semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
bool clauseFound = findRepeatableClause<omp::clause::Map>(
[&](const omp::clause::Map &clause,
const Fortran::parser::CharBlock &source) {
using Map = omp::clause::Map;
Expand Down Expand Up @@ -887,6 +892,7 @@ bool ClauseProcessor::processMap(
for (const omp::Object &object : std::get<omp::ObjectList>(clause.t)) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
const Fortran::semantics::Symbol *parentSym = nullptr;

Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
Expand All @@ -900,27 +906,63 @@ bool ClauseProcessor::processMap(
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
symAddr = origSymbol;

llvm::SmallVector<int> indices;
if (object.id()->owner().IsDerivedType()) {
if (auto dataRef{ExtractDataRef(object.designator)}) {
parentSym = &dataRef->GetFirstSymbol();
assert(parentSym &&
"Could not find parent symbol during lower of "
"a component member in OpenMP map clause");

indices = generateMemberPlacementIndices(object, semaCtx);
if (Fortran::semantics::IsAllocatableOrObjectPointer(
object.id())) {
llvm::SmallVector<mlir::Value> index;
for (auto idx : indices)
index.push_back(firOpBuilder.createIntegerConstant(
clauseLocation, firOpBuilder.getIndexType(), idx));

auto recordType =
converter.genType(*object.id()->owner().derivedTypeSpec())
.cast<fir::RecordType>();
auto fieldName = converter.getRecordTypeFieldName(*object.id());
mlir::Type fieldType = recordType.getType(fieldName);
mlir::Type designatorType = fir::ReferenceType::get(fieldType);
symAddr = firOpBuilder.create<fir::CoordinateOp>(
clauseLocation, designatorType,
converter.getSymbolAddress(*parentSym), index);
}
}
}

// Explicit map captures are captured ByRef by default,
// optimisation passes may alter this to ByCopy or other capture
// types to optimise
mlir::Value mapOp = createMapInfoOp(
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
firOpBuilder, clauseLocation, symAddr, mlir::Value{},
asFortran.str(), bounds, {},
asFortran.str(), bounds, {}, mlir::DenseIntElementsAttr{},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());

result.mapVars.push_back(mapOp);

if (mapSyms)
if (parentSym) {
parentMemberIndices[parentSym].push_back({indices, mapOp});
} else {
result.mapVars.push_back(mapOp);
mapSyms->push_back(object.id());
if (mapSymLocs)
mapSymLocs->push_back(symAddr.getLoc());
if (mapSymTypes)
mapSymTypes->push_back(symAddr.getType());
if (mapSymTypes)
mapSymTypes->push_back(symAddr.getType());
if (mapSymLocs)
mapSymLocs->push_back(symAddr.getLoc());
}
}
});

insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
mapSymTypes, mapSymLocs, mapSyms);

return clauseFound;
}

bool ClauseProcessor::processReduction(
Expand Down
47 changes: 35 additions & 12 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,14 @@ class ClauseProcessor {
// store the original type, location and Fortran symbol for the map operands.
// They may be used later on to create the block_arguments for some of the
// target directives that require it.
bool processMap(
mlir::Location currentLocation, const llvm::omp::Directive &directive,
Fortran::lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms =
nullptr,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
bool
processMap(mlir::Location currentLocation,
const llvm::omp::Directive &directive,
Fortran::lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr) const;
bool processReduction(
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr,
Expand Down Expand Up @@ -190,7 +190,12 @@ template <typename T>
bool ClauseProcessor::processMotionClauses(
Fortran::lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result) {
return findRepeatableClause<T>(
std::map<const Fortran::semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;

bool clauseFound = findRepeatableClause<T>(
[&](const T &clause, const Fortran::parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Expand All @@ -208,6 +213,7 @@ bool ClauseProcessor::processMotionClauses(
for (const omp::Object &object : objects) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;

Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
Expand All @@ -223,17 +229,34 @@ bool ClauseProcessor::processMotionClauses(
// Explicit map captures are captured ByRef by default,
// optimisation passes may alter this to ByCopy or other capture
// types to optimise
mlir::Value mapOp = createMapInfoOp(
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
firOpBuilder, clauseLocation, symAddr, mlir::Value{},
asFortran.str(), bounds, {},
asFortran.str(), bounds, {}, mlir::DenseIntElementsAttr{},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());

result.mapVars.push_back(mapOp);
if (object.id()->owner().IsDerivedType()) {
if (auto dataRef{ExtractDataRef(object.designator)}) {
const Fortran::semantics::Symbol *parentSym =
&dataRef->GetFirstSymbol();
assert(parentSym &&
"Could not find parent symbol during lower of "
"a component member in OpenMP map clause");
parentMemberIndices[parentSym].push_back(
{generateMemberPlacementIndices(object, semaCtx), mapOp});
}
} else {
result.mapVars.push_back(mapOp);
mapSymbols.push_back(object.id());
}
}
});

insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
nullptr, nullptr, &mapSymbols);
return clauseFound;
}

template <typename... Ts>
Expand Down
11 changes: 7 additions & 4 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1232,8 +1232,9 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter,
// re-introduce a hard-error rather than a warning in these cases.
promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(clauseOps, useDeviceTypes,
useDeviceLocs, useDeviceSyms);
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSyms;
cp.processMap(currentLocation, llvm::omp::Directive::OMPD_target_data,
stmtCtx, clauseOps);
stmtCtx, clauseOps, &mapSyms);

auto dataOp = converter.getFirOpBuilder().create<mlir::omp::TargetDataOp>(
currentLocation, clauseOps);
Expand Down Expand Up @@ -1276,7 +1277,8 @@ static OpTy genTargetEnterExitDataUpdateOp(
cp.processMotionClauses<clause::To>(stmtCtx, clauseOps);
cp.processMotionClauses<clause::From>(stmtCtx, clauseOps);
} else {
cp.processMap(currentLocation, directive, stmtCtx, clauseOps);
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSyms;
cp.processMap(currentLocation, directive, stmtCtx, clauseOps, &mapSyms);
}

return firOpBuilder.create<OpTy>(currentLocation, clauseOps);
Expand Down Expand Up @@ -1395,6 +1397,7 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
mlir::Value mapOp = createMapInfoOp(
firOpBuilder, copyVal.getLoc(), copyVal, mlir::Value{}, name.str(),
bounds, llvm::SmallVector<mlir::Value>{},
mlir::DenseIntElementsAttr{},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
Expand Down Expand Up @@ -1460,7 +1463,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
cp.processThreadLimit(stmtCtx, clauseOps);
cp.processDepend(clauseOps);
cp.processMap(currentLocation, directive, stmtCtx, clauseOps, &mapSyms,
&mapLocs, &mapTypes);
&mapTypes, &mapLocs);
cp.processIsDevicePtr(clauseOps, devicePtrTypes, devicePtrLocs,
devicePtrSyms);
cp.processHasDeviceAddr(clauseOps, deviceAddrTypes, deviceAddrLocs,
Expand Down Expand Up @@ -1561,7 +1564,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,

mlir::Value mapOp = createMapInfoOp(
converter.getFirOpBuilder(), baseOp.getLoc(), baseOp, mlir::Value{},
name.str(), bounds, {},
name.str(), bounds, {}, mlir::DenseIntElementsAttr{},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
mapFlag),
Expand Down
Loading

0 comments on commit c59f85b

Please sign in to comment.