Skip to content

Commit

Permalink
Fix with non jlvalue copy
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 27, 2023
1 parent 14e899d commit 5181993
Showing 1 changed file with 51 additions and 11 deletions.
62 changes: 51 additions & 11 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1587,7 +1587,6 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) {
#endif
IRBuilder<> EB(&NewF->getEntryBlock().front());
auto AL = EB.CreateAlloca(T, 0, "stack_roots");
EB.CreateStore(Constant::getNullValue(T), AL);
arg->replaceAllUsesWith(AL);
delete arg;
}
Expand All @@ -1607,7 +1606,6 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) {
Value *val = UndefValue::get(AT);
for (size_t j = 0; j < AT->getNumElements(); j++) {
auto AL = EB.CreateAlloca(T, 0, "stack_roots_v");
EB.CreateStore(Constant::getNullValue(T), AL);
val = EB.CreateInsertValue(val, AL, j);
}
arg->replaceAllUsesWith(val);
Expand All @@ -1626,7 +1624,6 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) {
Value *sret = nullptr;
if (sretTy) {
sret = EB.CreateAlloca(sretTy, 0, "stack_sret");
EB.CreateStore(Constant::getNullValue(sretTy), sret);
vals.push_back(sret);
NewAttrs = NewAttrs.addAttribute(
F->getContext(), AttributeList::FirstArgIndex + nexti,
Expand All @@ -1636,7 +1633,6 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) {
AllocaInst *roots = nullptr;
if (roots_AT) {
roots = EB.CreateAlloca(roots_AT, 0, "stack_roots_AT");
EB.CreateStore(Constant::getNullValue(roots_AT), roots);
vals.push_back(roots);
NewAttrs = NewAttrs.addAttribute(

Expand Down Expand Up @@ -1681,21 +1677,65 @@ void EnzymeFixupJuliaCallingConvention(LLVMValueRef F_C) {
sretCount++;
}

std::function<void(Type *, Value *, Value *, ArrayRef<int>, int, Type *)>
copyNonJLValue = [&](Type *curType, Value *out, Value *in,
ArrayRef<int> inds, int sretCount, Type *ptrTy) {
if (auto PT = dyn_cast<PointerType>(curType)) {
if (PT->getAddressSpace() == 10)
return;
}

if (auto AT = dyn_cast<ArrayType>(curType)) {
for (size_t i = 0; i < AT->getNumElements(); i++) {
SmallVector<int> next(inds.begin(), inds.end());
next.push_back(i);
copyNonJLValue(AT->getElementType(), out, in, next, sretCount,
ptrTy);
}
return;
}
if (auto ST = dyn_cast<StructType>(curType)) {
for (size_t i = 0; i < ST->getNumElements(); i++) {
SmallVector<int> next(inds.begin(), inds.end());
next.push_back(i);
copyNonJLValue(ST->getElementType(i), out, in, next, sretCount,
ptrTy);
}
return;
}

SmallVector<Value *> ininds;
SmallVector<Value *> outinds;
auto c0 = ConstantInt::get(B.getInt64Ty(), 0);
ininds.push_back(c0);
outinds.push_back(c0);
if (sretCount > 0)
outinds.push_back(ConstantInt::get(B.getInt64Ty(), sretCount));
for (auto v : inds) {
ininds.push_back(ConstantInt::get(B.getInt64Ty(), v));
outinds.push_back(ConstantInt::get(B.getInt64Ty(), v));
}
if (outinds.size() > 0)
out = B.CreateInBoundsGEP(sretTy, out, outinds);
if (ininds.size() > 0)
in = B.CreateInBoundsGEP(ptrTy, in, ininds);

auto ld = B.CreateLoad(curType, in);
B.CreateStore(ld, out);
};

for (Value *ptr : sret_vals) {
auto gep =
ST ? B.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount) : sret;
auto ld = B.CreateLoad(Types[sretCount], ptr);
B.CreateStore(ld, gep);
copyNonJLValue(Types[sretCount], sret, ptr, {}, ST ? sretCount : -1,
Types[sretCount]);
sretCount++;
}
for (Value *ptr_v : sretv_vals) {
auto AT = cast<ArrayType>(ptr_v->getType());
for (size_t j = 0; j < AT->getNumElements(); j++) {
auto gep = ST ? B.CreateConstInBoundsGEP2_32(ST, sret, 0, sretCount + j)
: sret;
auto ptr = GradientUtils::extractMeta(B, ptr_v, j);
copyNonJLValue(Types[sretCount], sret, ptr, {},
ST ? (sretCount + j) : -1, Types[sretCount]);
auto ld = B.CreateLoad(Types[sretCount], ptr);
B.CreateStore(ld, gep);
}
sretCount += AT->getNumElements();
}
Expand Down

0 comments on commit 5181993

Please sign in to comment.