diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index cbbd37fe0e06dc..3f4d721b431906 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -409,6 +409,7 @@ static LogicalResult inlineConvertOmpRegions( // Special case for single-block regions that don't create additional blocks: // insert operations without creating additional blocks. if (llvm::hasSingleElement(region)) { + llvm::Instruction *potentialTerminator = builder.GetInsertBlock()->empty() ? nullptr : &builder.GetInsertBlock()->back(); @@ -2515,8 +2516,12 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, argIndex++; } - bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region", - builder, moduleTranslation); + SmallVector phis; + llvm::BasicBlock *continuationBlock = + convertOmpOpRegions(region, "omp.data.region", builder, + moduleTranslation, bodyGenStatus, &phis); + builder.SetInsertPoint(continuationBlock, + continuationBlock->getFirstInsertionPt()); } break; case BodyGenTy::DupNoPriv: @@ -2525,8 +2530,12 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, // If device info is available then region has already been generated if (info.DevicePtrInfoMap.empty()) { builder.restoreIP(codeGenIP); - bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region", - builder, moduleTranslation); + SmallVector phis; + llvm::BasicBlock *continuationBlock = + convertOmpOpRegions(region, "omp.data.region", builder, + moduleTranslation, bodyGenStatus, &phis); + builder.SetInsertPoint(continuationBlock, + continuationBlock->getFirstInsertionPt()); } break; } @@ -3543,6 +3552,8 @@ convertTopLevelTargetOp(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { if (isa(op)) return convertOmpTarget(*op, builder, moduleTranslation); + if (isa(op)) + return convertOmpTargetData(op, builder, moduleTranslation); bool interrupted = op->walk([&](omp::TargetOp targetOp) { if (failed(convertOmpTarget(*targetOp, builder, moduleTranslation)))