Skip to content

Commit

Permalink
Initial support for mapping allocatables/pointers in derived types an…
Browse files Browse the repository at this point in the history
…d related map syntax (#160)

* Apply upstream allocatable member mapping with some minor modifications

* fix rebase issues

* Fix tests
  • Loading branch information
agozillon authored Sep 24, 2024
1 parent e7b6197 commit 968dac4
Show file tree
Hide file tree
Showing 28 changed files with 2,080 additions and 450 deletions.
11 changes: 11 additions & 0 deletions flang/include/flang/Lower/OpenMP/Clauses.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ struct IdTyTemplate {
return designator == other.designator;
}

// Defining an "ordering" which allows types derived from this to be
// utilised in maps and other containers that require comparison
// operators for ordering
bool operator<(const IdTyTemplate &other) const {
return symbol < other.symbol;
}

operator bool() const { return symbol != nullptr; }
};

Expand All @@ -76,6 +83,10 @@ struct ObjectT<Fortran::lower::omp::IdTyTemplate<Fortran::lower::omp::ExprTy>,
Fortran::semantics::Symbol *sym() const { return identity.symbol; }
const std::optional<ExprTy> &ref() const { return identity.designator; }

bool operator<(const ObjectT<IdTy, ExprTy> &other) const {
return identity < other.identity;
}

IdTy identity;
};
} // namespace tomp::type
Expand Down
121 changes: 97 additions & 24 deletions flang/include/flang/Lower/OpenMP/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/Value.h"
#include "llvm/Support/CommandLine.h"
#include <cstdint>

extern llvm::cl::opt<bool> treatIndexAsSection;
extern llvm::cl::opt<bool> enableDelayedPrivatization;
Expand All @@ -34,6 +35,7 @@ struct OmpObjectList;
} // namespace parser

namespace lower {
class StatementContext;
namespace pft {
struct Evaluation;
}
Expand All @@ -49,38 +51,111 @@ using DeclareTargetCapturePair =
// and index data when lowering OpenMP map clauses. Keeps track of the
// placement of the component in the derived type hierarchy it rests within,
// alongside the generated mlir::omp::MapInfoOp for the mapped component.
struct OmpMapMemberIndicesData {
//
// As an example of what the contents of this data structure may be like,
// when provided the following derived type and map of that type:
//
// type :: bottom_layer
// real(8) :: i2
// real(4) :: array_i2(10)
// real(4) :: array_j2(10)
// end type bottom_layer
//
// type :: top_layer
// real(4) :: i
// integer(4) :: array_i(10)
// real(4) :: j
// type(bottom_layer) :: nested
// integer, allocatable :: array_j(:)
// integer(4) :: k
// end type top_layer
//
// type(top_layer) :: top_dtype
//
// map(tofrom: top_dtype%nested%i2, top_dtype%k, top_dtype%nested%array_i2)
//
// We would end up with an OmpMapParentAndMemberData populated like below:
//
// memberPlacementIndices:
// Vector 1: 3, 0
// Vector 2: 5
// Vector 3: 3, 1
//
// memberMap:
// Entry 1: omp.map.info for "top_dtype%nested%i2"
// Entry 2: omp.map.info for "top_dtype%k"
// Entry 3: omp.map.info for "top_dtype%nested%array_i2"
//
// And this OmpMapParentAndMemberData would be accessed via the parent
// symbol for top_dtype. Other parent derived type instances that have
// members mapped would have there own OmpMapParentAndMemberData entry
// accessed via their own symbol.
struct OmpMapParentAndMemberData {
// The indices representing the component members placement in its derived
// type parents hierarchy.
llvm::SmallVector<int> memberPlacementIndices;
llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;

// Placement of the member in the member vector.
mlir::omp::MapInfoOp memberMap;
llvm::SmallVector<mlir::omp::MapInfoOp> memberMap;

// The list of associated parent object symbols. used to track data we
// need for various parent processing tasks when performing member
// mapping, the main example currently being re-evaluating the parent
// maps bounds at the final step of map processing, where we need to
// keep a hold of all of the omp::Object's which contain array bounds
// for the respective parent to calculate the final bounds from.
//
// As an Example:
//
// !$omp target map(tofrom: alloca_dtype_arr(2)%array_i,
// alloca_dtype_arr(3)%array_i)
//
// parentObjList will contain alloca_dtype_arr(3) as well as
// alloca_dtype_arr(2).
ObjectList parentObjList;
};

mlir::omp::MapInfoOp
createMapInfoOp(mlir::OpBuilder &builder, mlir::Location loc,
mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
mlir::ArrayRef<mlir::Value> bounds,
mlir::ArrayRef<mlir::Value> members,
mlir::DenseIntElementsAttr membersIndex, uint64_t mapType,
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
bool partialMap = false);

void addChildIndexAndMapToParent(
const omp::Object &object,
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
mlir::omp::MapInfoOp &mapOp, semantics::SemanticsContext &semaCtx);
void generateMemberPlacementIndices(
const Object &object, llvm::SmallVectorImpl<int64_t> &indices,
Fortran::semantics::SemanticsContext &semaCtx);

bool isMemberOrParentAllocatableOrPointer(
const Object &object, Fortran::semantics::SemanticsContext &semaCtx);

bool isDuplicateMemberMapInfo(OmpMapParentAndMemberData &parentMembers,
llvm::SmallVectorImpl<int64_t> &memberIndices);

mlir::omp::MapInfoOp createMapInfoOp(
fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value baseAddr,
mlir::Value varPtrPtr, std::string name, mlir::ArrayRef<mlir::Value> bounds,
mlir::ArrayRef<mlir::Value> members, mlir::ArrayAttr membersIndex,
uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType,
mlir::Type retTy, bool partialMap = false);

mlir::Value createParentSymAndGenIntermediateMaps(
mlir::Location clauseLocation, Fortran::lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx,
omp::ObjectList &objectList, llvm::SmallVector<int64_t> &indices,
OmpMapParentAndMemberData &parentMemberIndices, std::string asFortran,
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits);

omp::ObjectList gatherObjects(omp::Object obj,
semantics::SemanticsContext &semaCtx);

void addChildIndexAndMapToParent(const omp::Object &object,
OmpMapParentAndMemberData &parentMemberIndices,
mlir::omp::MapInfoOp &mapOp,
semantics::SemanticsContext &semaCtx);

void insertChildMapInfoIntoParent(
lower::AbstractConverter &converter,
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::StatementContext &stmtCtx,
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapOperands,
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs);
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols);

mlir::Type getLoopVarType(lower::AbstractConverter &converter,
std::size_t loopVarTypeSize);
Expand All @@ -94,8 +169,6 @@ void gatherFuncAndVarSyms(

int64_t getCollapseValue(const List<Clause> &clauses);

semantics::Symbol *getOmpObjectSymbol(const parser::OmpObject &ompObject);

void genObjectList(const ObjectList &objects,
lower::AbstractConverter &converter,
llvm::SmallVectorImpl<mlir::Value> &operands);
Expand Down
5 changes: 5 additions & 0 deletions flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
llvm::ArrayRef<mlir::Value> lenParams,
bool asTarget = false);

/// Create a two dimensional ArrayAttr containing integer data as
/// IntegerAttrs, effectively: ArrayAttr<ArrayAttr<IntegerAttr>>>.
mlir::ArrayAttr create2DIntegerArrayAttr(
llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &intData);

/// Create a temporary using `fir.alloca`. This function does not hoist.
/// It is the callers responsibility to set the insertion point if
/// hoisting is required.
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Frontend/FrontendActions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ bool CodeGenAction::beginSourceFileAction() {
return false;
}


// Print initial full MLIR module, before lowering or transformations, if
// -save-temps has been specified.
if (!saveMLIRTempFile(ci.getInvocation(), *mlirModule, getCurrentFile(),
Expand Down
5 changes: 4 additions & 1 deletion flang/lib/Lower/DirectivesCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,10 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
// If it is a scalar subscript, then the upper bound
// is equal to the lower bound, and the extent is one.
ubound = lbound;
extent = one;
if (treatIndexAsSection)
extent = fir::factory::readExtent(builder, loc, dataExv, dimension);
else
extent = one;
} else {
asFortran << ':';
Fortran::semantics::MaybeExpr upper =
Expand Down
78 changes: 46 additions & 32 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -889,16 +889,17 @@ void ClauseProcessor::processMapObjects(
lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
const omp::ObjectList &objects,
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapVars,
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

for (const omp::Object &object : objects) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
std::optional<omp::Object> parentObj;

lower::AddrAndBoundsInfo info =
lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
Expand All @@ -907,28 +908,47 @@ void ClauseProcessor::processMapObjects(
object.ref(), clauseLocation, asFortran, bounds,
treatIndexAsSection);

mlir::Value baseOp = info.rawInput;
if (object.sym()->owner().IsDerivedType()) {
omp::ObjectList objectList = gatherObjects(object, semaCtx);
assert(!objectList.empty() &&
"could not find parent objects of derived type member");
parentObj = objectList[0];
auto insert = parentMemberIndices.emplace(parentObj.value(),
OmpMapParentAndMemberData{});
insert.first->second.parentObjList.push_back(parentObj.value());

if (isMemberOrParentAllocatableOrPointer(object, semaCtx)) {
llvm::SmallVector<int64_t> indices;
generateMemberPlacementIndices(object, indices, semaCtx);
baseOp = createParentSymAndGenIntermediateMaps(
clauseLocation, converter, semaCtx, stmtCtx, objectList, indices,
parentMemberIndices[parentObj.value()], asFortran.str(),
mapTypeBits);
}
}

// Explicit map captures are captured ByRef by default,
// optimisation passes may alter this to ByCopy or other capture
// types to optimise
mlir::Value baseOp = info.rawInput;
auto location = mlir::NameLoc::get(
mlir::StringAttr::get(firOpBuilder.getContext(), asFortran.str()),
baseOp.getLoc());
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
firOpBuilder, location, baseOp,
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
/*members=*/{}, /*membersIndex=*/mlir::ArrayAttr{},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType());

if (object.sym()->owner().IsDerivedType()) {
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp, semaCtx);
if (parentObj.has_value()) {
addChildIndexAndMapToParent(
object, parentMemberIndices[parentObj.value()], mapOp, semaCtx);
} else {
mapVars.push_back(mapOp);
if (mapSyms)
mapSyms->push_back(object.sym());
mapSyms->push_back(object.sym());
if (mapSymTypes)
mapSymTypes->push_back(baseOp.getType());
if (mapSymLocs)
Expand All @@ -949,9 +969,7 @@ bool ClauseProcessor::processMap(
llvm::SmallVector<const semantics::Symbol *> localMapSyms;
llvm::SmallVectorImpl<const semantics::Symbol *> *ptrMapSyms =
mapSyms ? mapSyms : &localMapSyms;
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;

bool clauseFound = findRepeatableClause<omp::clause::Map>(
[&](const omp::clause::Map &clause, const parser::CharBlock &source) {
Expand Down Expand Up @@ -997,23 +1015,22 @@ bool ClauseProcessor::processMap(
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
}

processMapObjects(stmtCtx, clauseLocation,
std::get<omp::ObjectList>(clause.t), mapTypeBits,
parentMemberIndices, result.mapVars, ptrMapSyms,
mapSymLocs, mapSymTypes);
});

insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
*ptrMapSyms, mapSymTypes, mapSymLocs);

insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.mapVars, mapSymTypes, mapSymLocs,
ptrMapSyms);
return clauseFound;
}

bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result) {
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
llvm::SmallVector<const semantics::Symbol *> mapSymbols;

auto callbackFn = [&](const auto &clause, const parser::CharBlock &source) {
Expand All @@ -1034,9 +1051,9 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
clauseFound =
findRepeatableClause<omp::clause::From>(callbackFn) || clauseFound;

insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
mapSymbols,
/*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr);
insertChildMapInfoIntoParent(
converter, semaCtx, stmtCtx, parentMemberIndices, result.mapVars,
/*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr, &mapSymbols);
return clauseFound;
}

Expand Down Expand Up @@ -1110,9 +1127,7 @@ bool ClauseProcessor::processUseDeviceAddr(
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
bool clauseFound = findRepeatableClause<omp::clause::UseDeviceAddr>(
[&](const omp::clause::UseDeviceAddr &clause,
const parser::CharBlock &source) {
Expand All @@ -1125,9 +1140,9 @@ bool ClauseProcessor::processUseDeviceAddr(
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
});

insertChildMapInfoIntoParent(converter, parentMemberIndices,
result.useDeviceAddrVars, useDeviceSyms,
&useDeviceTypes, &useDeviceLocs);
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.useDeviceAddrVars, &useDeviceTypes,
&useDeviceLocs, &useDeviceSyms);
return clauseFound;
}

Expand All @@ -1136,9 +1151,8 @@ bool ClauseProcessor::processUseDevicePtr(
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;

bool clauseFound = findRepeatableClause<omp::clause::UseDevicePtr>(
[&](const omp::clause::UseDevicePtr &clause,
const parser::CharBlock &source) {
Expand All @@ -1151,9 +1165,9 @@ bool ClauseProcessor::processUseDevicePtr(
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
});

insertChildMapInfoIntoParent(converter, parentMemberIndices,
result.useDevicePtrVars, useDeviceSyms,
&useDeviceTypes, &useDeviceLocs);
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.useDevicePtrVars, &useDeviceTypes,
&useDeviceLocs, &useDeviceSyms);
return clauseFound;
}

Expand Down
Loading

0 comments on commit 968dac4

Please sign in to comment.