Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial support for mapping allocatables/pointers in derived types and related map syntax #160

Merged
merged 3 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions flang/include/flang/Lower/OpenMP/Clauses.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ struct IdTyTemplate {
return designator == other.designator;
}

// Defining an "ordering" which allows types derived from this to be
// utilised in maps and other containers that require comparison
// operators for ordering
bool operator<(const IdTyTemplate &other) const {
return symbol < other.symbol;
}

operator bool() const { return symbol != nullptr; }
};

Expand All @@ -76,6 +83,10 @@ struct ObjectT<Fortran::lower::omp::IdTyTemplate<Fortran::lower::omp::ExprTy>,
Fortran::semantics::Symbol *sym() const { return identity.symbol; }
const std::optional<ExprTy> &ref() const { return identity.designator; }

bool operator<(const ObjectT<IdTy, ExprTy> &other) const {
return identity < other.identity;
}

IdTy identity;
};
} // namespace tomp::type
Expand Down
121 changes: 97 additions & 24 deletions flang/include/flang/Lower/OpenMP/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/Value.h"
#include "llvm/Support/CommandLine.h"
#include <cstdint>

extern llvm::cl::opt<bool> treatIndexAsSection;
extern llvm::cl::opt<bool> enableDelayedPrivatization;
Expand All @@ -34,6 +35,7 @@ struct OmpObjectList;
} // namespace parser

namespace lower {
class StatementContext;
namespace pft {
struct Evaluation;
}
Expand All @@ -49,38 +51,111 @@ using DeclareTargetCapturePair =
// and index data when lowering OpenMP map clauses. Keeps track of the
// placement of the component in the derived type hierarchy it rests within,
// alongside the generated mlir::omp::MapInfoOp for the mapped component.
struct OmpMapMemberIndicesData {
//
// As an example of what the contents of this data structure may be like,
// when provided the following derived type and map of that type:
//
// type :: bottom_layer
// real(8) :: i2
// real(4) :: array_i2(10)
// real(4) :: array_j2(10)
// end type bottom_layer
//
// type :: top_layer
// real(4) :: i
// integer(4) :: array_i(10)
// real(4) :: j
// type(bottom_layer) :: nested
// integer, allocatable :: array_j(:)
// integer(4) :: k
// end type top_layer
//
// type(top_layer) :: top_dtype
//
// map(tofrom: top_dtype%nested%i2, top_dtype%k, top_dtype%nested%array_i2)
//
// We would end up with an OmpMapParentAndMemberData populated like below:
//
// memberPlacementIndices:
// Vector 1: 3, 0
// Vector 2: 5
// Vector 3: 3, 1
//
// memberMap:
// Entry 1: omp.map.info for "top_dtype%nested%i2"
// Entry 2: omp.map.info for "top_dtype%k"
// Entry 3: omp.map.info for "top_dtype%nested%array_i2"
//
// And this OmpMapParentAndMemberData would be accessed via the parent
// symbol for top_dtype. Other parent derived type instances that have
// members mapped would have there own OmpMapParentAndMemberData entry
// accessed via their own symbol.
struct OmpMapParentAndMemberData {
// The indices representing the component members placement in its derived
// type parents hierarchy.
llvm::SmallVector<int> memberPlacementIndices;
llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;

// Placement of the member in the member vector.
mlir::omp::MapInfoOp memberMap;
llvm::SmallVector<mlir::omp::MapInfoOp> memberMap;

// The list of associated parent object symbols. used to track data we
// need for various parent processing tasks when performing member
// mapping, the main example currently being re-evaluating the parent
// maps bounds at the final step of map processing, where we need to
// keep a hold of all of the omp::Object's which contain array bounds
// for the respective parent to calculate the final bounds from.
//
// As an Example:
//
// !$omp target map(tofrom: alloca_dtype_arr(2)%array_i,
// alloca_dtype_arr(3)%array_i)
//
// parentObjList will contain alloca_dtype_arr(3) as well as
// alloca_dtype_arr(2).
ObjectList parentObjList;
};

mlir::omp::MapInfoOp
createMapInfoOp(mlir::OpBuilder &builder, mlir::Location loc,
mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
mlir::ArrayRef<mlir::Value> bounds,
mlir::ArrayRef<mlir::Value> members,
mlir::DenseIntElementsAttr membersIndex, uint64_t mapType,
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
bool partialMap = false);

void addChildIndexAndMapToParent(
const omp::Object &object,
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
mlir::omp::MapInfoOp &mapOp, semantics::SemanticsContext &semaCtx);
void generateMemberPlacementIndices(
const Object &object, llvm::SmallVectorImpl<int64_t> &indices,
Fortran::semantics::SemanticsContext &semaCtx);

bool isMemberOrParentAllocatableOrPointer(
const Object &object, Fortran::semantics::SemanticsContext &semaCtx);

bool isDuplicateMemberMapInfo(OmpMapParentAndMemberData &parentMembers,
llvm::SmallVectorImpl<int64_t> &memberIndices);

mlir::omp::MapInfoOp createMapInfoOp(
fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value baseAddr,
mlir::Value varPtrPtr, std::string name, mlir::ArrayRef<mlir::Value> bounds,
mlir::ArrayRef<mlir::Value> members, mlir::ArrayAttr membersIndex,
uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType,
mlir::Type retTy, bool partialMap = false);

mlir::Value createParentSymAndGenIntermediateMaps(
mlir::Location clauseLocation, Fortran::lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx,
omp::ObjectList &objectList, llvm::SmallVector<int64_t> &indices,
OmpMapParentAndMemberData &parentMemberIndices, std::string asFortran,
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits);

omp::ObjectList gatherObjects(omp::Object obj,
semantics::SemanticsContext &semaCtx);

void addChildIndexAndMapToParent(const omp::Object &object,
OmpMapParentAndMemberData &parentMemberIndices,
mlir::omp::MapInfoOp &mapOp,
semantics::SemanticsContext &semaCtx);

void insertChildMapInfoIntoParent(
lower::AbstractConverter &converter,
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::StatementContext &stmtCtx,
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapOperands,
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs);
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols);

mlir::Type getLoopVarType(lower::AbstractConverter &converter,
std::size_t loopVarTypeSize);
Expand All @@ -94,8 +169,6 @@ void gatherFuncAndVarSyms(

int64_t getCollapseValue(const List<Clause> &clauses);

semantics::Symbol *getOmpObjectSymbol(const parser::OmpObject &ompObject);

void genObjectList(const ObjectList &objects,
lower::AbstractConverter &converter,
llvm::SmallVectorImpl<mlir::Value> &operands);
Expand Down
5 changes: 5 additions & 0 deletions flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
llvm::ArrayRef<mlir::Value> lenParams,
bool asTarget = false);

/// Create a two dimensional ArrayAttr containing integer data as
/// IntegerAttrs, effectively: ArrayAttr<ArrayAttr<IntegerAttr>>>.
mlir::ArrayAttr create2DIntegerArrayAttr(
llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &intData);

/// Create a temporary using `fir.alloca`. This function does not hoist.
/// It is the callers responsibility to set the insertion point if
/// hoisting is required.
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Frontend/FrontendActions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ bool CodeGenAction::beginSourceFileAction() {
return false;
}


// Print initial full MLIR module, before lowering or transformations, if
// -save-temps has been specified.
if (!saveMLIRTempFile(ci.getInvocation(), *mlirModule, getCurrentFile(),
Expand Down
5 changes: 4 additions & 1 deletion flang/lib/Lower/DirectivesCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,10 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
// If it is a scalar subscript, then the upper bound
// is equal to the lower bound, and the extent is one.
ubound = lbound;
extent = one;
if (treatIndexAsSection)
extent = fir::factory::readExtent(builder, loc, dataExv, dimension);
else
extent = one;
} else {
asFortran << ':';
Fortran::semantics::MaybeExpr upper =
Expand Down
78 changes: 46 additions & 32 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -889,16 +889,17 @@ void ClauseProcessor::processMapObjects(
lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
const omp::ObjectList &objects,
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapVars,
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

for (const omp::Object &object : objects) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
std::optional<omp::Object> parentObj;

lower::AddrAndBoundsInfo info =
lower::gatherDataOperandAddrAndBounds<mlir::omp::MapBoundsOp,
Expand All @@ -907,28 +908,47 @@ void ClauseProcessor::processMapObjects(
object.ref(), clauseLocation, asFortran, bounds,
treatIndexAsSection);

mlir::Value baseOp = info.rawInput;
if (object.sym()->owner().IsDerivedType()) {
omp::ObjectList objectList = gatherObjects(object, semaCtx);
assert(!objectList.empty() &&
"could not find parent objects of derived type member");
parentObj = objectList[0];
auto insert = parentMemberIndices.emplace(parentObj.value(),
OmpMapParentAndMemberData{});
insert.first->second.parentObjList.push_back(parentObj.value());

if (isMemberOrParentAllocatableOrPointer(object, semaCtx)) {
llvm::SmallVector<int64_t> indices;
generateMemberPlacementIndices(object, indices, semaCtx);
baseOp = createParentSymAndGenIntermediateMaps(
clauseLocation, converter, semaCtx, stmtCtx, objectList, indices,
parentMemberIndices[parentObj.value()], asFortran.str(),
mapTypeBits);
}
}

// Explicit map captures are captured ByRef by default,
// optimisation passes may alter this to ByCopy or other capture
// types to optimise
mlir::Value baseOp = info.rawInput;
auto location = mlir::NameLoc::get(
mlir::StringAttr::get(firOpBuilder.getContext(), asFortran.str()),
baseOp.getLoc());
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
firOpBuilder, location, baseOp,
/*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
/*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
/*members=*/{}, /*membersIndex=*/mlir::ArrayAttr{},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, baseOp.getType());

if (object.sym()->owner().IsDerivedType()) {
addChildIndexAndMapToParent(object, parentMemberIndices, mapOp, semaCtx);
if (parentObj.has_value()) {
addChildIndexAndMapToParent(
object, parentMemberIndices[parentObj.value()], mapOp, semaCtx);
} else {
mapVars.push_back(mapOp);
if (mapSyms)
mapSyms->push_back(object.sym());
mapSyms->push_back(object.sym());
if (mapSymTypes)
mapSymTypes->push_back(baseOp.getType());
if (mapSymLocs)
Expand All @@ -949,9 +969,7 @@ bool ClauseProcessor::processMap(
llvm::SmallVector<const semantics::Symbol *> localMapSyms;
llvm::SmallVectorImpl<const semantics::Symbol *> *ptrMapSyms =
mapSyms ? mapSyms : &localMapSyms;
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;

bool clauseFound = findRepeatableClause<omp::clause::Map>(
[&](const omp::clause::Map &clause, const parser::CharBlock &source) {
Expand Down Expand Up @@ -997,23 +1015,22 @@ bool ClauseProcessor::processMap(
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
}

processMapObjects(stmtCtx, clauseLocation,
std::get<omp::ObjectList>(clause.t), mapTypeBits,
parentMemberIndices, result.mapVars, ptrMapSyms,
mapSymLocs, mapSymTypes);
});

insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
*ptrMapSyms, mapSymTypes, mapSymLocs);

insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.mapVars, mapSymTypes, mapSymLocs,
ptrMapSyms);
return clauseFound;
}

bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result) {
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
llvm::SmallVector<const semantics::Symbol *> mapSymbols;

auto callbackFn = [&](const auto &clause, const parser::CharBlock &source) {
Expand All @@ -1034,9 +1051,9 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
clauseFound =
findRepeatableClause<omp::clause::From>(callbackFn) || clauseFound;

insertChildMapInfoIntoParent(converter, parentMemberIndices, result.mapVars,
mapSymbols,
/*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr);
insertChildMapInfoIntoParent(
converter, semaCtx, stmtCtx, parentMemberIndices, result.mapVars,
/*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr, &mapSymbols);
return clauseFound;
}

Expand Down Expand Up @@ -1110,9 +1127,7 @@ bool ClauseProcessor::processUseDeviceAddr(
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
bool clauseFound = findRepeatableClause<omp::clause::UseDeviceAddr>(
[&](const omp::clause::UseDeviceAddr &clause,
const parser::CharBlock &source) {
Expand All @@ -1125,9 +1140,9 @@ bool ClauseProcessor::processUseDeviceAddr(
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
});

insertChildMapInfoIntoParent(converter, parentMemberIndices,
result.useDeviceAddrVars, useDeviceSyms,
&useDeviceTypes, &useDeviceLocs);
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.useDeviceAddrVars, &useDeviceTypes,
&useDeviceLocs, &useDeviceSyms);
return clauseFound;
}

Expand All @@ -1136,9 +1151,8 @@ bool ClauseProcessor::processUseDevicePtr(
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
std::map<const semantics::Symbol *,
llvm::SmallVector<OmpMapMemberIndicesData>>
parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;

bool clauseFound = findRepeatableClause<omp::clause::UseDevicePtr>(
[&](const omp::clause::UseDevicePtr &clause,
const parser::CharBlock &source) {
Expand All @@ -1151,9 +1165,9 @@ bool ClauseProcessor::processUseDevicePtr(
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
});

insertChildMapInfoIntoParent(converter, parentMemberIndices,
result.useDevicePtrVars, useDeviceSyms,
&useDeviceTypes, &useDeviceLocs);
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.useDevicePtrVars, &useDeviceTypes,
&useDeviceLocs, &useDeviceSyms);
return clauseFound;
}

Expand Down
Loading