diff --git a/llvm-spirv/lib/SPIRV/SPIRVUtil.cpp b/llvm-spirv/lib/SPIRV/SPIRVUtil.cpp index 9f5cb04725398..3d306d9091c55 100644 --- a/llvm-spirv/lib/SPIRV/SPIRVUtil.cpp +++ b/llvm-spirv/lib/SPIRV/SPIRVUtil.cpp @@ -2229,8 +2229,9 @@ bool postProcessBuiltinReturningStruct(Function *F) { Builder.SetInsertPoint(CI); SmallVector Users(CI->users()); Value *A = nullptr; + StoreInst *SI = nullptr; for (auto *U : Users) { - if (auto *SI = dyn_cast(U)) { + if ((SI = dyn_cast(U)) != nullptr) { A = SI->getPointerOperand(); InstToRemove.push_back(SI); break; @@ -2253,9 +2254,14 @@ bool postProcessBuiltinReturningStruct(Function *F) { CallInst *NewCI = Builder.CreateCall(NewF, Args, CI->getName()); NewCI->addParamAttr(0, SretAttr); NewCI->setCallingConv(CI->getCallingConv()); - SmallVector CIUsers(CI->users()); - for (auto *CIUser : CIUsers) { - CIUser->replaceUsesOfWith(CI, A); + SmallVector UsersToReplace; + for (auto *U : Users) + if (U != SI) + UsersToReplace.push_back(U); + if (UsersToReplace.size() > 0) { + auto *LI = Builder.CreateLoad(F->getReturnType(), A); + for (auto *U : UsersToReplace) + U->replaceUsesOfWith(CI, LI); } InstToRemove.push_back(CI); } diff --git a/llvm-spirv/test/builtin_returns_struct.spvasm b/llvm-spirv/test/builtin_returns_struct.spvasm new file mode 100644 index 0000000000000..48e4973b732ed --- /dev/null +++ b/llvm-spirv/test/builtin_returns_struct.spvasm @@ -0,0 +1,52 @@ +; REQUIRES: spirv-as +; RUN: spirv-as --target-env spv1.0 -o %t.spv %s +; RUN: spirv-val %t.spv +; RUN: llvm-spirv -r -o - %t.spv | llvm-dis | FileCheck %s + + OpCapability Kernel + OpCapability Addresses + OpCapability Int8 + OpCapability GenericPointer + OpCapability Linkage + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpSource OpenCL_CPP 100000 + OpName %a "a" + OpName %p "p" + OpName %foo "foo" + OpName %e "e" + OpName %math "math" + OpName %ov "ov" + OpDecorate %foo LinkageAttributes "foo" Export + %uint = OpTypeInt 32 0 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar + %5 = OpTypeFunction %uint %uint %_ptr_Generic_uchar + %bool = OpTypeBool + %_struct_7 = OpTypeStruct %uint %uint + %uint_1 = OpConstant %uint 1 + %uchar_42 = OpConstant %uchar 42 + %10 = OpConstantNull %uint + %foo = OpFunction %uint None %5 + %a = OpFunctionParameter %uint + %p = OpFunctionParameter %_ptr_Generic_uchar + %19 = OpLabel + OpBranch %20 + %20 = OpLabel + %e = OpPhi %uint %a %19 %math %21 + %16 = OpIAddCarry %_struct_7 %e %uint_1 + %math = OpCompositeExtract %uint %16 0 + %17 = OpCompositeExtract %uint %16 1 + %ov = OpINotEqual %bool %17 %10 + OpBranchConditional %ov %22 %21 + %21 = OpLabel + OpStore %p %uchar_42 Aligned 1 + OpBranch %20 + %22 = OpLabel + OpReturnValue %math + OpFunctionEnd + +; CHECK: %[[#Var:]] = alloca %structtype, align 8 +; CHECK: call spir_func void @_Z17__spirv_IAddCarryii(ptr sret(%structtype) %[[#Var:]] +; CHECK: %[[#Load:]] = load %structtype, ptr %[[#Var]], align 4 +; CHECK-2: extractvalue %structtype %[[#Load:]]