From 8d8996dd1e5ded4da4c87ccbb103576a3c52cd15 Mon Sep 17 00:00:00 2001 From: Vyacheslav Levytskyy Date: Tue, 15 Oct 2024 18:42:51 +0200 Subject: [PATCH] [SPIRV] Implement type deduction and reference to function declarations for indirect calls using SPV_INTEL_function_pointers (#111159) This PR improves implementation of SPV_INTEL_function_pointers and type inference for phi-nodes and indirect calls. --- llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp | 18 +- llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 21 ++ llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 257 ++++++++++++++---- llvm/lib/Target/SPIRV/SPIRVUtils.cpp | 14 + llvm/lib/Target/SPIRV/SPIRVUtils.h | 3 + .../fp-simple-hierarchy.ll | 88 ++++++ .../SPV_INTEL_function_pointers/fp_const.ll | 37 ++- .../fp_two_calls.ll | 31 ++- .../CodeGen/SPIRV/instructions/select-phi.ll | 8 +- 9 files changed, 389 insertions(+), 88 deletions(-) create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp index 55b41627802096..8210e20ce5b10e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp @@ -78,6 +78,11 @@ class SPIRVAsmPrinter : public AsmPrinter { void outputExecutionMode(const Module &M); void outputAnnotations(const Module &M); void outputModuleSections(); + bool isHidden() { + return MF->getFunction() + .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME) + .isValid(); + } void emitInstruction(const MachineInstr *MI) override; void emitFunctionEntryLabel() override {} @@ -131,7 +136,7 @@ void SPIRVAsmPrinter::emitFunctionHeader() { TII = ST->getInstrInfo(); const Function &F = MF->getFunction(); - if (isVerbose()) { + if (isVerbose() && !isHidden()) { OutStreamer->getCommentOS() << "-- Begin function " << GlobalValue::dropLLVMManglingEscape(F.getName()) << '\n'; @@ -149,11 +154,18 @@ void SPIRVAsmPrinter::outputOpFunctionEnd() { // Emit OpFunctionEnd at the end of MF and clear BBNumToRegMap. void SPIRVAsmPrinter::emitFunctionBodyEnd() { + // Do not emit anything if it's an internal service function. + if (isHidden()) + return; outputOpFunctionEnd(); MAI->BBNumToRegMap.clear(); } void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) { + // Do not emit anything if it's an internal service function. + if (isHidden()) + return; + MCInst LabelInst; LabelInst.setOpcode(SPIRV::OpLabel); LabelInst.addOperand(MCOperand::createReg(MAI->getOrCreateMBBRegister(MBB))); @@ -162,7 +174,9 @@ void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) { } void SPIRVAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) { - assert(!MBB.empty() && "MBB is empty!"); + // Do not emit anything if it's an internal service function. + if (MBB.empty()) + return; // If it's the first MBB in MF, it has OpFunction and OpFunctionParameter, so // OpLabel should be output after them. diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index 27a9cb0ba9b8c0..f8ce02a13c0f67 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -36,6 +36,13 @@ bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, const Value *Val, ArrayRef VRegs, FunctionLoweringInfo &FLI, Register SwiftErrorVReg) const { + // Ignore if called from the internal service function + if (MIRBuilder.getMF() + .getFunction() + .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME) + .isValid()) + return true; + // Maybe run postponed production of types for function pointers if (IndirectCalls.size() > 0) { produceIndirectPtrTypes(MIRBuilder); @@ -280,6 +287,10 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, const Function &F, ArrayRef> VRegs, FunctionLoweringInfo &FLI) const { + // Discard the internal service function + if (F.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME).isValid()) + return true; + assert(GR && "Must initialize the SPIRV type registry before lowering args."); GR->setCurrentFunc(MIRBuilder.getMF()); @@ -576,6 +587,16 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo); } + // Ignore the call if it's called from the internal service function + if (MIRBuilder.getMF() + .getFunction() + .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME) + .isValid()) { + // insert a no-op + MIRBuilder.buildTrap(); + return true; + } + unsigned CallOp; if (Info.CB->isIndirectCall()) { if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 370df24bc7af9e..8b7e9c48de6c75 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -69,6 +69,7 @@ class SPIRVEmitIntrinsics SPIRVGlobalRegistry *GR = nullptr; Function *F = nullptr; bool TrackConstants = true; + bool HaveFunPtrs = false; DenseMap AggrConsts; DenseMap AggrConstTypes; DenseSet AggrStores; @@ -147,6 +148,10 @@ class SPIRVEmitIntrinsics void replaceWithPtrcasted(Instruction *CI, Type *NewElemTy, Type *KnownElemTy, CallInst *AssignCI); + bool runOnFunction(Function &F); + bool postprocessTypes(); + bool processFunctionPointers(Module &M); + public: static char ID; SPIRVEmitIntrinsics() : ModulePass(ID) { @@ -173,8 +178,6 @@ class SPIRVEmitIntrinsics StringRef getPassName() const override { return "SPIRV emit intrinsics"; } bool runOnModule(Module &M) override; - bool runOnFunction(Function &F); - bool postprocessTypes(); void getAnalysisUsage(AnalysisUsage &AU) const override { ModulePass::getAnalysisUsage(AU); @@ -384,7 +387,8 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeByValueDeep( // Traverse User instructions to deduce an element pointer type of the operand. Type *SPIRVEmitIntrinsics::deduceElementTypeByUsersDeep( Value *Op, std::unordered_set &Visited, bool UnknownElemTypeI8) { - if (!Op || !isPointerTy(Op->getType())) + if (!Op || !isPointerTy(Op->getType()) || isa(Op) || + isa(Op)) return nullptr; if (auto ElemTy = getPointeeType(Op->getType())) @@ -481,12 +485,25 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper( if (isPointerTy(Op->getType())) Ty = deduceElementTypeHelper(Op, Visited, UnknownElemTypeI8); } else if (auto *Ref = dyn_cast(I)) { - for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) { + Type *BestTy = nullptr; + unsigned MaxN = 1; + DenseMap PhiTys; + for (int i = Ref->getNumIncomingValues() - 1; i >= 0; --i) { Ty = deduceElementTypeByUsersDeep(Ref->getIncomingValue(i), Visited, UnknownElemTypeI8); - if (Ty) - break; + if (!Ty) + continue; + auto It = PhiTys.try_emplace(Ty, 1); + if (!It.second) { + ++It.first->second; + if (It.first->second > MaxN) { + MaxN = It.first->second; + BestTy = Ty; + } + } } + if (BestTy) + Ty = BestTy; } else if (auto *Ref = dyn_cast(I)) { for (Value *Op : {Ref->getTrueValue(), Ref->getFalseValue()}) { Ty = deduceElementTypeByUsersDeep(Op, Visited, UnknownElemTypeI8); @@ -642,6 +659,93 @@ static inline Type *getAtomicElemTy(SPIRVGlobalRegistry *GR, Instruction *I, return nullptr; } +// Try to deduce element type for a call base. Returns false if this is an +// indirect function invocation, and true otherwise. +static bool deduceOperandElementTypeCalledFunction( + SPIRVGlobalRegistry *GR, Instruction *I, + SPIRV::InstructionSet::InstructionSet InstrSet, CallInst *CI, + SmallVector> &Ops, Type *&KnownElemTy) { + Function *CalledF = CI->getCalledFunction(); + if (!CalledF) + return false; + std::string DemangledName = + getOclOrSpirvBuiltinDemangledName(CalledF->getName()); + if (DemangledName.length() > 0 && + !StringRef(DemangledName).starts_with("llvm.")) { + auto [Grp, Opcode, ExtNo] = + SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet); + if (Opcode == SPIRV::OpGroupAsyncCopy) { + for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2; ++i) { + Value *Op = CI->getArgOperand(i); + if (!isPointerTy(Op->getType())) + continue; + ++PtrCnt; + if (Type *ElemTy = GR->findDeducedElementType(Op)) + KnownElemTy = ElemTy; // src will rewrite dest if both are defined + Ops.push_back(std::make_pair(Op, i)); + } + } else if (Grp == SPIRV::Atomic || Grp == SPIRV::AtomicFloating) { + if (CI->arg_size() < 2) + return true; + Value *Op = CI->getArgOperand(0); + if (!isPointerTy(Op->getType())) + return true; + switch (Opcode) { + case SPIRV::OpAtomicLoad: + case SPIRV::OpAtomicCompareExchangeWeak: + case SPIRV::OpAtomicCompareExchange: + case SPIRV::OpAtomicExchange: + case SPIRV::OpAtomicIAdd: + case SPIRV::OpAtomicISub: + case SPIRV::OpAtomicOr: + case SPIRV::OpAtomicXor: + case SPIRV::OpAtomicAnd: + case SPIRV::OpAtomicUMin: + case SPIRV::OpAtomicUMax: + case SPIRV::OpAtomicSMin: + case SPIRV::OpAtomicSMax: { + KnownElemTy = getAtomicElemTy(GR, I, Op); + if (!KnownElemTy) + return true; + Ops.push_back(std::make_pair(Op, 0)); + } break; + } + } + } + return true; +} + +// Try to deduce element type for a function pointer. +static void deduceOperandElementTypeFunctionPointer( + SPIRVGlobalRegistry *GR, Instruction *I, CallInst *CI, + SmallVector> &Ops, Type *&KnownElemTy) { + Value *Op = CI->getCalledOperand(); + if (!Op || !isPointerTy(Op->getType())) + return; + Ops.push_back(std::make_pair(Op, std::numeric_limits::max())); + FunctionType *FTy = CI->getFunctionType(); + bool IsNewFTy = false; + SmallVector ArgTys; + for (Value *Arg : CI->args()) { + Type *ArgTy = Arg->getType(); + if (ArgTy->isPointerTy()) + if (Type *ElemTy = GR->findDeducedElementType(Arg)) { + IsNewFTy = true; + ArgTy = TypedPointerType::get(ElemTy, getPointerAddressSpace(ArgTy)); + } + ArgTys.push_back(ArgTy); + } + Type *RetTy = FTy->getReturnType(); + if (I->getType()->isPointerTy()) + if (Type *ElemTy = GR->findDeducedElementType(I)) { + IsNewFTy = true; + RetTy = + TypedPointerType::get(ElemTy, getPointerAddressSpace(I->getType())); + } + KnownElemTy = + IsNewFTy ? FunctionType::get(RetTy, ArgTys, FTy->isVarArg()) : FTy; +} + // If the Instruction has Pointer operands with unresolved types, this function // tries to deduce them. If the Instruction has Pointer operands with known // types which differ from expected, this function tries to insert a bitcast to @@ -747,54 +851,12 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I, KnownElemTy = ElemTy1; Ops.push_back(std::make_pair(Op0, 0)); } - } else if (auto *CI = dyn_cast(I)) { - if (Function *CalledF = CI->getCalledFunction()) { - std::string DemangledName = - getOclOrSpirvBuiltinDemangledName(CalledF->getName()); - if (DemangledName.length() > 0 && - !StringRef(DemangledName).starts_with("llvm.")) { - auto [Grp, Opcode, ExtNo] = - SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet); - if (Opcode == SPIRV::OpGroupAsyncCopy) { - for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2; - ++i) { - Value *Op = CI->getArgOperand(i); - if (!isPointerTy(Op->getType())) - continue; - ++PtrCnt; - if (Type *ElemTy = GR->findDeducedElementType(Op)) - KnownElemTy = ElemTy; // src will rewrite dest if both are defined - Ops.push_back(std::make_pair(Op, i)); - } - } else if (Grp == SPIRV::Atomic || Grp == SPIRV::AtomicFloating) { - if (CI->arg_size() < 2) - return; - Value *Op = CI->getArgOperand(0); - if (!isPointerTy(Op->getType())) - return; - switch (Opcode) { - case SPIRV::OpAtomicLoad: - case SPIRV::OpAtomicCompareExchangeWeak: - case SPIRV::OpAtomicCompareExchange: - case SPIRV::OpAtomicExchange: - case SPIRV::OpAtomicIAdd: - case SPIRV::OpAtomicISub: - case SPIRV::OpAtomicOr: - case SPIRV::OpAtomicXor: - case SPIRV::OpAtomicAnd: - case SPIRV::OpAtomicUMin: - case SPIRV::OpAtomicUMax: - case SPIRV::OpAtomicSMin: - case SPIRV::OpAtomicSMax: { - KnownElemTy = getAtomicElemTy(GR, I, Op); - if (!KnownElemTy) - return; - Ops.push_back(std::make_pair(Op, 0)); - } break; - } - } - } - } + } else if (CallInst *CI = dyn_cast(I)) { + if (!CI->isIndirectCall()) + deduceOperandElementTypeCalledFunction(GR, I, InstrSet, CI, Ops, + KnownElemTy); + else if (HaveFunPtrs) + deduceOperandElementTypeFunctionPointer(GR, I, CI, Ops, KnownElemTy); } // There is no enough info to deduce types or all is valid. @@ -844,7 +906,10 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I, B.getInt32(getPointerAddressSpace(OpTy))}; CallInst *PtrCastI = B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args); - I->setOperand(OpIt.second, PtrCastI); + if (OpIt.second == std::numeric_limits::max()) + dyn_cast(I)->setCalledOperand(PtrCastI); + else + I->setOperand(OpIt.second, PtrCastI); buildAssignPtr(B, KnownElemTy, PtrCastI); } } @@ -1671,6 +1736,82 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) { } } +static FunctionType *getFunctionPointerElemType(Function *F, + SPIRVGlobalRegistry *GR) { + FunctionType *FTy = F->getFunctionType(); + bool IsNewFTy = false; + SmallVector ArgTys; + for (Argument &Arg : F->args()) { + Type *ArgTy = Arg.getType(); + if (ArgTy->isPointerTy()) + if (Type *ElemTy = GR->findDeducedElementType(&Arg)) { + IsNewFTy = true; + ArgTy = TypedPointerType::get(ElemTy, getPointerAddressSpace(ArgTy)); + } + ArgTys.push_back(ArgTy); + } + return IsNewFTy + ? FunctionType::get(FTy->getReturnType(), ArgTys, FTy->isVarArg()) + : FTy; +} + +bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) { + SmallVector Worklist; + for (auto &F : M) { + if (F.isIntrinsic()) + continue; + if (F.isDeclaration()) { + for (User *U : F.users()) { + CallInst *CI = dyn_cast(U); + if (!CI || CI->getCalledFunction() != &F) { + Worklist.push_back(&F); + break; + } + } + } else { + if (F.user_empty()) + continue; + Type *FPElemTy = GR->findDeducedElementType(&F); + if (!FPElemTy) + FPElemTy = getFunctionPointerElemType(&F, GR); + for (User *U : F.users()) { + IntrinsicInst *II = dyn_cast(U); + if (!II || II->arg_size() != 3 || II->getOperand(0) != &F) + continue; + if (II->getIntrinsicID() == Intrinsic::spv_assign_ptr_type || + II->getIntrinsicID() == Intrinsic::spv_ptrcast) { + updateAssignType(II, &F, PoisonValue::get(FPElemTy)); + break; + } + } + } + } + if (Worklist.empty()) + return false; + + std::string ServiceFunName = SPIRV_BACKEND_SERVICE_FUN_NAME; + if (!getVacantFunctionName(M, ServiceFunName)) + report_fatal_error( + "cannot allocate a name for the internal service function"); + LLVMContext &Ctx = M.getContext(); + Function *SF = + Function::Create(FunctionType::get(Type::getVoidTy(Ctx), {}, false), + GlobalValue::PrivateLinkage, ServiceFunName, M); + SF->addFnAttr(SPIRV_BACKEND_SERVICE_FUN_NAME, ""); + BasicBlock *BB = BasicBlock::Create(Ctx, "entry", SF); + IRBuilder<> IRB(BB); + + for (Function *F : Worklist) { + SmallVector Args; + for (const auto &Arg : F->args()) + Args.push_back(PoisonValue::get(Arg.getType())); + IRB.CreateCall(F, Args); + } + IRB.CreateRetVoid(); + + return true; +} + bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { if (Func.isDeclaration()) return false; @@ -1680,6 +1821,10 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { InstrSet = ST.isOpenCLEnv() ? SPIRV::InstructionSet::OpenCL_std : SPIRV::InstructionSet::GLSL_std_450; + if (!F) + HaveFunPtrs = + ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers); + F = &Func; IRBuilder<> B(Func.getContext()); AggrConsts.clear(); @@ -1825,6 +1970,8 @@ bool SPIRVEmitIntrinsics::runOnModule(Module &M) { } Changed |= postprocessTypes(); + if (HaveFunPtrs) + Changed |= processFunctionPointers(M); return Changed; } diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index d204a8ac7975d8..dff33b16b9cfcf 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -598,4 +598,18 @@ MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg) { return MaybeDef; } +bool getVacantFunctionName(Module &M, std::string &Name) { + // It's a bit of paranoia, but still we don't want to have even a chance that + // the loop will work for too long. + constexpr unsigned MaxIters = 1024; + for (unsigned I = 0; I < MaxIters; ++I) { + std::string OrdName = Name + Twine(I).str(); + if (!M.getFunction(OrdName)) { + Name = OrdName; + return true; + } + } + return false; +} + } // namespace llvm diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index f7e8a827c2767f..83e717e6ea58fd 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -341,5 +341,8 @@ inline const Type *unifyPtrType(const Type *Ty) { MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg); +#define SPIRV_BACKEND_SERVICE_FUN_NAME "__spirv_backend_service_fun" +bool getVacantFunctionName(Module &M, std::string &Name); + } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll new file mode 100644 index 00000000000000..d5a8fb3e7baafa --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll @@ -0,0 +1,88 @@ +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s +; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: OpName %[[I9:.*]] "_ZN13BaseIncrement9incrementEPi" +; CHECK-DAG: OpName %[[I29:.*]] "_ZN12IncrementBy29incrementEPi" +; CHECK-DAG: OpName %[[I49:.*]] "_ZN12IncrementBy49incrementEPi" +; CHECK-DAG: OpName %[[I89:.*]] "_ZN12IncrementBy89incrementEPi" + +; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid +; CHECK-DAG: %[[TyArr:.*]] = OpTypeArray +; CHECK-DAG: %[[TyStruct1:.*]] = OpTypeStruct %[[TyArr]] +; CHECK-DAG: %[[TyStruct2:.*]] = OpTypeStruct %[[TyStruct1]] +; CHECK-DAG: %[[TyPtrStruct2:.*]] = OpTypePointer Generic %[[TyStruct2]] +; CHECK-DAG: %[[TyFun:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrStruct2]] %[[#]] +; CHECK-DAG: %[[TyPtrFun:.*]] = OpTypePointer Generic %[[TyFun]] +; CHECK-DAG: %[[TyPtrPtrFun:.*]] = OpTypePointer Generic %[[TyPtrFun]] + +; CHECK: %[[I9]] = OpFunction +; CHECK: %[[I29]] = OpFunction +; CHECK: %[[I49]] = OpFunction +; CHECK: %[[I89]] = OpFunction + +; CHECK: %[[Arg1:.*]] = OpPhi %[[TyPtrStruct2]] +; CHECK: %[[VTbl:.*]] = OpBitcast %[[TyPtrPtrFun]] %[[#]] +; CHECK: %[[FP:.*]] = OpLoad %[[TyPtrFun]] %[[VTbl]] +; CHECK: %[[#]] = OpFunctionPointerCallINTEL %[[TyVoid]] %[[FP]] %[[Arg1]] %[[#]] + +%"cls::id" = type { %"cls::detail::array" } +%"cls::detail::array" = type { [1 x i64] } +%struct.obj_storage_t = type { %"struct.aligned_storage::type" } +%"struct.aligned_storage::type" = type { [8 x i8] } + +@_ZTV12IncrementBy8 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy89incrementEPi to ptr addrspace(4))] }, align 8 +@_ZTV13BaseIncrement = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN13BaseIncrement9incrementEPi to ptr addrspace(4))] }, align 8 +@_ZTV12IncrementBy4 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy49incrementEPi to ptr addrspace(4))] }, align 8 +@_ZTV12IncrementBy2 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy29incrementEPi to ptr addrspace(4))] }, align 8 +@__spirv_BuiltInWorkgroupId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32 +@__spirv_BuiltInGlobalLinearId = external dso_local local_unnamed_addr addrspace(1) constant i64, align 8 +@__spirv_BuiltInWorkgroupSize = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32 + +define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) noundef align 8 %_arg_StorageAcc, ptr noundef byval(%"cls::id") align 8 %_arg_StorageAcc3, i32 noundef %_arg_TestCase, ptr addrspace(1) noundef align 4 %_arg_DataAcc) { +entry: + %r0 = load i64, ptr %_arg_StorageAcc3, align 8 + %add.ptr.i = getelementptr inbounds %struct.obj_storage_t, ptr addrspace(1) %_arg_StorageAcc, i64 %r0 + %arrayidx.ascast.i = addrspacecast ptr addrspace(1) %add.ptr.i to ptr addrspace(4) + %cmp.i = icmp ugt i32 %_arg_TestCase, 3 + br i1 %cmp.i, label %entry.critedge, label %if.end.1 + +entry.critedge: ; preds = %entry + %vtable.i.pre = load ptr addrspace(4), ptr addrspace(4) null, align 8 + br label %exit + +if.end.1: ; preds = %entry + switch i32 %_arg_TestCase, label %if.end.5 [ + i32 0, label %if.end.2 + i32 1, label %if.end.3 + i32 2, label %if.end.4 + ] + +if.end.5: ; preds = %if.end.1 + store ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy8, i64 16), ptr addrspace(1) %add.ptr.i, align 8 + br label %exit + +if.end.4: ; preds = %if.end.1 + store ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy4, i64 16), ptr addrspace(1) %add.ptr.i, align 8 + br label %exit + +if.end.3: ; preds = %if.end.1 + store ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy2, i64 16), ptr addrspace(1) %add.ptr.i, align 8 + br label %exit + +if.end.2: ; preds = %if.end.1 + store ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV13BaseIncrement, i64 16), ptr addrspace(1) %add.ptr.i, align 8 + br label %exit + +exit: ; preds = %if.end.2, %if.end.3, %if.end.4, %if.end.5, %entry.critedge + %vtable.i = phi ptr addrspace(4) [ %vtable.i.pre, %entry.critedge ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy8, i64 16) to i64) to ptr addrspace(4)), %if.end.5 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy4, i64 16) to i64) to ptr addrspace(4)), %if.end.4 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy2, i64 16) to i64) to ptr addrspace(4)), %if.end.3 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV13BaseIncrement, i64 16) to i64) to ptr addrspace(4)), %if.end.2 ] + %retval.0.i = phi ptr addrspace(4) [ null, %entry.critedge ], [ %arrayidx.ascast.i, %if.end.5 ], [ %arrayidx.ascast.i, %if.end.4 ], [ %arrayidx.ascast.i, %if.end.3 ], [ %arrayidx.ascast.i, %if.end.2 ] + %r1 = addrspacecast ptr addrspace(1) %_arg_DataAcc to ptr addrspace(4) + %r2 = load ptr addrspace(4), ptr addrspace(4) %vtable.i, align 8 + tail call spir_func addrspace(4) void %r2(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %retval.0.i, ptr addrspace(4) noundef %r1) + ret void +} + +declare dso_local spir_func void @_ZN13BaseIncrement9incrementEPi(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8), ptr addrspace(4) noundef) +declare dso_local spir_func void @_ZN12IncrementBy29incrementEPi(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8), ptr addrspace(4) noundef) +declare dso_local spir_func void @_ZN12IncrementBy49incrementEPi(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8), ptr addrspace(4) noundef) +declare dso_local spir_func void @_ZN12IncrementBy89incrementEPi(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8), ptr addrspace(4) noundef) diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll index 5f073e95cb68f2..b4faba9a4eb8e3 100644 --- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll @@ -5,30 +5,39 @@ ; CHECK-DAG: OpCapability FunctionPointersINTEL ; CHECK-DAG: OpCapability Int64 ; CHECK: OpExtension "SPV_INTEL_function_pointers" -; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0 + ; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid ; CHECK-DAG: %[[TyInt64:.*]] = OpTypeInt 64 0 -; CHECK-DAG: %[[TyFunFp:.*]] = OpTypeFunction %[[TyVoid]] %[[TyInt64]] -; CHECK-DAG: %[[ConstInt64:.*]] = OpConstant %[[TyInt64]] 42 -; CHECK-DAG: %[[TyPtrFunFp:.*]] = OpTypePointer Function %[[TyFunFp]] -; CHECK-DAG: %[[ConstFunFp:.*]] = OpConstantFunctionPointerINTEL %[[TyPtrFunFp]] %[[DefFunFp:.*]] -; CHECK: %[[FunPtr1:.*]] = OpBitcast %[[#]] %[[ConstFunFp]] -; CHECK: %[[FunPtr2:.*]] = OpLoad %[[#]] %[[FunPtr1]] -; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[FunPtr2]] %[[ConstInt64]] -; CHECK: OpReturn +; CHECK-DAG: %[[TyFun:.*]] = OpTypeFunction %[[TyInt64]] %[[TyInt64]] +; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0 +; CHECK-DAG: %[[TyPtrFun:.*]] = OpTypePointer Function %[[TyFun]] +; CHECK-DAG: %[[ConstFunFp:.*]] = OpConstantFunctionPointerINTEL %[[TyPtrFun]] %[[DefFunFp:.*]] +; CHECK-DAG: %[[TyPtrPtrFun:.*]] = OpTypePointer Function %[[TyPtrFun]] +; CHECK-DAG: %[[TyPtrInt8:.*]] = OpTypePointer Function %[[TyInt8]] +; CHECK-DAG: %[[TyPtrPtrInt8:.*]] = OpTypePointer Function %[[TyPtrInt8]] +; CHECK: OpFunction +; CHECK: %[[Var:.*]] = OpVariable %[[TyPtrPtrInt8]] Function +; CHECK: %[[SAddr:.*]] = OpBitcast %[[TyPtrPtrFun]] %[[Var]] +; CHECK: OpStore %[[SAddr]] %[[ConstFunFp]] +; CHECK: %[[LAddr:.*]] = OpBitcast %[[TyPtrPtrFun]] %[[Var]] +; CHECK: %[[FP:.*]] = OpLoad %[[TyPtrFun]] %[[LAddr]] +; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[FP]] %[[#]] ; CHECK: OpFunctionEnd -; CHECK: %[[DefFunFp]] = OpFunction %[[TyVoid]] None %[[TyFunFp]] + +; CHECK: %[[DefFunFp]] = OpFunction %[[TyInt64]] None %[[TyFun]] target triple = "spir64-unknown-unknown" define spir_kernel void @test() { entry: - %0 = load ptr, ptr @foo - %1 = call i64 %0(i64 42) + %fp = alloca ptr + store ptr @foo, ptr %fp + %tocall = load ptr, ptr %fp + %res = call i64 %tocall(i64 42) ret void } -define void @foo(i64 %a) { +define i64 @foo(i64 %a) { entry: - ret void + ret i64 %a } diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll index c5a2918f92c29e..eb7b1dffaee501 100644 --- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll @@ -5,27 +5,30 @@ ; CHECK-DAG: OpCapability FunctionPointersINTEL ; CHECK-DAG: OpCapability Int64 ; CHECK: OpExtension "SPV_INTEL_function_pointers" -; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0 + +; CHECK-DAG: OpName %[[fp:.*]] "fp" +; CHECK-DAG: OpName %[[data:.*]] "data" +; CHECK-DAG: OpName %[[bar:.*]] "bar" +; CHECK-DAG: OpName %[[test:.*]] "test" ; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid ; CHECK-DAG: %[[TyFloat32:.*]] = OpTypeFloat 32 +; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0 ; CHECK-DAG: %[[TyInt64:.*]] = OpTypeInt 64 0 ; CHECK-DAG: %[[TyPtrInt8:.*]] = OpTypePointer Function %[[TyInt8]] -; CHECK-DAG: %[[TyFunFp:.*]] = OpTypeFunction %[[TyFloat32]] %[[TyPtrInt8]] -; CHECK-DAG: %[[TyFunBar:.*]] = OpTypeFunction %[[TyInt64]] %[[TyPtrInt8]] %[[TyPtrInt8]] -; CHECK-DAG: %[[TyPtrFunFp:.*]] = OpTypePointer Function %[[TyFunFp]] -; CHECK-DAG: %[[TyPtrFunBar:.*]] = OpTypePointer Function %[[TyFunBar]] -; CHECK-DAG: %[[TyFunTest:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrInt8]] %[[TyPtrInt8]] %[[TyPtrInt8]] -; CHECK: %[[FunTest:.*]] = OpFunction %[[TyVoid]] None %[[TyFunTest]] -; CHECK: %[[ArgFp:.*]] = OpFunctionParameter %[[TyPtrInt8]] -; CHECK: %[[ArgData:.*]] = OpFunctionParameter %[[TyPtrInt8]] -; CHECK: %[[ArgBar:.*]] = OpFunctionParameter %[[TyPtrInt8]] -; CHECK: OpFunctionPointerCallINTEL %[[TyFloat32]] %[[ArgFp]] %[[ArgBar]] -; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[ArgBar]] %[[ArgFp]] %[[ArgData]] +; CHECK-DAG: %[[TyFp:.*]] = OpTypeFunction %[[TyFloat32]] %[[TyPtrInt8]] +; CHECK-DAG: %[[TyPtrFp:.*]] = OpTypePointer Function %[[TyFp]] +; CHECK-DAG: %[[TyBar:.*]] = OpTypeFunction %[[TyInt64]] %[[TyPtrFp]] %[[TyPtrInt8]] +; CHECK-DAG: %[[TyPtrBar:.*]] = OpTypePointer Function %[[TyBar]] +; CHECK-DAG: %[[TyTest:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrFp]] %[[TyPtrInt8]] %[[TyPtrBar]] +; CHECK: %[[test]] = OpFunction %[[TyVoid]] None %[[TyTest]] +; CHECK: %[[fp]] = OpFunctionParameter %[[TyPtrFp]] +; CHECK: %[[data]] = OpFunctionParameter %[[TyPtrInt8]] +; CHECK: %[[bar]] = OpFunctionParameter %[[TyPtrBar]] +; CHECK: OpFunctionPointerCallINTEL %[[TyFloat32]] %[[fp]] %[[bar]] +; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[bar]] %[[fp]] %[[data]] ; CHECK: OpReturn ; CHECK: OpFunctionEnd -target triple = "spir64-unknown-unknown" - define spir_kernel void @test(ptr %fp, ptr %data, ptr %bar) { entry: %0 = call spir_func float %fp(ptr %bar) diff --git a/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll b/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll index 3828fe89e60aec..16be7cd3b8db62 100644 --- a/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll +++ b/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll @@ -1,3 +1,6 @@ +; This test case checks how phi-nodes with different operand types select +; a result type. Majority of operands makes it i8* in this case. + ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s ; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s @@ -15,14 +18,13 @@ ; CHECK: %[[Branch1:.*]] = OpLabel ; CHECK: %[[Res1:.*]] = OpVariable %[[StructPtr]] Function +; CHECK: %[[Res1Casted:.*]] = OpBitcast %[[CharPtr]] %[[Res1]] ; CHECK: OpBranchConditional %[[#]] %[[#]] %[[Branch2:.*]] ; CHECK: %[[Res2:.*]] = OpInBoundsPtrAccessChain %[[CharPtr]] %[[#]] %[[#]] -; CHECK: %[[Res2Casted:.*]] = OpBitcast %[[StructPtr]] %[[Res2]] ; CHECK: OpBranchConditional %[[#]] %[[#]] %[[BranchSelect:.*]] ; CHECK: %[[SelectRes:.*]] = OpSelect %[[CharPtr]] %[[#]] %[[#]] %[[#]] -; CHECK: %[[SelectResCasted:.*]] = OpBitcast %[[StructPtr]] %[[SelectRes]] ; CHECK: OpLabel -; CHECK: OpPhi %[[StructPtr]] %[[Res1]] %[[Branch1]] %[[Res2Casted]] %[[Branch2]] %[[SelectResCasted]] %[[BranchSelect]] +; CHECK: OpPhi %[[CharPtr]] %[[Res1Casted]] %[[Branch1]] %[[Res2]] %[[Branch2]] %[[SelectRes]] %[[BranchSelect]] %struct = type { %array } %array = type { [1 x i64] }