diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 66cc1b131a7254..ab02a46f433cdf 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -2871,6 +2871,7 @@ class OpenMPIRBuilder { /// \param Loc where the target data construct was encountered. /// \param IsSPMD whether this is an SPMD target launch. /// \param IsOffloadEntry whether it is an offload entry. + /// \param IfCond value of the IF clause for the TARGET construct or nullptr. /// \param CodeGenIP The insertion point where the call to the outlined /// function should be emitted. /// \param EntryInfo The entry information about the function. @@ -2884,7 +2885,7 @@ class OpenMPIRBuilder { /// \param Dependencies A vector of DependData objects that carry // dependency information as passed in the depend clause. InsertPointTy createTarget(const LocationDescription &Loc, bool IsSPMD, - bool IsOffloadEntry, + bool IsOffloadEntry, Value *IfCond, OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 1661ea5bcb5295..a31210a4d35c37 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -7266,7 +7266,7 @@ static void emitTargetCall( const OpenMPIRBuilder::TargetKernelDefaultBounds &DefaultBounds, const OpenMPIRBuilder::TargetKernelRuntimeBounds &RuntimeBounds, Function *OutlinedFn, Constant *OutlinedFnID, - SmallVectorImpl &Args, + SmallVectorImpl &Args, Value *IfCond, OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB, SmallVector Dependencies = {}) { // Generate a function call to the host fallback implementation of the target @@ -7283,9 +7283,7 @@ static void emitTargetCall( bool HasDependencies = Dependencies.size() > 0; bool RequiresOuterTargetTask = HasNoWait || HasDependencies; - // If we don't have an ID for the target region, it means an offload entry - // wasn't created. In this case we just run the host fallback directly. - if (!OutlinedFnID) { + auto &&EmitTargetCallElse = [&]() { if (RequiresOuterTargetTask) { // Arguments that are intended to be directly forwarded to an // emitKernelLaunch call are pased as nullptr, since OutlinedFnID=nullptr @@ -7298,96 +7296,132 @@ static void emitTargetCall( } else { Builder.restoreIP(EmitTargetCallFallbackCB(Builder.saveIP())); } - return; - } - - OpenMPIRBuilder::TargetDataInfo Info( - /*RequiresDevicePointerInfo=*/false, - /*SeparateBeginEndCalls=*/true); - - OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP()); - OpenMPIRBuilder::TargetDataRTArgs RTArgs; - OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info, - RTArgs, MapInfo, - /*IsNonContiguous=*/true, - /*ForEndCall=*/false); - - SmallVector NumTeamsC; - for (auto [DefNumTeams, RtNumTeams] : - llvm::zip_equal(DefaultBounds.MaxTeams, RuntimeBounds.MaxTeams)) { - NumTeamsC.push_back(RtNumTeams ? RtNumTeams - : Builder.getInt32(DefNumTeams)); - } - - // Calculate number of threads: 0 if no clauses specified, otherwise it is the - // minimum between optional THREAD_LIMIT and MAX_THREADS clauses. Perform a - // type cast to uint32. - auto InitMaxThreadsClause = [&Builder](Value *Clause) { - if (Clause) - Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(), - /*isSigned=*/false); - return Clause; }; - auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) { - if (Clause) - Result = Result - ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause), + auto &&EmitTargetCallThen = [&]() { + OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); + + OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP()); + OpenMPIRBuilder::TargetDataRTArgs RTArgs; + OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info, + RTArgs, MapInfo, + /*IsNonContiguous=*/true, + /*ForEndCall=*/false); + + SmallVector NumTeamsC; + for (auto [DefNumTeams, RtNumTeams] : + llvm::zip_equal(DefaultBounds.MaxTeams, RuntimeBounds.MaxTeams)) { + NumTeamsC.push_back(RtNumTeams ? RtNumTeams + : Builder.getInt32(DefNumTeams)); + } + + // Calculate number of threads: 0 if no clauses specified, otherwise it is + // the minimum between optional THREAD_LIMIT and MAX_THREADS clauses. + // Perform a type cast to uint32. + auto InitMaxThreadsClause = [&Builder](Value *Clause) { + if (Clause) + Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(), + /*isSigned=*/false); + return Clause; + }; + + auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) { + if (Clause) + Result = + Result ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause), Result, Clause) : Clause; - }; + }; + + // TODO: Check if this is the correct handling for multi-dim thread_limit. + SmallVector NumThreadsC; + Value *MaxThreadsClause = InitMaxThreadsClause(RuntimeBounds.MaxThreads); - // TODO: Check if this is the correct handling for multi-dim thread_limit. - SmallVector NumThreadsC; - Value *MaxThreadsClause = InitMaxThreadsClause(RuntimeBounds.MaxThreads); + for (auto [RtTeamsThreadLimit, RtTargetThreadLimit] : llvm::zip_equal( + RuntimeBounds.TeamsThreadLimit, RuntimeBounds.TargetThreadLimit)) { + Value *TeamsThreadLimitClause = InitMaxThreadsClause(RtTeamsThreadLimit); + Value *NumThreads = InitMaxThreadsClause(RtTargetThreadLimit); - for (auto [RtTeamsThreadLimit, RtTargetThreadLimit] : llvm::zip_equal( - RuntimeBounds.TeamsThreadLimit, RuntimeBounds.TargetThreadLimit)) { - Value *TeamsThreadLimitClause = InitMaxThreadsClause(RtTeamsThreadLimit); - Value *NumThreads = InitMaxThreadsClause(RtTargetThreadLimit); + CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads); + CombineMaxThreadsClauses(MaxThreadsClause, NumThreads); - CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads); - CombineMaxThreadsClauses(MaxThreadsClause, NumThreads); + NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0)); + } - NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0)); + unsigned NumTargetItems = Info.NumberOfPtrs; + // TODO: Use correct device ID + Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF); + uint32_t SrcLocStrSize; + Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize); + Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize, + llvm::omp::IdentFlag(0), 0); + + Value *TripCount = RuntimeBounds.LoopTripCount + ? Builder.CreateIntCast(RuntimeBounds.LoopTripCount, + Builder.getInt64Ty(), + /*isSigned=*/false) + : Builder.getInt64(0); + + // TODO: Use correct DynCGGroupMem + Value *DynCGGroupMem = Builder.getInt32(0); + OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, TripCount, + NumTeamsC, NumThreadsC, + DynCGGroupMem, HasNoWait); + + // The presence of certain clauses on the target directive require the + // explicit generation of the target task. + if (RequiresOuterTargetTask) { + Builder.restoreIP(OMPBuilder.emitTargetTask( + OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID, + RTLoc, AllocaIP, Dependencies, HasNoWait)); + } else { + Builder.restoreIP(OMPBuilder.emitKernelLaunch( + Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, + DeviceID, RTLoc, AllocaIP)); + } + }; + + // If we don't have an ID for the target region, it means an offload entry + // wasn't created. In this case we just run the host fallback directly. + if (!OutlinedFnID) { + EmitTargetCallElse(); + return; } - unsigned NumTargetItems = Info.NumberOfPtrs; - // TODO: Use correct device ID - Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF); - uint32_t SrcLocStrSize; - Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize); - Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize, - llvm::omp::IdentFlag(0), 0); - - Value *TripCount = RuntimeBounds.LoopTripCount - ? Builder.CreateIntCast(RuntimeBounds.LoopTripCount, - Builder.getInt64Ty(), - /*isSigned=*/false) - : Builder.getInt64(0); - - // TODO: Use correct DynCGGroupMem - Value *DynCGGroupMem = Builder.getInt32(0); - OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, TripCount, - NumTeamsC, NumThreadsC, DynCGGroupMem, - HasNoWait); - - // The presence of certain clauses on the target directive require the - // explicit generation of the target task. - if (RequiresOuterTargetTask) { - Builder.restoreIP(OMPBuilder.emitTargetTask( - OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID, - RTLoc, AllocaIP, Dependencies, HasNoWait)); - } else { - Builder.restoreIP(OMPBuilder.emitKernelLaunch( - Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, - DeviceID, RTLoc, AllocaIP)); + // If there's no IF clause, only generate the kernel launch code path. + if (!IfCond) { + EmitTargetCallThen(); + return; } + + // Create if-else to handle IF clause. + llvm::BasicBlock *ThenBlock = + BasicBlock::Create(Builder.getContext(), "omp_if.then"); + llvm::BasicBlock *ElseBlock = + BasicBlock::Create(Builder.getContext(), "omp_if.else"); + llvm::BasicBlock *ContBlock = + BasicBlock::Create(Builder.getContext(), "omp_if.end"); + Builder.CreateCondBr(IfCond, ThenBlock, ElseBlock); + + Function *CurFn = Builder.GetInsertBlock()->getParent(); + + // Emit the 'then' code. + OMPBuilder.emitBlock(ThenBlock, CurFn); + EmitTargetCallThen(); + OMPBuilder.emitBranch(ContBlock); + // Emit the 'else' code. + OMPBuilder.emitBlock(ElseBlock, CurFn); + EmitTargetCallElse(); + OMPBuilder.emitBranch(ContBlock); + // Emit the continuation block. + OMPBuilder.emitBlock(ContBlock, CurFn, /*IsFinished=*/true); } OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget( const LocationDescription &Loc, bool IsSPMD, bool IsOffloadEntry, - InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + Value *IfCond, InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, const TargetKernelDefaultBounds &DefaultBounds, const TargetKernelRuntimeBounds &RuntimeBounds, @@ -7415,7 +7449,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget( // that represents the target region. Do that now. if (!Config.isTargetDevice()) emitTargetCall(*this, Builder, AllocaIP, DefaultBounds, RuntimeBounds, - OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies); + OutlinedFn, OutlinedFnID, Args, IfCond, GenMapInfoCB, + Dependencies); return Builder.saveIP(); } diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index 1f805eb85da539..a5584e6a52d8ba 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -6017,9 +6017,9 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { RuntimeBounds.TeamsThreadLimit.push_back(nullptr); RuntimeBounds.MaxTeams.push_back(nullptr); Builder.restoreIP(OMPBuilder.createTarget( - OmpLoc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, Builder.saveIP(), - Builder.saveIP(), EntryInfo, DefaultBounds, RuntimeBounds, Inputs, - GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); + OmpLoc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, /*IfCond=*/nullptr, + Builder.saveIP(), Builder.saveIP(), EntryInfo, DefaultBounds, + RuntimeBounds, Inputs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); OMPBuilder.finalize(); Builder.CreateRetVoid(); @@ -6134,9 +6134,9 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { RuntimeBounds.TeamsThreadLimit.push_back(nullptr); RuntimeBounds.MaxTeams.push_back(nullptr); Builder.restoreIP(OMPBuilder.createTarget( - Loc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, DefaultBounds, RuntimeBounds, CapturedArgs, GenMapInfoCB, - BodyGenCB, SimpleArgAccessorCB)); + Loc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, /*IfCond=*/nullptr, + EntryIP, EntryIP, EntryInfo, DefaultBounds, RuntimeBounds, CapturedArgs, + GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); Builder.CreateRetVoid(); OMPBuilder.finalize(); @@ -6290,9 +6290,9 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) { RuntimeBounds.TeamsThreadLimit.push_back(nullptr); RuntimeBounds.MaxTeams.push_back(nullptr); Builder.restoreIP(OMPBuilder.createTarget( - Loc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, DefaultBounds, RuntimeBounds, CapturedArgs, GenMapInfoCB, - BodyGenCB, SimpleArgAccessorCB)); + Loc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, /*IfCond=*/nullptr, + EntryIP, EntryIP, EntryInfo, DefaultBounds, RuntimeBounds, CapturedArgs, + GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); Builder.CreateRetVoid(); OMPBuilder.finalize(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index ae72894a2f9e31..7a3fb52aa4f612 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3476,10 +3476,6 @@ static bool getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo, static bool targetOpSupported(Operation &opInst) { auto targetOp = cast(opInst); - if (targetOp.getIfExpr()) { - opInst.emitError("If clause not yet supported"); - return false; - } if (targetOp.getDevice()) { opInst.emitError("Device clause not yet supported"); @@ -3955,8 +3951,12 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, if (Value targetThreadLimit = targetOp.getThreadLimit()) llvmTargetThreadLimit = moduleTranslation.lookupValue(targetThreadLimit); + llvm::Value *ifCond = nullptr; + if (Value targetIfCond = targetOp.getIfExpr()) + ifCond = moduleTranslation.lookupValue(targetIfCond); + builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTarget( - ompLoc, targetOp.isTargetSPMDLoop(), isOffloadEntry, allocaIP, + ompLoc, targetOp.isTargetSPMDLoop(), isOffloadEntry, ifCond, allocaIP, builder.saveIP(), entryInfo, defaultBounds, runtimeBounds, kernelInput, genMapInfoCB, bodyCB, argAccessorCB, dds));