Skip to content

Commit

Permalink
Fix cuda realloc
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 29, 2024
1 parent 5077629 commit 6e72222
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
13 changes: 13 additions & 0 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -656,11 +656,18 @@ OldAllocationSize(Value *Ptr, CallInst *Loc, Function *NewF, IntegerType *T,
}
if (success)
continue;

auto v2 = simplifyLoad(LI);
if (v2) {
todo.push_back({v2, next.second});
continue;
}
}

EmitFailure("DynamicReallocSize", Loc->getDebugLoc(), Loc,
"could not statically determine size of realloc ", *Loc,
" - because of - ", *next.first);
return AI;

std::string allocName;
switch (llvm::Triple(NewF->getParent()->getTargetTriple()).getOS()) {
Expand Down Expand Up @@ -833,6 +840,12 @@ void PreProcessCache::ReplaceReallocs(Function *NewF, bool mem2reg) {
if (mem2reg) {
auto PA = PromotePass().run(*NewF, FAM);
FAM.invalidate(*NewF, PA);
#if !defined(FLANG)
PA = GVNPass().run(*NewF, FAM);
#else
PA = GVN().run(*NewF, FAM);
#endif
FAM.invalidate(*NewF, PA);
}

SmallVector<CallInst *, 4> ToConvert;
Expand Down
11 changes: 10 additions & 1 deletion enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@
#include "llvm/IR/Type.h"
#include "llvm/IR/Verifier.h"

#if LLVM_VERSION_MAJOR >= 16
#include "llvm/TargetParser/Triple.h"
#else
#include "llvm/ADT/Triple.h"
#endif

#include "llvm-c/Core.h"

#include "LibraryFuncs.h"
Expand Down Expand Up @@ -203,7 +209,10 @@ Function *getOrInsertExponentialAllocator(Module &M, Function *newFunc,
ConstantInt::get(next->getType(), 0),
B.CreateLShr(next, ConstantInt::get(next->getType(), 1)));

if (!custom) {
auto Arch = llvm::Triple(M.getTargetTriple()).getArch();
bool forceMalloc = Arch == Triple::nvptx || Arch == Triple::nvptx64;

if (!custom && !forceMalloc) {
auto reallocF = M.getOrInsertFunction("realloc", allocType, allocType,
Type::getInt64Ty(M.getContext()));

Expand Down

0 comments on commit 6e72222

Please sign in to comment.