diff --git a/llvm/lib/SYCLLowerIR/LowerInvokeSimd.cpp b/llvm/lib/SYCLLowerIR/LowerInvokeSimd.cpp index ff402e9c4158f..b8684da375fe4 100644 --- a/llvm/lib/SYCLLowerIR/LowerInvokeSimd.cpp +++ b/llvm/lib/SYCLLowerIR/LowerInvokeSimd.cpp @@ -262,6 +262,49 @@ void markFunctionAsESIMD(Function *F) { } } +void AdjustAddressSpace(Function *F, uint32_t ArgNo, uint32_t ArgAddrSpace) { + Argument *Arg = F->getArg(ArgNo); + for (User *ArgUse : Arg->users()) { + Instruction *Instr = dyn_cast(ArgUse); + if (!Instr) + continue; + const AddrSpaceCastInst *ASC = dyn_cast(ArgUse); + if (ASC) { + if (ASC->getDestAddressSpace() == ArgAddrSpace) + continue; + } + + const CallInst *CI = dyn_cast(ArgUse); + if (CI) { + Function *Callee = CI->getCalledFunction(); + if (Callee) { + if (Callee->isDeclaration()) + continue; + } + + for (uint32_t i = 0; i < CI->getNumOperands(); ++i) { + if (CI->getOperand(i) == Arg) { + AdjustAddressSpace(Callee, i, ArgAddrSpace); + } + } + } else { + for (unsigned int i = 0; i < ArgUse->getNumOperands(); ++i) { + if (ArgUse->getOperand(i) == Arg) { + const Type *Ty = ArgUse->getOperand(i)->getType(); + + PointerType *NPT = PointerType::get(Ty->getContext(), ArgAddrSpace); + + auto *NewInstr = new AddrSpaceCastInst(ArgUse->getOperand(i), NPT); + NewInstr->insertBefore(Instr); + NewInstr->setDebugLoc(Instr->getDebugLoc()); + + ArgUse->setOperand(i, NewInstr); + } + } + } + } +} + // Process 'invoke_simd(sub_group_obj, f, spmd_args...);' call. // // If f is a function name or a function pointer, this call is lowered into @@ -325,33 +368,10 @@ bool processInvokeSimdCall(CallInst *InvokeSimd, if (!SimdF->isDeclaration()) { const SmallDenseMap &ArgMap = ArgAddrSpaceMap[SimdF]; for (const auto &MapEntry : ArgMap) { - const uint32_t ArgNo = MapEntry.first; - const uint32_t ArgAddrSpace = MapEntry.second; - - Argument *Arg = SimdF->getArg(ArgNo); - for (User *ArgUse : Arg->users()) { - Instruction *Instr = dyn_cast(ArgUse); - if (!Instr) - continue; - const AddrSpaceCastInst *ASC = dyn_cast(ArgUse); - if (ASC) { - if (ASC->getDestAddressSpace() == ArgAddrSpace) - continue; - } - for (unsigned int i = 0; i < ArgUse->getNumOperands(); ++i) { - if (ArgUse->getOperand(i) == Arg) { - const Type *Ty = ArgUse->getOperand(i)->getType(); + uint32_t ArgNo = MapEntry.first; + uint32_t ArgAddrSpace = MapEntry.second; - PointerType *NPT = PointerType::get(Ty->getContext(), ArgAddrSpace); - - auto *NewInstr = new AddrSpaceCastInst(ArgUse->getOperand(i), NPT); - NewInstr->insertBefore(Instr); - NewInstr->setDebugLoc(Instr->getDebugLoc()); - - ArgUse->setOperand(i, NewInstr); - } - } - } + AdjustAddressSpace(SimdF, ArgNo, ArgAddrSpace); } } @@ -477,15 +497,11 @@ PreservedAnalyses SYCLLowerInvokeSimdPass::run(Module &M, for (uint32_t i = 2; i < CI->arg_size(); ++i) { const Value *Arg = CI->getArgOperand(i); if (Arg->getType()->isPointerTy()) { - uint32_t AddressSpace = Arg->getType()->getPointerAddressSpace(); - if (AddressSpace == 4) { - const AddrSpaceCastInst *ASC = dyn_cast(Arg); - if (!ASC) - continue; - - AddressSpace = - ASC->getOperand(0)->getType()->getPointerAddressSpace(); - } + const AddrSpaceCastInst *ASC = dyn_cast(Arg); + if (!ASC) + continue; + uint32_t AddressSpace = + ASC->getOperand(0)->getType()->getPointerAddressSpace(); ArgumentMap[i - 2] = AddressSpace; } }