diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index 8708d0d86aa3..1bfc6d34a37c 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -1313,28 +1313,38 @@ void EnzymeFixupBatchedJuliaCallingConvention(LLVMValueRef F_C) { AttributeList NewAttrs; SmallVector types; - bool legal = true; + SmallSet changed; for (auto pair : llvm::enumerate(FT->params())) { auto T = pair.value(); + auto i = pair.index(); + bool sretv = false; + for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i)) { + if (attr.isStringAttribute() && + attr.getKindAsString() == "enzyme_sret_v") { + sretv = true; + NewAttrs = NewAttrs.addAttribute( + F->getContext(), AttributeList::FirstArgIndex + types.size(), + Attribute::get(F->getContext(), "enzyme_sret")); + } else { + NewAttrs = NewAttrs.addAttribute( + F->getContext(), AttributeList::FirstArgIndex + types.size(), attr); + } + } if (auto AT = dyn_cast(T)) { if (auto PT = dyn_cast(AT->getElementType())) { auto AS = PT->getAddressSpace(); - if (AS == 11 || AS == 12 || AS == 13) { - legal = false; + if (AS == 11 || AS == 12 || AS == 13 || sretv) { for (unsigned i = 0; i < AT->getNumElements(); i++) { types.push_back(PT); } + changed.insert(i); continue; } } } - auto i = pair.index(); - for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i)) - NewAttrs = NewAttrs.addAttribute( - F->getContext(), AttributeList::FirstArgIndex + types.size(), attr); types.push_back(T); } - if (legal) + if (changed.size() == 0) return; for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex)) @@ -1362,21 +1372,18 @@ void EnzymeFixupBatchedJuliaCallingConvention(LLVMValueRef F_C) { for (Argument &I : F->args()) { auto T = I.getType(); if (auto AT = dyn_cast(T)) { - if (auto PT = dyn_cast(AT->getElementType())) { - auto AS = PT->getAddressSpace(); - if (AS == 11 || AS == 12 || AS == 13) { - Value *V = UndefValue::get(T); - for (unsigned i = 0; i < AT->getNumElements(); i++) { - DestI->setName(I.getName() + "." + - std::to_string(i)); // Copy the name over... - unsigned idx[1] = {i}; - auto IV = InsertValueInst::Create(V, (llvm::Value *)&*DestI++, idx); - toInsert.push_back(IV); - V = IV; - } - VMap[&I] = V; - continue; + if (changed.count(I.getArgNo())) { + Value *V = UndefValue::get(T); + for (unsigned i = 0; i < AT->getNumElements(); i++) { + DestI->setName(I.getName() + "." + + std::to_string(i)); // Copy the name over... + unsigned idx[1] = {i}; + auto IV = InsertValueInst::Create(V, (llvm::Value *)&*DestI++, idx); + toInsert.push_back(IV); + V = IV; } + VMap[&I] = V; + continue; } } DestI->setName(I.getName()); // Copy the name over... @@ -1429,10 +1436,21 @@ void EnzymeFixupBatchedJuliaCallingConvention(LLVMValueRef F_C) { auto T = CI->getArgOperand(j)->getType(); if (auto AT = dyn_cast(T)) { - if (auto PT = dyn_cast(AT->getElementType())) { - auto AS = PT->getAddressSpace(); - if (AS == 11 || AS == 12 || AS == 13) { + if (isa(AT->getElementType())) { + if (changed.count(j)) { + bool sretv = false; + for (auto attr : + Attrs.getAttributes(AttributeList::FirstArgIndex + j)) { + if (attr.isStringAttribute() && + attr.getKindAsString() == "enzyme_sret_v") { + sretv = true; + } + } for (unsigned i = 0; i < AT->getNumElements(); i++) { + if (sretv) + NewAttrs = NewAttrs.addAttribute( + F->getContext(), AttributeList::FirstArgIndex + vals.size(), + Attribute::get(F->getContext(), "enzyme_sret")); vals.push_back( GradientUtils::extractMeta(B, CI->getArgOperand(j), i)); } @@ -1441,9 +1459,19 @@ void EnzymeFixupBatchedJuliaCallingConvention(LLVMValueRef F_C) { } } - for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + j)) - NewAttrs = NewAttrs.addAttribute( - F->getContext(), AttributeList::FirstArgIndex + vals.size(), attr); + for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + j)) { + if (attr.isStringAttribute() && + attr.getKindAsString() == "enzyme_sret_v") { + NewAttrs = NewAttrs.addAttribute( + F->getContext(), AttributeList::FirstArgIndex + vals.size(), + Attribute::get(F->getContext(), "enzyme_sret")); + } else { + NewAttrs = NewAttrs.addAttribute( + F->getContext(), AttributeList::FirstArgIndex + vals.size(), + attr); + } + } + vals.push_back(CI->getArgOperand(j)); }