Skip to content

Commit

Permalink
[MLIR][OpenMP][OMPIRBuilder] Support omp.target IF translation to LLV…
Browse files Browse the repository at this point in the history
…M IR

This patch adds missing MLIR to LLVM IR translation support for the `if` clause
on `omp.target` operations. This completes the missing piece for Fortran
support of `!$omp target if(...)`.

The implementation updates `emitTargetCall()` in the OMPIRBuilder to follow
clang's support for the `if` clause in  `CGOpenMPRuntime::emitTargetCall()`.
  • Loading branch information
skatrak committed Sep 17, 2024
1 parent 57548b0 commit 90560e7
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 95 deletions.
3 changes: 2 additions & 1 deletion llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2871,6 +2871,7 @@ class OpenMPIRBuilder {
/// \param Loc where the target data construct was encountered.
/// \param IsSPMD whether this is an SPMD target launch.
/// \param IsOffloadEntry whether it is an offload entry.
/// \param IfCond value of the IF clause for the TARGET construct or nullptr.
/// \param CodeGenIP The insertion point where the call to the outlined
/// function should be emitted.
/// \param EntryInfo The entry information about the function.
Expand All @@ -2884,7 +2885,7 @@ class OpenMPIRBuilder {
/// \param Dependencies A vector of DependData objects that carry
// dependency information as passed in the depend clause.
InsertPointTy createTarget(const LocationDescription &Loc, bool IsSPMD,
bool IsOffloadEntry,
bool IsOffloadEntry, Value *IfCond,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo,
Expand Down
195 changes: 115 additions & 80 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7266,7 +7266,7 @@ static void emitTargetCall(
const OpenMPIRBuilder::TargetKernelDefaultBounds &DefaultBounds,
const OpenMPIRBuilder::TargetKernelRuntimeBounds &RuntimeBounds,
Function *OutlinedFn, Constant *OutlinedFnID,
SmallVectorImpl<Value *> &Args,
SmallVectorImpl<Value *> &Args, Value *IfCond,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) {
// Generate a function call to the host fallback implementation of the target
Expand All @@ -7283,9 +7283,7 @@ static void emitTargetCall(
bool HasDependencies = Dependencies.size() > 0;
bool RequiresOuterTargetTask = HasNoWait || HasDependencies;

// If we don't have an ID for the target region, it means an offload entry
// wasn't created. In this case we just run the host fallback directly.
if (!OutlinedFnID) {
auto &&EmitTargetCallElse = [&]() {
if (RequiresOuterTargetTask) {
// Arguments that are intended to be directly forwarded to an
// emitKernelLaunch call are pased as nullptr, since OutlinedFnID=nullptr
Expand All @@ -7298,96 +7296,132 @@ static void emitTargetCall(
} else {
Builder.restoreIP(EmitTargetCallFallbackCB(Builder.saveIP()));
}
return;
}

OpenMPIRBuilder::TargetDataInfo Info(
/*RequiresDevicePointerInfo=*/false,
/*SeparateBeginEndCalls=*/true);

OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
RTArgs, MapInfo,
/*IsNonContiguous=*/true,
/*ForEndCall=*/false);

SmallVector<Value *, 3> NumTeamsC;
for (auto [DefNumTeams, RtNumTeams] :
llvm::zip_equal(DefaultBounds.MaxTeams, RuntimeBounds.MaxTeams)) {
NumTeamsC.push_back(RtNumTeams ? RtNumTeams
: Builder.getInt32(DefNumTeams));
}

// Calculate number of threads: 0 if no clauses specified, otherwise it is the
// minimum between optional THREAD_LIMIT and MAX_THREADS clauses. Perform a
// type cast to uint32.
auto InitMaxThreadsClause = [&Builder](Value *Clause) {
if (Clause)
Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
/*isSigned=*/false);
return Clause;
};

auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
if (Clause)
Result = Result
? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
auto &&EmitTargetCallThen = [&]() {
OpenMPIRBuilder::TargetDataInfo Info(
/*RequiresDevicePointerInfo=*/false,
/*SeparateBeginEndCalls=*/true);

OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
RTArgs, MapInfo,
/*IsNonContiguous=*/true,
/*ForEndCall=*/false);

SmallVector<Value *, 3> NumTeamsC;
for (auto [DefNumTeams, RtNumTeams] :
llvm::zip_equal(DefaultBounds.MaxTeams, RuntimeBounds.MaxTeams)) {
NumTeamsC.push_back(RtNumTeams ? RtNumTeams
: Builder.getInt32(DefNumTeams));
}

// Calculate number of threads: 0 if no clauses specified, otherwise it is
// the minimum between optional THREAD_LIMIT and MAX_THREADS clauses.
// Perform a type cast to uint32.
auto InitMaxThreadsClause = [&Builder](Value *Clause) {
if (Clause)
Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
/*isSigned=*/false);
return Clause;
};

auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
if (Clause)
Result =
Result ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
Result, Clause)
: Clause;
};
};

// TODO: Check if this is the correct handling for multi-dim thread_limit.
SmallVector<Value *, 3> NumThreadsC;
Value *MaxThreadsClause = InitMaxThreadsClause(RuntimeBounds.MaxThreads);

// TODO: Check if this is the correct handling for multi-dim thread_limit.
SmallVector<Value *, 3> NumThreadsC;
Value *MaxThreadsClause = InitMaxThreadsClause(RuntimeBounds.MaxThreads);
for (auto [RtTeamsThreadLimit, RtTargetThreadLimit] : llvm::zip_equal(
RuntimeBounds.TeamsThreadLimit, RuntimeBounds.TargetThreadLimit)) {
Value *TeamsThreadLimitClause = InitMaxThreadsClause(RtTeamsThreadLimit);
Value *NumThreads = InitMaxThreadsClause(RtTargetThreadLimit);

for (auto [RtTeamsThreadLimit, RtTargetThreadLimit] : llvm::zip_equal(
RuntimeBounds.TeamsThreadLimit, RuntimeBounds.TargetThreadLimit)) {
Value *TeamsThreadLimitClause = InitMaxThreadsClause(RtTeamsThreadLimit);
Value *NumThreads = InitMaxThreadsClause(RtTargetThreadLimit);
CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);

CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
}

NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
unsigned NumTargetItems = Info.NumberOfPtrs;
// TODO: Use correct device ID
Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
uint32_t SrcLocStrSize;
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
llvm::omp::IdentFlag(0), 0);

Value *TripCount = RuntimeBounds.LoopTripCount
? Builder.CreateIntCast(RuntimeBounds.LoopTripCount,
Builder.getInt64Ty(),
/*isSigned=*/false)
: Builder.getInt64(0);

// TODO: Use correct DynCGGroupMem
Value *DynCGGroupMem = Builder.getInt32(0);
OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, TripCount,
NumTeamsC, NumThreadsC,
DynCGGroupMem, HasNoWait);

// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
if (RequiresOuterTargetTask) {
Builder.restoreIP(OMPBuilder.emitTargetTask(
OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
RTLoc, AllocaIP, Dependencies, HasNoWait));
} else {
Builder.restoreIP(OMPBuilder.emitKernelLaunch(
Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
DeviceID, RTLoc, AllocaIP));
}
};

// If we don't have an ID for the target region, it means an offload entry
// wasn't created. In this case we just run the host fallback directly.
if (!OutlinedFnID) {
EmitTargetCallElse();
return;
}

unsigned NumTargetItems = Info.NumberOfPtrs;
// TODO: Use correct device ID
Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
uint32_t SrcLocStrSize;
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
llvm::omp::IdentFlag(0), 0);

Value *TripCount = RuntimeBounds.LoopTripCount
? Builder.CreateIntCast(RuntimeBounds.LoopTripCount,
Builder.getInt64Ty(),
/*isSigned=*/false)
: Builder.getInt64(0);

// TODO: Use correct DynCGGroupMem
Value *DynCGGroupMem = Builder.getInt32(0);
OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, TripCount,
NumTeamsC, NumThreadsC, DynCGGroupMem,
HasNoWait);

// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
if (RequiresOuterTargetTask) {
Builder.restoreIP(OMPBuilder.emitTargetTask(
OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
RTLoc, AllocaIP, Dependencies, HasNoWait));
} else {
Builder.restoreIP(OMPBuilder.emitKernelLaunch(
Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
DeviceID, RTLoc, AllocaIP));
// If there's no IF clause, only generate the kernel launch code path.
if (!IfCond) {
EmitTargetCallThen();
return;
}

// Create if-else to handle IF clause.
llvm::BasicBlock *ThenBlock =
BasicBlock::Create(Builder.getContext(), "omp_if.then");
llvm::BasicBlock *ElseBlock =
BasicBlock::Create(Builder.getContext(), "omp_if.else");
llvm::BasicBlock *ContBlock =
BasicBlock::Create(Builder.getContext(), "omp_if.end");
Builder.CreateCondBr(IfCond, ThenBlock, ElseBlock);

Function *CurFn = Builder.GetInsertBlock()->getParent();

// Emit the 'then' code.
OMPBuilder.emitBlock(ThenBlock, CurFn);
EmitTargetCallThen();
OMPBuilder.emitBranch(ContBlock);
// Emit the 'else' code.
OMPBuilder.emitBlock(ElseBlock, CurFn);
EmitTargetCallElse();
OMPBuilder.emitBranch(ContBlock);
// Emit the continuation block.
OMPBuilder.emitBlock(ContBlock, CurFn, /*IsFinished=*/true);
}

OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, bool IsSPMD, bool IsOffloadEntry,
InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
Value *IfCond, InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultBounds &DefaultBounds,
const TargetKernelRuntimeBounds &RuntimeBounds,
Expand Down Expand Up @@ -7415,7 +7449,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
emitTargetCall(*this, Builder, AllocaIP, DefaultBounds, RuntimeBounds,
OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies);
OutlinedFn, OutlinedFnID, Args, IfCond, GenMapInfoCB,
Dependencies);
return Builder.saveIP();
}

Expand Down
18 changes: 9 additions & 9 deletions llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6017,9 +6017,9 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
RuntimeBounds.TeamsThreadLimit.push_back(nullptr);
RuntimeBounds.MaxTeams.push_back(nullptr);
Builder.restoreIP(OMPBuilder.createTarget(
OmpLoc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, Builder.saveIP(),
Builder.saveIP(), EntryInfo, DefaultBounds, RuntimeBounds, Inputs,
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
OmpLoc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, /*IfCond=*/nullptr,
Builder.saveIP(), Builder.saveIP(), EntryInfo, DefaultBounds,
RuntimeBounds, Inputs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
OMPBuilder.finalize();
Builder.CreateRetVoid();

Expand Down Expand Up @@ -6134,9 +6134,9 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
RuntimeBounds.TeamsThreadLimit.push_back(nullptr);
RuntimeBounds.MaxTeams.push_back(nullptr);
Builder.restoreIP(OMPBuilder.createTarget(
Loc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
EntryInfo, DefaultBounds, RuntimeBounds, CapturedArgs, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB));
Loc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, /*IfCond=*/nullptr,
EntryIP, EntryIP, EntryInfo, DefaultBounds, RuntimeBounds, CapturedArgs,
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));

Builder.CreateRetVoid();
OMPBuilder.finalize();
Expand Down Expand Up @@ -6290,9 +6290,9 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
RuntimeBounds.TeamsThreadLimit.push_back(nullptr);
RuntimeBounds.MaxTeams.push_back(nullptr);
Builder.restoreIP(OMPBuilder.createTarget(
Loc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
EntryInfo, DefaultBounds, RuntimeBounds, CapturedArgs, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB));
Loc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, /*IfCond=*/nullptr,
EntryIP, EntryIP, EntryInfo, DefaultBounds, RuntimeBounds, CapturedArgs,
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));

Builder.CreateRetVoid();
OMPBuilder.finalize();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3476,10 +3476,6 @@ static bool getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,

static bool targetOpSupported(Operation &opInst) {
auto targetOp = cast<omp::TargetOp>(opInst);
if (targetOp.getIfExpr()) {
opInst.emitError("If clause not yet supported");
return false;
}

if (targetOp.getDevice()) {
opInst.emitError("Device clause not yet supported");
Expand Down Expand Up @@ -3955,8 +3951,12 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
if (Value targetThreadLimit = targetOp.getThreadLimit())
llvmTargetThreadLimit = moduleTranslation.lookupValue(targetThreadLimit);

llvm::Value *ifCond = nullptr;
if (Value targetIfCond = targetOp.getIfExpr())
ifCond = moduleTranslation.lookupValue(targetIfCond);

builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTarget(
ompLoc, targetOp.isTargetSPMDLoop(), isOffloadEntry, allocaIP,
ompLoc, targetOp.isTargetSPMDLoop(), isOffloadEntry, ifCond, allocaIP,
builder.saveIP(), entryInfo, defaultBounds, runtimeBounds, kernelInput,
genMapInfoCB, bodyCB, argAccessorCB, dds));

Expand Down

0 comments on commit 90560e7

Please sign in to comment.