Skip to content

Commit

Permalink
Zero alloca in Julia Calling convention fix (#1506)
Browse files Browse the repository at this point in the history
* Zero alloca in Julia Calling convention fix

* Add names

* Even more names

* Fix with non jlvalue copy
  • Loading branch information
wsmoses authored Oct 27, 2023
1 parent d0c9b9b commit f48835d
Showing 1 changed file with 59 additions and 12 deletions.
71 changes: 59 additions & 12 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1586,7 +1586,8 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) {
T = FT->getParamType(i)->getPointerElementType();
#endif
IRBuilder<> EB(&NewF->getEntryBlock().front());
arg->replaceAllUsesWith(EB.CreateAlloca(T));
auto AL = EB.CreateAlloca(T, 0, "stack_roots");
arg->replaceAllUsesWith(AL);
delete arg;
}
for (auto i : rroots_v) {
Expand All @@ -1604,7 +1605,8 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) {
IRBuilder<> EB(&NewF->getEntryBlock().front());
Value *val = UndefValue::get(AT);
for (size_t j = 0; j < AT->getNumElements(); j++) {
val = EB.CreateInsertValue(val, EB.CreateAlloca(T), j);
auto AL = EB.CreateAlloca(T, 0, "stack_roots_v");
val = EB.CreateInsertValue(val, AL, j);
}
arg->replaceAllUsesWith(val);
delete arg;
Expand All @@ -1621,7 +1623,7 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) {
size_t nexti = 0;
Value *sret = nullptr;
if (sretTy) {
sret = EB.CreateAlloca(sretTy);
sret = EB.CreateAlloca(sretTy, 0, "stack_sret");
vals.push_back(sret);
NewAttrs = NewAttrs.addAttribute(
F->getContext(), AttributeList::FirstArgIndex + nexti,
Expand All @@ -1630,7 +1632,7 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) {
}
AllocaInst *roots = nullptr;
if (roots_AT) {
roots = EB.CreateAlloca(roots_AT);
roots = EB.CreateAlloca(roots_AT, 0, "stack_roots_AT");
vals.push_back(roots);
NewAttrs = NewAttrs.addAttribute(

Expand Down Expand Up @@ -1675,21 +1677,66 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) {
sretCount++;
}

std::function<void(Type *, Value *, Value *, ArrayRef<int>, int, Type *)>
copyNonJLValue = [&](Type *curType, Value *out, Value *in,
ArrayRef<int> inds, int sretCount, Type *ptrTy) {
if (auto PT = dyn_cast<PointerType>(curType)) {
if (PT->getAddressSpace() == 10) {
return;
}
}

if (auto AT = dyn_cast<ArrayType>(curType)) {
for (size_t i = 0; i < AT->getNumElements(); i++) {
SmallVector<int, 1> next(inds.begin(), inds.end());
next.push_back(i);
copyNonJLValue(AT->getElementType(), out, in, next, sretCount,
ptrTy);
}
return;
}
if (auto ST = dyn_cast<StructType>(curType)) {
for (size_t i = 0; i < ST->getNumElements(); i++) {
SmallVector<int, 1> next(inds.begin(), inds.end());
next.push_back(i);
copyNonJLValue(ST->getElementType(i), out, in, next, sretCount,
ptrTy);
}
return;
}

SmallVector<Value *, 1> ininds;
SmallVector<Value *, 1> outinds;
auto c0 = ConstantInt::get(B.getInt64Ty(), 0);
ininds.push_back(c0);
outinds.push_back(c0);
if (sretCount >= 0)
outinds.push_back(ConstantInt::get(B.getInt32Ty(), sretCount));
for (auto v : inds) {
ininds.push_back(ConstantInt::get(B.getInt32Ty(), v));
outinds.push_back(ConstantInt::get(B.getInt32Ty(), v));
}

if (outinds.size() > 1)
out = B.CreateInBoundsGEP(sretTy, out, outinds);
if (ininds.size() > 1)
in = B.CreateInBoundsGEP(ptrTy, in, ininds);

auto ld = B.CreateLoad(curType, in);
B.CreateStore(ld, out);
};

for (Value *ptr : sret_vals) {
auto gep =
ST ? B.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount) : sret;
auto ld = B.CreateLoad(Types[sretCount], ptr);
B.CreateStore(ld, gep);
copyNonJLValue(Types[sretCount], sret, ptr, {}, ST ? sretCount : -1,
Types[sretCount]);
sretCount++;
}
for (Value *ptr_v : sretv_vals) {
auto AT = cast<ArrayType>(ptr_v->getType());
for (size_t j = 0; j < AT->getNumElements(); j++) {
auto gep = ST ? B.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount + j)
: sret;
auto ptr = GradientUtils::extractMeta(B, ptr_v, j);
auto ld = B.CreateLoad(Types[sretCount], ptr);
B.CreateStore(ld, gep);
copyNonJLValue(Types[sretCount], sret, ptr, {},
ST ? (sretCount + j) : -1, Types[sretCount]);
}
sretCount += AT->getNumElements();
}
Expand Down

0 comments on commit f48835d

Please sign in to comment.