Skip to content

Commit

Permalink
Implement parent bounds resizing when multiple subscripts
Browse files Browse the repository at this point in the history
  • Loading branch information
agozillon committed Sep 20, 2024
1 parent a648bd2 commit 5d62d62
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 75 deletions.
33 changes: 25 additions & 8 deletions flang/include/flang/Lower/OpenMP/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ using DeclareTargetCapturePair =
//
// map(tofrom: top_dtype%nested%i2, top_dtype%k, top_dtype%nested%array_i2)
//
// We would end up with an OmpMapMemberIndicesData populated like below:
// We would end up with an OmpMapParentAndMemberData populated like below:
//
// memberPlacementIndices:
// Vector 1: 3, 0
Expand All @@ -86,17 +86,33 @@ using DeclareTargetCapturePair =
// Entry 2: omp.map.info for "top_dtype%k"
// Entry 3: omp.map.info for "top_dtype%nested%array_i2"
//
// And this OmpMapMemberIndicesData would be accessed via the parent
// And this OmpMapParentAndMemberData would be accessed via the parent
// symbol for top_dtype. Other parent derived type instances that have
// members mapped would have there own OmpMapMemberIndicesData entry
// members mapped would have there own OmpMapParentAndMemberData entry
// accessed via their own symbol.
struct OmpMapMemberIndicesData {
struct OmpMapParentAndMemberData {
// The indices representing the component members placement in its derived
// type parents hierarchy.
llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;

// Placement of the member in the member vector.
llvm::SmallVector<mlir::omp::MapInfoOp> memberMap;

// The list of associated parent object symbols. used to track data we
// need for various parent processing tasks when performing member
// mapping, the main example currently being re-evaluating the parent
// maps bounds at the final step of map processing, where we need to
// keep a hold of all of the omp::Object's which contain array bounds
// for the respective parent to calculate the final bounds from.
//
// As an Example:
//
// !$omp target map(tofrom: alloca_dtype_arr(2)%array_i,
// alloca_dtype_arr(3)%array_i)
//
// parentObjList will contain alloca_dtype_arr(3) as well as
// alloca_dtype_arr(2).
ObjectList parentObjList;
};

void generateMemberPlacementIndices(
Expand All @@ -106,7 +122,7 @@ void generateMemberPlacementIndices(
bool isMemberOrParentAllocatableOrPointer(
const Object &object, Fortran::semantics::SemanticsContext &semaCtx);

bool isDuplicateMemberMapInfo(OmpMapMemberIndicesData &parentMembers,
bool isDuplicateMemberMapInfo(OmpMapParentAndMemberData &parentMembers,
llvm::SmallVectorImpl<int64_t> &memberIndices);

mlir::omp::MapInfoOp createMapInfoOp(
Expand All @@ -120,21 +136,22 @@ mlir::Value createParentSymAndGenIntermediateMaps(
mlir::Location clauseLocation, Fortran::lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx,
omp::ObjectList &objectList, llvm::SmallVector<int64_t> &indices,
OmpMapMemberIndicesData &parentMemberIndices, std::string asFortran,
OmpMapParentAndMemberData &parentMemberIndices, std::string asFortran,
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits);

omp::ObjectList gatherObjects(omp::Object obj,
semantics::SemanticsContext &semaCtx);

void addChildIndexAndMapToParent(const omp::Object &object,
OmpMapMemberIndicesData &parentMemberIndices,
OmpMapParentAndMemberData &parentMemberIndices,
mlir::omp::MapInfoOp &mapOp,
semantics::SemanticsContext &semaCtx);

void insertChildMapInfoIntoParent(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
std::map<omp::Object, OmpMapMemberIndicesData> &parentMemberIndices,
Fortran::lower::StatementContext &stmtCtx,
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapOperands,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
Expand Down
24 changes: 13 additions & 11 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -889,12 +889,13 @@ void ClauseProcessor::processMapObjects(
lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
const omp::ObjectList &objects,
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
std::map<omp::Object, OmpMapMemberIndicesData> &parentMemberIndices,
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapVars,
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

for (const omp::Object &object : objects) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
Expand All @@ -913,7 +914,10 @@ void ClauseProcessor::processMapObjects(
assert(!objectList.empty() &&
"could not find parent objects of derived type member");
parentObj = objectList[0];
parentMemberIndices.emplace(parentObj.value(), OmpMapMemberIndicesData{});
auto insert = parentMemberIndices.emplace(parentObj.value(),
OmpMapParentAndMemberData{});
insert.first->second.parentObjList.push_back(parentObj.value());

if (isMemberOrParentAllocatableOrPointer(object, semaCtx)) {
llvm::SmallVector<int64_t> indices;
generateMemberPlacementIndices(object, indices, semaCtx);
Expand Down Expand Up @@ -965,7 +969,7 @@ bool ClauseProcessor::processMap(
llvm::SmallVector<const semantics::Symbol *> localMapSyms;
llvm::SmallVectorImpl<const semantics::Symbol *> *ptrMapSyms =
mapSyms ? mapSyms : &localMapSyms;
std::map<omp::Object, OmpMapMemberIndicesData> parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;

bool clauseFound = findRepeatableClause<omp::clause::Map>(
[&](const omp::clause::Map &clause, const parser::CharBlock &source) {
Expand Down Expand Up @@ -1018,7 +1022,7 @@ bool ClauseProcessor::processMap(
mapSymLocs, mapSymTypes);
});

insertChildMapInfoIntoParent(converter, semaCtx, parentMemberIndices,
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.mapVars, mapSymTypes, mapSymLocs,
ptrMapSyms);
return clauseFound;
Expand Down Expand Up @@ -1082,7 +1086,7 @@ bool ClauseProcessor::processUseDeviceAddr(
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
std::map<omp::Object, OmpMapMemberIndicesData> parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
bool clauseFound = findRepeatableClause<omp::clause::UseDeviceAddr>(
[&](const omp::clause::UseDeviceAddr &clause,
const parser::CharBlock &source) {
Expand All @@ -1095,9 +1099,8 @@ bool ClauseProcessor::processUseDeviceAddr(
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
});


insertChildMapInfoIntoParent(converter, semaCtx, parentMemberIndices,
result.useDeviceAddrVars, &useDeviceTypes,
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.useDeviceAddrVars, &useDeviceTypes,
&useDeviceLocs, &useDeviceSyms);
return clauseFound;
}
Expand All @@ -1107,7 +1110,7 @@ bool ClauseProcessor::processUseDevicePtr(
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
std::map<omp::Object, OmpMapMemberIndicesData> parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;

bool clauseFound = findRepeatableClause<omp::clause::UseDevicePtr>(
[&](const omp::clause::UseDevicePtr &clause,
Expand All @@ -1121,8 +1124,7 @@ bool ClauseProcessor::processUseDevicePtr(
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
});


insertChildMapInfoIntoParent(converter, semaCtx, parentMemberIndices,
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.useDevicePtrVars, &useDeviceTypes,
&useDeviceLocs, &useDeviceSyms);
return clauseFound;
Expand Down
6 changes: 3 additions & 3 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class ClauseProcessor {
lower::StatementContext &stmtCtx, mlir::Location clauseLocation,
const omp::ObjectList &objects,
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
std::map<omp::Object, OmpMapMemberIndicesData> &parentMemberIndices,
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
llvm::SmallVectorImpl<mlir::Value> &mapVars,
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
Expand All @@ -192,7 +192,7 @@ class ClauseProcessor {
template <typename T>
bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
mlir::omp::MapClauseOps &result) {
std::map<omp::Object, OmpMapMemberIndicesData> parentMemberIndices;
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
llvm::SmallVector<const semantics::Symbol *> mapSymbols;

bool clauseFound = findRepeatableClause<T>(
Expand All @@ -213,7 +213,7 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
});

insertChildMapInfoIntoParent(
converter, semaCtx, parentMemberIndices, result.mapVars,
converter, semaCtx, stmtCtx, parentMemberIndices, result.mapVars,
/*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr, &mapSymbols);
return clauseFound;
}
Expand Down
Loading

0 comments on commit 5d62d62

Please sign in to comment.