Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use device ptr implementation #60

Open
wants to merge 1 commit into
base: amd-trunk-dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1001,16 +1001,55 @@ bool ClauseProcessor::processEnter(
}

bool ClauseProcessor::processUseDeviceAddr(
mlir::omp::UseDeviceClauseOps &result,
Fortran::lower::StatementContext &stmtCtx,
llvm::SmallVectorImpl<mlir::Value> &operands,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms)
const {
return findRepeatableClause<omp::clause::UseDeviceAddr>(
[&](const omp::clause::UseDeviceAddr &clause,
const Fortran::parser::CharBlock &) {
addUseDeviceClause(converter, clause.v, result.useDeviceAddrVars,
useDeviceTypes, useDeviceLocs, useDeviceSyms);
// addUseDeviceClause(converter, clause.v, operands, useDeviceTypes,
// useDeviceLocs, useDeviceSymbols);
const Fortran::parser::CharBlock source;
mlir::Location clauseLocation = converter.genLocation(source);
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
for (const omp::Object &object : clause.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;

Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
converter, firOpBuilder, semaCtx, stmtCtx, *object.id(),
object.ref(), clauseLocation, asFortran, bounds,
treatIndexAsSection);

auto origSymbol = converter.getSymbolAddress(*object.id());
mlir::Value symAddr = info.addr;
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
symAddr = origSymbol;

// Explicit map captures are captured ByRef by default,
// optimisation passes may alter this to ByCopy or other capture
// types to optimise
mlir::Value mapOp = createMapInfoOp(
firOpBuilder, clauseLocation, symAddr, mlir::Value{},
asFortran.str(), bounds, {},
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());

useDeviceSyms.push_back(object.id());
useDeviceTypes.push_back(symAddr.getType());
useDeviceLocs.push_back(symAddr.getLoc());
operands.push_back(mapOp);
}
});
}

Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ class ClauseProcessor {
const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool
processUseDeviceAddr(mlir::omp::UseDeviceClauseOps &result,
processUseDeviceAddr(Fortran::lower::StatementContext &stmtCtx,
llvm::SmallVectorImpl<mlir::Value> &operands,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
Expand Down
5 changes: 2 additions & 3 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1210,15 +1210,14 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Type> useDeviceTypes;
llvm::SmallVector<mlir::Location> useDeviceLocs;
llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSyms;

llvm::SmallVector<mlir::Value> deviceAddrOperands;
ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(llvm::omp::Directive::OMPD_target_data, clauseOps);
cp.processDevice(stmtCtx, clauseOps);
cp.processUseDevicePtr(clauseOps, useDeviceTypes, useDeviceLocs,
useDeviceSyms);
cp.processUseDeviceAddr(clauseOps, useDeviceTypes, useDeviceLocs,
cp.processUseDeviceAddr(stmtCtx, deviceAddrOperands, useDeviceTypes, useDeviceLocs,
useDeviceSyms);

// This function implements the deprecated functionality of use_device_ptr
// that allows users to provide non-CPTR arguments to it with the caveat
// that the compiler will treat them as use_device_addr. A lot of legacy
Expand Down
31 changes: 20 additions & 11 deletions flang/lib/Optimizer/Transforms/OMPDescriptorMapInfoGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,10 @@ class OMPDescriptorMapInfoGenPass
// TODO: map the addendum segment of the descriptor, similarly to the
// above base address/data pointer member.

if (auto mapClauseOwner =
llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(target)) {
auto addOperands = [&](mlir::OperandRange &operandsArr, mlir::MutableOperandRange &mutableOpRange, auto directiveOp) {
llvm::SmallVector<mlir::Value> newMapOps;
mlir::OperandRange mapOperandsArr = mapClauseOwner.getMapOperands();

for (size_t i = 0; i < mapOperandsArr.size(); ++i) {
if (mapOperandsArr[i] == op) {
for (size_t i = 0; i < operandsArr.size(); ++i) {
if (operandsArr[i] == op) {
// Push new implicit maps generated for the descriptor.
newMapOps.push_back(baseAddr);

Expand All @@ -107,13 +104,25 @@ class OMPDescriptorMapInfoGenPass
// as the printing and later processing currently requires a 1:1
// mapping of BlockArgs to MapInfoOp's at the same placement in
// each array (BlockArgs and MapOperands).
if (auto targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(target))
targetOp.getRegion().insertArgument(i, baseAddr.getType(), loc);
if (directiveOp) {
directiveOp.getRegion().insertArgument(i, baseAddr.getType(), loc);
}
}
newMapOps.push_back(mapOperandsArr[i]);
newMapOps.push_back(operandsArr[i]);
}
mutableOpRange.assign(newMapOps);
};
if(auto mapClauseOwner = llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(target)){
mlir::OperandRange mapOperandsArr = mapClauseOwner.getMapOperands();
mlir::MutableOperandRange mapMutableOpRange = mapClauseOwner.getMapOperandsMutable();
mlir::omp::TargetOp targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(target);
addOperands(mapOperandsArr, mapMutableOpRange, targetOp);
}
if(auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target)) {
mlir::OperandRange useDevAddrArr = targetDataOp.getUseDeviceAddr();
mlir::MutableOperandRange useDevAddrMutableOpRange = targetDataOp.getUseDeviceAddrMutable();
addOperands(useDevAddrArr, useDevAddrMutableOpRange, targetDataOp);
}
mapClauseOwner.getMapOperandsMutable().assign(newMapOps);
}

mlir::Value newDescParentMapOp = builder.create<mlir::omp::MapInfoOp>(
op->getLoc(), op.getResult().getType(), descriptor,
Expand Down
8 changes: 5 additions & 3 deletions flang/test/Lower/OpenMP/FIR/target.f90
Original file line number Diff line number Diff line change
Expand Up @@ -450,12 +450,14 @@ end subroutine omp_target_device_ptr
subroutine omp_target_device_addr
integer, pointer :: a
!CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box<!fir.ptr<i32>> {bindc_name = "a", uniq_name = "_QFomp_target_device_addrEa"}
!CHECK: %[[USE_DEVICE_ADDR_MAP_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
!CHECK: %[[USE_DEVICE_ADDR_MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[USE_DEVICE_ADDR_MAP_MEMBERS]] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
!CHECK: %[[MAP_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_MEMBERS]] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
!CHECK: omp.target_data map_entries(%[[MAP_MEMBERS]], %[[MAP]] : {{.*}}) use_device_addr(%[[VAL_0]] : !fir.ref<!fir.box<!fir.ptr<i32>>>) {
!CHECK: omp.target_data map_entries(%[[MAP_MEMBERS]], %[[MAP]] : {{.*}}) use_device_addr(%[[USE_DEVICE_ADDR_MAP_MEMBERS]], %[[USE_DEVICE_ADDR_MAP]] : {{.*}}) {
!$omp target data map(tofrom: a) use_device_addr(a)
!CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref<!fir.box<!fir.ptr<i32>>>):
!CHECK: {{.*}} = fir.load %[[VAL_1]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
!CHECK: ^bb0(%[[VAL_1:.*]]: !fir.llvm_ptr<!fir.ref<i32>>):
!CHECK: {{.*}} = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
a = 10
!CHECK: omp.terminator
!$omp end target data
Expand Down
10 changes: 9 additions & 1 deletion llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6946,9 +6946,17 @@ void OpenMPIRBuilder::emitOffloadingArrays(
Value *BP = Builder.CreateConstInBoundsGEP2_32(
ArrayType::get(PtrTy, Info.NumberOfPtrs), Info.RTArgs.BasePointersArray,
0, I);
// AMD SEGFAULT HERE
if(!BP){
printf("BP is NULL \n");
}
if(!BPVal){
// SEGFAULT is because of NULL here
printf("BPVal is NULL \n");
}
Builder.CreateAlignedStore(BPVal, BP,
M.getDataLayout().getPrefTypeAlign(PtrTy));

// AMD END SEGFAULT HERE
if (Info.requiresDevicePointerInfo()) {
if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) {
CodeGenIP = Builder.saveIP();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2511,7 +2511,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
MapInfoData &mapData,
const SmallVector<Value> &devPtrOperands = {},
const SmallVector<Value> &devAddrOperands = {},
MapInfoData useDeviceAddrData = {},
bool isTargetParams = false) {
// We wish to modify some of the methods in which arguments are
// passed based on their capture type by the target region, this can
Expand Down Expand Up @@ -2622,7 +2622,18 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
};

addDevInfos(devPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
addDevInfos(devAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
// addDevInfos(devAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);

for (size_t i = 0; i < useDeviceAddrData.MapClause.size(); ++i) {
auto mapFlag = useDeviceAddrData.Types[i];
combinedInfo.BasePointers.emplace_back(useDeviceAddrData.BasePointers[i]);
combinedInfo.Pointers.emplace_back(useDeviceAddrData.Pointers[i]);
combinedInfo.DevicePointers.emplace_back(llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
combinedInfo.Names.emplace_back(useDeviceAddrData.Names[i]);
combinedInfo.Types.emplace_back(
mapFlag | llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
combinedInfo.Sizes.emplace_back(builder.getInt64(0));
}
}

static LogicalResult
Expand Down Expand Up @@ -2719,8 +2730,11 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;

MapInfoData mapData;
MapInfoData useDeviceAddrData;
collectMapDataFromMapOperands(mapData, mapOperands, moduleTranslation, DL,
builder);
collectMapDataFromMapOperands(useDeviceAddrData, useDevAddrOperands, moduleTranslation, DL,
builder);

// Fill up the arrays with all the mapped variables.
llvm::OpenMPIRBuilder::MapInfosTy combinedInfo;
Expand All @@ -2729,7 +2743,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
builder.restoreIP(codeGenIP);
if (auto dataOp = dyn_cast<omp::TargetDataOp>(op)) {
genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData,
useDevPtrOperands, useDevAddrOperands);
useDevPtrOperands, useDeviceAddrData);
} else {
genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData);
}
Expand Down Expand Up @@ -2758,12 +2772,10 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
info.DevicePtrInfoMap[mapOpValue].second);
argIndex++;
}

for (auto &devAddrOp : useDevAddrOperands) {
llvm::Value *mapOpValue = moduleTranslation.lookupValue(devAddrOp);
for (size_t i = 0; i < useDeviceAddrData.MapClause.size(); ++i) {
const auto &arg = region.front().getArgument(argIndex);
auto *LI = builder.CreateLoad(
builder.getPtrTy(), info.DevicePtrInfoMap[mapOpValue].second);
builder.getPtrTy(), info.DevicePtrInfoMap[useDeviceAddrData.OriginalValue[i]].second);
moduleTranslation.mapValue(arg, LI);
argIndex++;
}
Expand Down