Skip to content

Commit

Permalink
Batch expand sret v (#1963)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jul 4, 2024
1 parent 316c064 commit fa33a6a
Showing 1 changed file with 56 additions and 28 deletions.
84 changes: 56 additions & 28 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1313,28 +1313,38 @@ void EnzymeFixupBatchedJuliaCallingConvention(LLVMValueRef F_C) {

AttributeList NewAttrs;
SmallVector<Type *, 1> types;
bool legal = true;
SmallSet<size_t, 1> changed;
for (auto pair : llvm::enumerate(FT->params())) {
auto T = pair.value();
auto i = pair.index();
bool sretv = false;
for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i)) {
if (attr.isStringAttribute() &&
attr.getKindAsString() == "enzyme_sret_v") {
sretv = true;
NewAttrs = NewAttrs.addAttribute(
F->getContext(), AttributeList::FirstArgIndex + types.size(),
Attribute::get(F->getContext(), "enzyme_sret"));
} else {
NewAttrs = NewAttrs.addAttribute(
F->getContext(), AttributeList::FirstArgIndex + types.size(), attr);
}
}
if (auto AT = dyn_cast<ArrayType>(T)) {
if (auto PT = dyn_cast<PointerType>(AT->getElementType())) {
auto AS = PT->getAddressSpace();
if (AS == 11 || AS == 12 || AS == 13) {
legal = false;
if (AS == 11 || AS == 12 || AS == 13 || sretv) {
for (unsigned i = 0; i < AT->getNumElements(); i++) {
types.push_back(PT);
}
changed.insert(i);
continue;
}
}
}
auto i = pair.index();
for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + i))
NewAttrs = NewAttrs.addAttribute(
F->getContext(), AttributeList::FirstArgIndex + types.size(), attr);
types.push_back(T);
}
if (legal)
if (changed.size() == 0)
return;

for (auto attr : Attrs.getAttributes(AttributeList::FunctionIndex))
Expand Down Expand Up @@ -1362,21 +1372,18 @@ void EnzymeFixupBatchedJuliaCallingConvention(LLVMValueRef F_C) {
for (Argument &I : F->args()) {
auto T = I.getType();
if (auto AT = dyn_cast<ArrayType>(T)) {
if (auto PT = dyn_cast<PointerType>(AT->getElementType())) {
auto AS = PT->getAddressSpace();
if (AS == 11 || AS == 12 || AS == 13) {
Value *V = UndefValue::get(T);
for (unsigned i = 0; i < AT->getNumElements(); i++) {
DestI->setName(I.getName() + "." +
std::to_string(i)); // Copy the name over...
unsigned idx[1] = {i};
auto IV = InsertValueInst::Create(V, (llvm::Value *)&*DestI++, idx);
toInsert.push_back(IV);
V = IV;
}
VMap[&I] = V;
continue;
if (changed.count(I.getArgNo())) {
Value *V = UndefValue::get(T);
for (unsigned i = 0; i < AT->getNumElements(); i++) {
DestI->setName(I.getName() + "." +
std::to_string(i)); // Copy the name over...
unsigned idx[1] = {i};
auto IV = InsertValueInst::Create(V, (llvm::Value *)&*DestI++, idx);
toInsert.push_back(IV);
V = IV;
}
VMap[&I] = V;
continue;
}
}
DestI->setName(I.getName()); // Copy the name over...
Expand Down Expand Up @@ -1429,10 +1436,21 @@ void EnzymeFixupBatchedJuliaCallingConvention(LLVMValueRef F_C) {

auto T = CI->getArgOperand(j)->getType();
if (auto AT = dyn_cast<ArrayType>(T)) {
if (auto PT = dyn_cast<PointerType>(AT->getElementType())) {
auto AS = PT->getAddressSpace();
if (AS == 11 || AS == 12 || AS == 13) {
if (isa<PointerType>(AT->getElementType())) {
if (changed.count(j)) {
bool sretv = false;
for (auto attr :
Attrs.getAttributes(AttributeList::FirstArgIndex + j)) {
if (attr.isStringAttribute() &&
attr.getKindAsString() == "enzyme_sret_v") {
sretv = true;
}
}
for (unsigned i = 0; i < AT->getNumElements(); i++) {
if (sretv)
NewAttrs = NewAttrs.addAttribute(
F->getContext(), AttributeList::FirstArgIndex + vals.size(),
Attribute::get(F->getContext(), "enzyme_sret"));
vals.push_back(
GradientUtils::extractMeta(B, CI->getArgOperand(j), i));
}
Expand All @@ -1441,9 +1459,19 @@ void EnzymeFixupBatchedJuliaCallingConvention(LLVMValueRef F_C) {
}
}

for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + j))
NewAttrs = NewAttrs.addAttribute(
F->getContext(), AttributeList::FirstArgIndex + vals.size(), attr);
for (auto attr : Attrs.getAttributes(AttributeList::FirstArgIndex + j)) {
if (attr.isStringAttribute() &&
attr.getKindAsString() == "enzyme_sret_v") {
NewAttrs = NewAttrs.addAttribute(
F->getContext(), AttributeList::FirstArgIndex + vals.size(),
Attribute::get(F->getContext(), "enzyme_sret"));
} else {
NewAttrs = NewAttrs.addAttribute(
F->getContext(), AttributeList::FirstArgIndex + vals.size(),
attr);
}
}

vals.push_back(CI->getArgOperand(j));
}

Expand Down

0 comments on commit fa33a6a

Please sign in to comment.