Skip to content

Commit

Permalink
Fix free fn with additional args (#2055)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Aug 28, 2024
1 parent 6dc74f0 commit f514f82
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 2 deletions.
9 changes: 9 additions & 0 deletions enzyme/Enzyme/CallDerivatives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4140,6 +4140,15 @@ bool AdjointGenerator::handleKnownCallDerivatives(
auto rule = [&args](Value *tofree) { args.push_back(tofree); };
applyChainRule(Builder2, rule, tofree);

#if LLVM_VERSION_MAJOR >= 14
for (size_t i = 1; i < call.arg_size(); i++)
#else
for (size_t i = 1; i < call.getNumArgOperands(); i++)
#endif
{
args.push_back(gutils->getNewFromOriginal(call.getArgOperand(i)));
}

auto frees = Builder2.CreateCall(free->getFunctionType(), free, args);
frees->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc()));

Expand Down
27 changes: 25 additions & 2 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1519,11 +1519,23 @@ Function *getOrInsertCheckedFree(Module &M, CallInst *call, Type *Ty,

std::string name = "__enzyme_checked_free_" + std::to_string(width);

auto callname = getFuncNameFromCall(call);
if (callname != "free")
name += "_" + callname.str();

SmallVector<Type *, 3> types;
types.push_back(Ty);
for (unsigned i = 0; i < width; i++) {
types.push_back(Ty);
}
#if LLVM_VERSION_MAJOR >= 14
for (size_t i = 1; i < call->arg_size(); i++)
#else
for (size_t i = 1; i < call->getNumArgOperands(); i++)
#endif
{
types.push_back(call->getArgOperand(i)->getType());
}

FunctionType *FT =
FunctionType::get(Type::getVoidTy(M.getContext()), types, false);
Expand Down Expand Up @@ -1558,7 +1570,17 @@ Function *getOrInsertCheckedFree(Module &M, CallInst *call, Type *Ty,
Value *isNotEqual = EntryBuilder.CreateICmpNE(primal, first_shadow);
EntryBuilder.CreateCondBr(isNotEqual, free0, end);

CallInst *CI = Free0Builder.CreateCall(FreeTy, Free, {first_shadow});
SmallVector<Value *, 1> args = {first_shadow};
#if LLVM_VERSION_MAJOR >= 14
for (size_t i = 1; i < call->arg_size(); i++)
#else
for (size_t i = 1; i < call->getNumArgOperands(); i++)
#endif
{
args.push_back(F->arg_begin() + width + i);
}

CallInst *CI = Free0Builder.CreateCall(FreeTy, Free, args);
CI->setAttributes(FreeAttributes);
CI->setCallingConv(CallingConvention);

Expand All @@ -1578,7 +1600,8 @@ Function *getOrInsertCheckedFree(Module &M, CallInst *call, Type *Ty,
? Free0Builder.CreateAnd(isNotEqual, checkResult)
: isNotEqual;

CallInst *CI = Free1Builder.CreateCall(FreeTy, Free, {nextShadow});
args[0] = nextShadow;
CallInst *CI = Free1Builder.CreateCall(FreeTy, Free, args);
CI->setAttributes(FreeAttributes);
CI->setCallingConv(CallingConvention);
}
Expand Down
28 changes: 28 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/freefn.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
; RUN: if [ %llvmver -ge 15 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,sroa,instsimplify)" -S | FileCheck %s; fi

define double @_f2(ptr %0, i64 %1) personality ptr null {
call void @__rust_dealloc(ptr %0, i64 %1)
ret double 0.000000e+00
}

declare void @__rust_dealloc(ptr, i64)

declare double @__enzyme_fwddiff(...)

define double @enzyme_opt_helper_0() {
%1 = call double (...) @__enzyme_fwddiff(ptr @_f2, metadata !"enzyme_dup", ptr null, ptr null, metadata !"enzyme_const", i64 0)
ret double 0.000000e+00
}

; CHECK: define internal double @fwddiffe_f2(ptr %0, ptr %"'", i64 %1)
; CHECK-NEXT: call void @__rust_dealloc(ptr %0, i64 %1)
; CHECK-NEXT: %3 = icmp ne ptr %0, %"'"
; CHECK-NEXT: br i1 %3, label %free0.i, label %__enzyme_checked_free_1___rust_dealloc.exit

; CHECK: free0.i:
; CHECK-NEXT: call void @__rust_dealloc(ptr %"'", i64 %1)
; CHECK-NEXT: br label %__enzyme_checked_free_1___rust_dealloc.exit

; CHECK: __enzyme_checked_free_1___rust_dealloc.exit:
; CHECK-NEXT: ret double 0.000000e+00
; CHECK-NEXT: }

0 comments on commit f514f82

Please sign in to comment.