Skip to content

Commit

Permalink
Fix cuda realloc (#2091)
Browse files Browse the repository at this point in the history
* Fix cuda realloc

* fix
  • Loading branch information
wsmoses authored Sep 29, 2024
1 parent cc98d92 commit 680d8bc
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 46 deletions.
13 changes: 13 additions & 0 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -656,11 +656,18 @@ OldAllocationSize(Value *Ptr, CallInst *Loc, Function *NewF, IntegerType *T,
}
if (success)
continue;

auto v2 = simplifyLoad(LI);
if (v2) {
todo.push_back({v2, next.second});
continue;
}
}

EmitFailure("DynamicReallocSize", Loc->getDebugLoc(), Loc,
"could not statically determine size of realloc ", *Loc,
" - because of - ", *next.first);
return AI;

std::string allocName;
switch (llvm::Triple(NewF->getParent()->getTargetTriple()).getOS()) {
Expand Down Expand Up @@ -833,6 +840,12 @@ void PreProcessCache::ReplaceReallocs(Function *NewF, bool mem2reg) {
if (mem2reg) {
auto PA = PromotePass().run(*NewF, FAM);
FAM.invalidate(*NewF, PA);
#if !defined(FLANG)
PA = GVNPass().run(*NewF, FAM);
#else
PA = GVN().run(*NewF, FAM);
#endif
FAM.invalidate(*NewF, PA);
}

SmallVector<CallInst *, 4> ToConvert;
Expand Down
11 changes: 10 additions & 1 deletion enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@
#include "llvm/IR/Type.h"
#include "llvm/IR/Verifier.h"

#if LLVM_VERSION_MAJOR >= 16
#include "llvm/TargetParser/Triple.h"
#else
#include "llvm/ADT/Triple.h"
#endif

#include "llvm-c/Core.h"

#include "LibraryFuncs.h"
Expand Down Expand Up @@ -203,7 +209,10 @@ Function *getOrInsertExponentialAllocator(Module &M, Function *newFunc,
ConstantInt::get(next->getType(), 0),
B.CreateLShr(next, ConstantInt::get(next->getType(), 1)));

if (!custom) {
auto Arch = llvm::Triple(M.getTargetTriple()).getArch();
bool forceMalloc = Arch == Triple::nvptx || Arch == Triple::nvptx64;

if (!custom && !forceMalloc) {
auto reallocF = M.getOrInsertFunction("realloc", allocType, allocType,
Type::getInt64Ty(M.getContext()));

Expand Down
14 changes: 6 additions & 8 deletions enzyme/test/Enzyme/ReverseMode/sharedcachefwd.ll
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,11 @@ attributes #6 = { nounwind }
; CHECK-NEXT: call void @llvm.nvvm.barrier0()
; CHECK-NEXT: %[[v9:.+]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
; CHECK-NEXT: %mul = shl i32 %[[v9]], 4
; CHECK-NEXT: %[[v10:.+]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
; CHECK-NEXT: %add = add i32 %mul, %[[v10]]
; CHECK-NEXT: %add = add i32 %mul, %1
; CHECK-NEXT: %conv = zext i32 %add to i64
; CHECK-NEXT: %[[v11:.+]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
; CHECK-NEXT: %mul3 = shl i32 %[[v11]], 4
; CHECK-NEXT: %[[v12:.+]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
; CHECK-NEXT: %add5 = add i32 %mul3, %[[v12]]
; CHECK-NEXT: %add5 = add i32 %mul3, %0
; CHECK-NEXT: %conv6 = zext i32 %add5 to i64
; CHECK-NEXT: %[[v13:.+]] = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
; CHECK-NEXT: %conv8 = zext i32 %[[v13]] to i64
Expand All @@ -212,10 +210,10 @@ attributes #6 = { nounwind }

; CHECK: for.body.lr.ph:
; CHECK-NEXT: %mul9 = mul i64 %conv, %n
; CHECK-NEXT: %conv13 = zext i32 %[[v12]] to i64
; CHECK-NEXT: %conv13 = zext i32 %0 to i64
; CHECK-NEXT: %add11 = add i64 %mul9, %conv13
; CHECK-NEXT: %mul15 = mul i64 %n, %n
; CHECK-NEXT: %idxprom = zext i32 %[[v10]] to i64
; CHECK-NEXT: %idxprom = zext i32 %1 to i64
; CHECK-NEXT: %arrayidx2195 = getelementptr inbounds [16 x [16 x float]], [16 x [16 x float]] addrspace(3)* @_ZZ22gpu_square_matrix_multPfS_S_mE6tile_a, i64 0, i64 %idxprom, i64 %conv13
; CHECK-NEXT: %arrayidx21 = addrspacecast float addrspace(3)* %arrayidx2195 to float*
; CHECK-NEXT: %arrayidx4097 = getelementptr inbounds [16 x [16 x float]], [16 x [16 x float]] addrspace(3)* @_ZZ22gpu_square_matrix_multPfS_S_mE6tile_b, i64 0, i64 %idxprom, i64 %conv13
Expand Down Expand Up @@ -403,11 +401,11 @@ attributes #6 = { nounwind }
; CHECK: invertfor.body44_phimerge_phimerge:
; CHECK-NEXT: %[[v33:.+]] = phi {{(fast )?}}float [ %[[_unwrap51]], %invertfor.body44_phimerge_phirc ], [ 0.000000e+00, %invertfor.body44_phimerge ]
; CHECK-NEXT: %[[m1diffe:.+]] = fmul fast float %"add56'de.1", %[[v33]]
; CHECK-NEXT: %[[conv13_unwrap53]] = zext i32 %[[v12]] to i64
; CHECK-NEXT: %[[conv13_unwrap53]] = zext i32 %0 to i64
; CHECK-NEXT: %"arrayidx54101'ipg_unwrap" = getelementptr inbounds [16 x [16 x float]], [16 x [16 x float]] addrspace(3)* @_ZZ22gpu_square_matrix_multPfS_S_mE6tile_b_shadow, i64 0, i64 %[[idxprom_unwrap20]], i64 %[[conv13_unwrap53]]
; CHECK-NEXT: %"arrayidx54'ipc_unwrap" = addrspacecast float addrspace(3)* %"arrayidx54101'ipg_unwrap" to float*
; CHECK-NEXT: %{{.+}} = atomicrmw fadd float* %"arrayidx54'ipc_unwrap", float %[[m1diffe]] monotonic
; CHECK-NEXT: %[[idxprom_unwrap54]] = zext i32 %[[v10]] to i64
; CHECK-NEXT: %[[idxprom_unwrap54]] = zext i32 %1 to i64
; CHECK-NEXT: %"arrayidx4999'ipg_unwrap" = getelementptr inbounds [16 x [16 x float]], [16 x [16 x float]] addrspace(3)* @_ZZ22gpu_square_matrix_multPfS_S_mE6tile_a_shadow, i64 0, i64 %[[idxprom_unwrap54]], i64 %[[idxprom_unwrap20]]
; CHECK-NEXT: %"arrayidx49'ipc_unwrap" = addrspacecast float addrspace(3)* %"arrayidx4999'ipg_unwrap" to float*
; CHECK-NEXT: %{{.+}} = atomicrmw fadd float* %"arrayidx49'ipc_unwrap", float %[[m0diffe]] monotonic
Expand Down
6 changes: 2 additions & 4 deletions enzyme/test/Enzyme/ReverseMode/sharedmem.ll
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,10 @@ attributes #4 = { nounwind }
; CHECK: invertbb: ; preds = %shblock, %bb
; CHECK-NEXT: call void @llvm.nvvm.barrier0()
; CHECK-NEXT: %tmp = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y() #{{.*}}, !range !12
; CHECK-NEXT: %tmp4 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.y() #{{.*}}, !range !13
; CHECK-NEXT: %tmp5 = add nuw nsw i32 %tmp4, %tmp
; CHECK-NEXT: %tmp5 = add nuw nsw i32 %1, %tmp
; CHECK-NEXT: %tmp6 = zext i32 %tmp5 to i64
; CHECK-NEXT: %tmp7 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #{{.*}}, !range !14
; CHECK-NEXT: %tmp8 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() #{{.*}}, !range !13
; CHECK-NEXT: %tmp9 = add nuw i32 %tmp8, %tmp7
; CHECK-NEXT: %tmp9 = add nuw i32 %0, %tmp7
; CHECK-NEXT: %tmp10 = zext i32 %tmp9 to i64
; CHECK-NEXT: %tmp11 = mul i64 %tmp6, %arg3
; CHECK-NEXT: %tmp12 = add i64 %tmp11, %tmp10
Expand Down
64 changes: 31 additions & 33 deletions enzyme/test/Enzyme/ReverseModeVector/vecloadatomic.ll
Original file line number Diff line number Diff line change
Expand Up @@ -71,39 +71,37 @@ entry:
; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg)
; CHECK-NEXT: %"i3'de" = alloca [2 x i64]
; CHECK-NEXT: store [2 x i64] zeroinitializer, [2 x i64]* %"i3'de"
; CHECK-NEXT: %0 = extractvalue [2 x i64 addrspace(13)*] %"i7'", 0
; CHECK-NEXT: %1 = load i64, i64 addrspace(13)* %0
; CHECK-NEXT: %2 = extractvalue [2 x i64 addrspace(13)*] %"i7'", 1
; CHECK-NEXT: %3 = load i64, i64 addrspace(13)* %2
; CHECK-NEXT: %4 = extractvalue [2 x i64 addrspace(13)*] %"i7'", 0
; CHECK-NEXT: store i64 0, i64 addrspace(13)* %4
; CHECK-NEXT: %5 = extractvalue [2 x i64 addrspace(13)*] %"i7'", 1
; CHECK-NEXT: store i64 0, i64 addrspace(13)* %5
; CHECK-NEXT: %6 = getelementptr inbounds [2 x i64], [2 x i64]* %"i3'de", i32 0, i32 0
; CHECK-NEXT: %7 = load i64, i64* %6
; CHECK-NEXT: %8 = bitcast i64 %7 to double
; CHECK-NEXT: %9 = bitcast i64 %1 to double
; CHECK-NEXT: %10 = fadd fast double %8, %9
; CHECK-NEXT: %11 = bitcast double %10 to i64
; CHECK-NEXT: store i64 %11, i64* %6
; CHECK-NEXT: %12 = getelementptr inbounds [2 x i64], [2 x i64]* %"i3'de", i32 0, i32 1
; CHECK-NEXT: %13 = load i64, i64* %12
; CHECK-NEXT: %14 = bitcast i64 %13 to double
; CHECK-NEXT: %15 = bitcast i64 %3 to double
; CHECK-NEXT: %16 = fadd fast double %14, %15
; CHECK-NEXT: %17 = bitcast double %16 to i64
; CHECK-NEXT: store i64 %17, i64* %12
; CHECK-NEXT: %18 = load [2 x i64], [2 x i64]* %"i3'de"
; CHECK-NEXT: %[[i0:.+]] = extractvalue [2 x i64 addrspace(13)*] %"i7'", 0
; CHECK-NEXT: %[[i1:.+]] = load i64, i64 addrspace(13)* %[[i0]]
; CHECK-NEXT: %[[i2:.+]] = extractvalue [2 x i64 addrspace(13)*] %"i7'", 1
; CHECK-NEXT: %[[i3:.+]] = load i64, i64 addrspace(13)* %[[i2]]
; CHECK-NEXT: store i64 0, i64 addrspace(13)* %[[i0]]
; CHECK-NEXT: store i64 0, i64 addrspace(13)* %[[i2]]
; CHECK-NEXT: %[[i6:.+]] = getelementptr inbounds [2 x i64], [2 x i64]* %"i3'de", i32 0, i32 0
; CHECK-NEXT: %[[i7:.+]] = load i64, i64* %[[i6]]
; CHECK-NEXT: %[[i8:.+]] = bitcast i64 %[[i7]] to double
; CHECK-NEXT: %[[i9:.+]] = bitcast i64 %[[i1]] to double
; CHECK-NEXT: %[[i10:.+]] = fadd fast double %[[i8]], %[[i9]]
; CHECK-NEXT: %[[i11:.+]] = bitcast double %[[i10]] to i64
; CHECK-NEXT: store i64 %[[i11]], i64* %[[i6]]
; CHECK-NEXT: %[[i12:.+]] = getelementptr inbounds [2 x i64], [2 x i64]* %"i3'de", i32 0, i32 1
; CHECK-NEXT: %[[i13:.+]] = load i64, i64* %[[i12]]
; CHECK-NEXT: %[[i14:.+]] = bitcast i64 %[[i13]] to double
; CHECK-NEXT: %[[i15:.+]] = bitcast i64 %[[i3]] to double
; CHECK-NEXT: %[[i16:.+]] = fadd fast double %[[i14]], %[[i15]]
; CHECK-NEXT: %[[i17:.+]] = bitcast double %[[i16]] to i64
; CHECK-NEXT: store i64 %[[i17]], i64* %[[i12]]
; CHECK-NEXT: %[[i18:.+]] = load [2 x i64], [2 x i64]* %"i3'de"
; CHECK-NEXT: store [2 x i64] zeroinitializer, [2 x i64]* %"i3'de"
; CHECK-NEXT: %19 = extractvalue [2 x i64 addrspace(12)*] %"i2'", 0
; CHECK-NEXT: %20 = bitcast i64 addrspace(12)* %19 to double addrspace(12)*
; CHECK-NEXT: %21 = extractvalue [2 x i64 addrspace(12)*] %"i2'", 1
; CHECK-NEXT: %22 = bitcast i64 addrspace(12)* %21 to double addrspace(12)*
; CHECK-NEXT: %23 = extractvalue [2 x i64] %18, 0
; CHECK-NEXT: %24 = bitcast i64 %23 to double
; CHECK-NEXT: %25 = extractvalue [2 x i64] %18, 1
; CHECK-NEXT: %26 = bitcast i64 %25 to double
; CHECK-NEXT: %27 = atomicrmw fadd double addrspace(12)* %20, double %24 monotonic
; CHECK-NEXT: %28 = atomicrmw fadd double addrspace(12)* %22, double %26 monotonic
; CHECK-NEXT: %[[i19:.+]] = extractvalue [2 x i64 addrspace(12)*] %"i2'", 0
; CHECK-NEXT: %[[i20:.+]] = bitcast i64 addrspace(12)* %[[i19]] to double addrspace(12)*
; CHECK-NEXT: %[[i21:.+]] = extractvalue [2 x i64 addrspace(12)*] %"i2'", 1
; CHECK-NEXT: %[[i22:.+]] = bitcast i64 addrspace(12)* %[[i21]] to double addrspace(12)*
; CHECK-NEXT: %[[i23:.+]] = extractvalue [2 x i64] %[[i18]], 0
; CHECK-NEXT: %[[i24:.+]] = bitcast i64 %[[i23]] to double
; CHECK-NEXT: %[[i25:.+]] = extractvalue [2 x i64] %[[i18]], 1
; CHECK-NEXT: %[[i26:.+]] = bitcast i64 %[[i25]] to double
; CHECK-NEXT: %[[i27:.+]] = atomicrmw fadd double addrspace(12)* %[[i20]], double %[[i24]] monotonic
; CHECK-NEXT: %[[i28:.+]] = atomicrmw fadd double addrspace(12)* %[[i22]], double %[[i26]] monotonic
; CHECK-NEXT: ret void
; CHECK-NEXT: }

0 comments on commit 680d8bc

Please sign in to comment.