Skip to content

Commit

Permalink
Fix handling of call instruction
Browse files Browse the repository at this point in the history
  • Loading branch information
fineg74 committed Jul 14, 2024
1 parent a246761 commit 6dcd8c7
Showing 1 changed file with 51 additions and 35 deletions.
86 changes: 51 additions & 35 deletions llvm/lib/SYCLLowerIR/LowerInvokeSimd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,49 @@ void markFunctionAsESIMD(Function *F) {
}
}

void AdjustAddressSpace(Function *F, uint32_t ArgNo, uint32_t ArgAddrSpace) {
Argument *Arg = F->getArg(ArgNo);
for (User *ArgUse : Arg->users()) {
Instruction *Instr = dyn_cast<Instruction>(ArgUse);
if (!Instr)
continue;
const AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(ArgUse);
if (ASC) {
if (ASC->getDestAddressSpace() == ArgAddrSpace)
continue;
}

const CallInst *CI = dyn_cast<CallInst>(ArgUse);
if (CI) {
Function *Callee = CI->getCalledFunction();
if (Callee) {
if (Callee->isDeclaration())
continue;
}

for (uint32_t i = 0; i < CI->getNumOperands(); ++i) {
if (CI->getOperand(i) == Arg) {
AdjustAddressSpace(Callee, i, ArgAddrSpace);
}
}
} else {
for (unsigned int i = 0; i < ArgUse->getNumOperands(); ++i) {
if (ArgUse->getOperand(i) == Arg) {
const Type *Ty = ArgUse->getOperand(i)->getType();

PointerType *NPT = PointerType::get(Ty->getContext(), ArgAddrSpace);

auto *NewInstr = new AddrSpaceCastInst(ArgUse->getOperand(i), NPT);
NewInstr->insertBefore(Instr);
NewInstr->setDebugLoc(Instr->getDebugLoc());

ArgUse->setOperand(i, NewInstr);
}
}
}
}
}

// Process 'invoke_simd(sub_group_obj, f, spmd_args...);' call.
//
// If f is a function name or a function pointer, this call is lowered into
Expand Down Expand Up @@ -325,33 +368,10 @@ bool processInvokeSimdCall(CallInst *InvokeSimd,
if (!SimdF->isDeclaration()) {
const SmallDenseMap<uint32_t, uint32_t> &ArgMap = ArgAddrSpaceMap[SimdF];
for (const auto &MapEntry : ArgMap) {
const uint32_t ArgNo = MapEntry.first;
const uint32_t ArgAddrSpace = MapEntry.second;

Argument *Arg = SimdF->getArg(ArgNo);
for (User *ArgUse : Arg->users()) {
Instruction *Instr = dyn_cast<Instruction>(ArgUse);
if (!Instr)
continue;
const AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(ArgUse);
if (ASC) {
if (ASC->getDestAddressSpace() == ArgAddrSpace)
continue;
}
for (unsigned int i = 0; i < ArgUse->getNumOperands(); ++i) {
if (ArgUse->getOperand(i) == Arg) {
const Type *Ty = ArgUse->getOperand(i)->getType();
uint32_t ArgNo = MapEntry.first;
uint32_t ArgAddrSpace = MapEntry.second;

PointerType *NPT = PointerType::get(Ty->getContext(), ArgAddrSpace);

auto *NewInstr = new AddrSpaceCastInst(ArgUse->getOperand(i), NPT);
NewInstr->insertBefore(Instr);
NewInstr->setDebugLoc(Instr->getDebugLoc());

ArgUse->setOperand(i, NewInstr);
}
}
}
AdjustAddressSpace(SimdF, ArgNo, ArgAddrSpace);
}
}

Expand Down Expand Up @@ -477,15 +497,11 @@ PreservedAnalyses SYCLLowerInvokeSimdPass::run(Module &M,
for (uint32_t i = 2; i < CI->arg_size(); ++i) {
const Value *Arg = CI->getArgOperand(i);
if (Arg->getType()->isPointerTy()) {
uint32_t AddressSpace = Arg->getType()->getPointerAddressSpace();
if (AddressSpace == 4) {
const AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(Arg);
if (!ASC)
continue;

AddressSpace =
ASC->getOperand(0)->getType()->getPointerAddressSpace();
}
const AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(Arg);
if (!ASC)
continue;
uint32_t AddressSpace =
ASC->getOperand(0)->getType()->getPointerAddressSpace();
ArgumentMap[i - 2] = AddressSpace;
}
}
Expand Down

0 comments on commit 6dcd8c7

Please sign in to comment.