From 8d970ac00a016cf3ee30825ece8c118a13ee7d4b Mon Sep 17 00:00:00 2001 From: Dominik Adamski Date: Tue, 16 Apr 2024 02:25:46 -0500 Subject: [PATCH] [WIP] Use device ptr implementation Patch cherry-picked into ATD branch Co-authored-by: Raghu Maddhipatla --- flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 45 +++++++++++++++++-- flang/lib/Lower/OpenMP/ClauseProcessor.h | 3 +- flang/lib/Lower/OpenMP/OpenMP.cpp | 5 +-- .../Transforms/OMPDescriptorMapInfoGen.cpp | 31 ++++++++----- flang/test/Lower/OpenMP/FIR/target.f90 | 8 ++-- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 10 ++++- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 26 ++++++++--- 7 files changed, 99 insertions(+), 29 deletions(-) diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 80d0199610ffff..7b5a743609979c 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -1001,7 +1001,8 @@ bool ClauseProcessor::processEnter( } bool ClauseProcessor::processUseDeviceAddr( - mlir::omp::UseDeviceClauseOps &result, + Fortran::lower::StatementContext &stmtCtx, + llvm::SmallVectorImpl &operands, llvm::SmallVectorImpl &useDeviceTypes, llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl &useDeviceSyms) @@ -1009,8 +1010,46 @@ bool ClauseProcessor::processUseDeviceAddr( return findRepeatableClause( [&](const omp::clause::UseDeviceAddr &clause, const Fortran::parser::CharBlock &) { - addUseDeviceClause(converter, clause.v, result.useDeviceAddrVars, - useDeviceTypes, useDeviceLocs, useDeviceSyms); + // addUseDeviceClause(converter, clause.v, operands, useDeviceTypes, + // useDeviceLocs, useDeviceSymbols); + const Fortran::parser::CharBlock source; + mlir::Location clauseLocation = converter.genLocation(source); + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + for (const omp::Object &object : clause.v) { + llvm::SmallVector bounds; + std::stringstream asFortran; + + Fortran::lower::AddrAndBoundsInfo info = + Fortran::lower::gatherDataOperandAddrAndBounds< + mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>( + converter, firOpBuilder, semaCtx, stmtCtx, *object.id(), + object.ref(), clauseLocation, asFortran, bounds, + treatIndexAsSection); + + auto origSymbol = converter.getSymbolAddress(*object.id()); + mlir::Value symAddr = info.addr; + if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType())) + symAddr = origSymbol; + + // Explicit map captures are captured ByRef by default, + // optimisation passes may alter this to ByCopy or other capture + // types to optimise + mlir::Value mapOp = createMapInfoOp( + firOpBuilder, clauseLocation, symAddr, mlir::Value{}, + asFortran.str(), bounds, {}, + static_cast< + std::underlying_type_t>( + mapTypeBits), + mlir::omp::VariableCaptureKind::ByRef, symAddr.getType()); + + useDeviceSyms.push_back(object.id()); + useDeviceTypes.push_back(symAddr.getType()); + useDeviceLocs.push_back(symAddr.getLoc()); + operands.push_back(mapOp); + } }); } diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index 5ddf1f2e64a9fe..553c355a3bc029 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -133,7 +133,8 @@ class ClauseProcessor { const; bool processTo(llvm::SmallVectorImpl &result) const; bool - processUseDeviceAddr(mlir::omp::UseDeviceClauseOps &result, + processUseDeviceAddr(Fortran::lower::StatementContext &stmtCtx, + llvm::SmallVectorImpl &operands, llvm::SmallVectorImpl &useDeviceTypes, llvm::SmallVectorImpl &useDeviceLocs, llvm::SmallVectorImpl diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index d107c48a6422f7..21a65397cc5f18 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -1210,15 +1210,14 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector useDeviceTypes; llvm::SmallVector useDeviceLocs; llvm::SmallVector useDeviceSyms; - + llvm::SmallVector deviceAddrOperands; ClauseProcessor cp(converter, semaCtx, clauseList); cp.processIf(llvm::omp::Directive::OMPD_target_data, clauseOps); cp.processDevice(stmtCtx, clauseOps); cp.processUseDevicePtr(clauseOps, useDeviceTypes, useDeviceLocs, useDeviceSyms); - cp.processUseDeviceAddr(clauseOps, useDeviceTypes, useDeviceLocs, + cp.processUseDeviceAddr(stmtCtx, deviceAddrOperands, useDeviceTypes, useDeviceLocs, useDeviceSyms); - // This function implements the deprecated functionality of use_device_ptr // that allows users to provide non-CPTR arguments to it with the caveat // that the compiler will treat them as use_device_addr. A lot of legacy diff --git a/flang/lib/Optimizer/Transforms/OMPDescriptorMapInfoGen.cpp b/flang/lib/Optimizer/Transforms/OMPDescriptorMapInfoGen.cpp index 6ffcf0746c76fc..321e8b67717546 100644 --- a/flang/lib/Optimizer/Transforms/OMPDescriptorMapInfoGen.cpp +++ b/flang/lib/Optimizer/Transforms/OMPDescriptorMapInfoGen.cpp @@ -92,13 +92,10 @@ class OMPDescriptorMapInfoGenPass // TODO: map the addendum segment of the descriptor, similarly to the // above base address/data pointer member. - if (auto mapClauseOwner = - llvm::dyn_cast(target)) { + auto addOperands = [&](mlir::OperandRange &operandsArr, mlir::MutableOperandRange &mutableOpRange, auto directiveOp) { llvm::SmallVector newMapOps; - mlir::OperandRange mapOperandsArr = mapClauseOwner.getMapOperands(); - - for (size_t i = 0; i < mapOperandsArr.size(); ++i) { - if (mapOperandsArr[i] == op) { + for (size_t i = 0; i < operandsArr.size(); ++i) { + if (operandsArr[i] == op) { // Push new implicit maps generated for the descriptor. newMapOps.push_back(baseAddr); @@ -107,13 +104,25 @@ class OMPDescriptorMapInfoGenPass // as the printing and later processing currently requires a 1:1 // mapping of BlockArgs to MapInfoOp's at the same placement in // each array (BlockArgs and MapOperands). - if (auto targetOp = llvm::dyn_cast(target)) - targetOp.getRegion().insertArgument(i, baseAddr.getType(), loc); + if (directiveOp) { + directiveOp.getRegion().insertArgument(i, baseAddr.getType(), loc); + } } - newMapOps.push_back(mapOperandsArr[i]); + newMapOps.push_back(operandsArr[i]); + } + mutableOpRange.assign(newMapOps); + }; + if(auto mapClauseOwner = llvm::dyn_cast(target)){ + mlir::OperandRange mapOperandsArr = mapClauseOwner.getMapOperands(); + mlir::MutableOperandRange mapMutableOpRange = mapClauseOwner.getMapOperandsMutable(); + mlir::omp::TargetOp targetOp = llvm::dyn_cast(target); + addOperands(mapOperandsArr, mapMutableOpRange, targetOp); + } + if(auto targetDataOp = llvm::dyn_cast(target)) { + mlir::OperandRange useDevAddrArr = targetDataOp.getUseDeviceAddr(); + mlir::MutableOperandRange useDevAddrMutableOpRange = targetDataOp.getUseDeviceAddrMutable(); + addOperands(useDevAddrArr, useDevAddrMutableOpRange, targetDataOp); } - mapClauseOwner.getMapOperandsMutable().assign(newMapOps); - } mlir::Value newDescParentMapOp = builder.create( op->getLoc(), op.getResult().getType(), descriptor, diff --git a/flang/test/Lower/OpenMP/FIR/target.f90 b/flang/test/Lower/OpenMP/FIR/target.f90 index c36dad5f4ff30e..8ab1cd39437a32 100644 --- a/flang/test/Lower/OpenMP/FIR/target.f90 +++ b/flang/test/Lower/OpenMP/FIR/target.f90 @@ -450,12 +450,14 @@ end subroutine omp_target_device_ptr subroutine omp_target_device_addr integer, pointer :: a !CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box> {bindc_name = "a", uniq_name = "_QFomp_target_device_addrEa"} + !CHECK: %[[USE_DEVICE_ADDR_MAP_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref>>, i32) var_ptr_ptr({{.*}} : !fir.llvm_ptr>) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr> {name = ""} + !CHECK: %[[USE_DEVICE_ADDR_MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref>>, !fir.box>) map_clauses(tofrom) capture(ByRef) members(%[[USE_DEVICE_ADDR_MAP_MEMBERS]] : !fir.llvm_ptr>) -> !fir.ref>> {name = "a"} !CHECK: %[[MAP_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref>>, i32) var_ptr_ptr({{.*}} : !fir.llvm_ptr>) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr> {name = ""} !CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref>>, !fir.box>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_MEMBERS]] : !fir.llvm_ptr>) -> !fir.ref>> {name = "a"} - !CHECK: omp.target_data map_entries(%[[MAP_MEMBERS]], %[[MAP]] : {{.*}}) use_device_addr(%[[VAL_0]] : !fir.ref>>) { + !CHECK: omp.target_data map_entries(%[[MAP_MEMBERS]], %[[MAP]] : {{.*}}) use_device_addr(%[[USE_DEVICE_ADDR_MAP_MEMBERS]], %[[USE_DEVICE_ADDR_MAP]] : {{.*}}) { !$omp target data map(tofrom: a) use_device_addr(a) - !CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref>>): - !CHECK: {{.*}} = fir.load %[[VAL_1]] : !fir.ref>> + !CHECK: ^bb0(%[[VAL_1:.*]]: !fir.llvm_ptr>): + !CHECK: {{.*}} = fir.load %[[VAL_0]] : !fir.ref>> a = 10 !CHECK: omp.terminator !$omp end target data diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index e8e1ad435e6cec..9039763e20af7c 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -6946,9 +6946,17 @@ void OpenMPIRBuilder::emitOffloadingArrays( Value *BP = Builder.CreateConstInBoundsGEP2_32( ArrayType::get(PtrTy, Info.NumberOfPtrs), Info.RTArgs.BasePointersArray, 0, I); + // AMD SEGFAULT HERE + if(!BP){ + printf("BP is NULL \n"); + } + if(!BPVal){ + // SEGFAULT is because of NULL here + printf("BPVal is NULL \n"); + } Builder.CreateAlignedStore(BPVal, BP, M.getDataLayout().getPrefTypeAlign(PtrTy)); - + // AMD END SEGFAULT HERE if (Info.requiresDevicePointerInfo()) { if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) { CodeGenIP = Builder.saveIP(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index e1107f55b82c39..93631425d72068 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -2511,7 +2511,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, const SmallVector &devPtrOperands = {}, - const SmallVector &devAddrOperands = {}, + MapInfoData useDeviceAddrData = {}, bool isTargetParams = false) { // We wish to modify some of the methods in which arguments are // passed based on their capture type by the target region, this can @@ -2622,7 +2622,18 @@ static void genMapInfos(llvm::IRBuilderBase &builder, }; addDevInfos(devPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer); - addDevInfos(devAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address); + // addDevInfos(devAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address); + + for (size_t i = 0; i < useDeviceAddrData.MapClause.size(); ++i) { + auto mapFlag = useDeviceAddrData.Types[i]; + combinedInfo.BasePointers.emplace_back(useDeviceAddrData.BasePointers[i]); + combinedInfo.Pointers.emplace_back(useDeviceAddrData.Pointers[i]); + combinedInfo.DevicePointers.emplace_back(llvm::OpenMPIRBuilder::DeviceInfoTy::Address); + combinedInfo.Names.emplace_back(useDeviceAddrData.Names[i]); + combinedInfo.Types.emplace_back( + mapFlag | llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM); + combinedInfo.Sizes.emplace_back(builder.getInt64(0)); + } } static LogicalResult @@ -2719,8 +2730,11 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; MapInfoData mapData; + MapInfoData useDeviceAddrData; collectMapDataFromMapOperands(mapData, mapOperands, moduleTranslation, DL, builder); + collectMapDataFromMapOperands(useDeviceAddrData, useDevAddrOperands, moduleTranslation, DL, + builder); // Fill up the arrays with all the mapped variables. llvm::OpenMPIRBuilder::MapInfosTy combinedInfo; @@ -2729,7 +2743,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, builder.restoreIP(codeGenIP); if (auto dataOp = dyn_cast(op)) { genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData, - useDevPtrOperands, useDevAddrOperands); + useDevPtrOperands, useDeviceAddrData); } else { genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData); } @@ -2758,12 +2772,10 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, info.DevicePtrInfoMap[mapOpValue].second); argIndex++; } - - for (auto &devAddrOp : useDevAddrOperands) { - llvm::Value *mapOpValue = moduleTranslation.lookupValue(devAddrOp); + for (size_t i = 0; i < useDeviceAddrData.MapClause.size(); ++i) { const auto &arg = region.front().getArgument(argIndex); auto *LI = builder.CreateLoad( - builder.getPtrTy(), info.DevicePtrInfoMap[mapOpValue].second); + builder.getPtrTy(), info.DevicePtrInfoMap[useDeviceAddrData.OriginalValue[i]].second); moduleTranslation.mapValue(arg, LI); argIndex++; }