Skip to content

Commit

Permalink
Fix diff use in non-chosen remat (#1301)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jun 26, 2023
1 parent bccd9ff commit c21f77b
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
8 changes: 8 additions & 0 deletions enzyme/Enzyme/DifferentialUseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,14 @@ inline bool is_value_needed_in_reverse(
if (isa<StoreInst>(user) || isa<MemTransferInst>(user) ||
isa<MemSetInst>(user)) {
for (auto pair : gutils->rematerializableAllocations) {
// If caching the outer allocation and have already set that this is
// not needed return early. This is necessary to avoid unnecessarily
// deciding stored values are needed if we have already decided to
// cache the whole allocation.
auto found = seen.find(std::make_pair(pair.first, ValueType::Primal));
if (found != seen.end() && !found->second)
continue;

// Directly consider all the load uses to avoid an illegal inductive
// recurrence. Specifically if we're asking if the alloca is used,
// we'll set it to unused, then check the gep, then here we'll
Expand Down
64 changes: 64 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/rematint.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s

define void @f(i64* %arg, {} addrspace(10)* %arg1) {
bb:
%i7 = alloca i64, align 8
%i5 = call [1 x i64] @a2(i64* %arg, {} addrspace(10)* %arg1)
%i6 = extractvalue [1 x i64] %i5, 0
store i64 %i6, i64* %i7, align 8
call void @a3(i64* %arg, i64* %i7)
ret void
}

declare void @__enzyme_reverse(...)

define void @dsquare(double %arg) {
bb:
call void (...) @__enzyme_reverse(void (i64*, {} addrspace(10)*)* nonnull @f, metadata !"enzyme_dup", i64* undef, i64* undef, metadata !"enzyme_dup", {} addrspace(10)* undef, {} addrspace(10)* undef, i8* null)
ret void
}

define [1 x i64] @a2(i64* %arg, {} addrspace(10)* %arg1) {
bb:
%i5 = load i64, i64* %arg, align 8, !tbaa !5
%i30 = insertvalue [1 x i64] undef, i64 %i5, 0
ret [1 x i64] %i30
}

define void @a3(i64* %arg, i64* nocapture readonly %arg1) {
bb:
ret void
}

!5 = !{!6, !6, i64 0}
!6 = !{!"jtbaa_arraylen", !7, i64 0}
!7 = !{!"jtbaa_array", !8, i64 0}
!8 = !{!"jtbaa", !9, i64 0}
!9 = !{!"jtbaa"}

; CHECK: define internal i8* @augmented_f(i64* %arg, i64* %"arg'", {} addrspace(10)* %arg1, {} addrspace(10)* %"arg1'")
; CHECK-NEXT: bb:
; CHECK-NEXT: %malloccall1 = tail call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8)
; CHECK-NEXT: %tapemem = bitcast i8* %malloccall1 to i8**
; CHECK-NEXT: %malloccall = tail call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8)
; CHECK-NEXT: store i8* %malloccall, i8** %tapemem
; CHECK-NEXT: %i7 = bitcast i8* %malloccall to i64*
; CHECK-NEXT: %i5 = call [1 x i64] @augmented_a2(i64* %arg, i64* %"arg'", {} addrspace(10)* %arg1, {} addrspace(10)* %"arg1'")
; CHECK-NEXT: %i6 = extractvalue [1 x i64] %i5, 0
; CHECK-NEXT: store i64 %i6, i64* %i7, align 8
; CHECK-NEXT: call void @augmented_a3(i64* %arg, i64* %"arg'", i64* %i7)
; CHECK-NEXT: ret i8* %malloccall1
; CHECK-NEXT: }

; CHECK: define internal void @diffef(i64* %arg, i64* %"arg'", {} addrspace(10)* %arg1, {} addrspace(10)* %"arg1'", i8* %tapeArg)
; CHECK-NEXT: bb:
; CHECK-NEXT: %0 = bitcast i8* %tapeArg to i8**
; CHECK-NEXT: %malloccall = load i8*, i8** %0
; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg)
; CHECK-NEXT: %i7 = bitcast i8* %malloccall to i64*
; CHECK-NEXT: call void @diffea3(i64* %arg, i64* %"arg'", i64* %i7)
; CHECK-NEXT: call void @diffea2(i64* %arg, i64* %"arg'", {} addrspace(10)* %arg1, {} addrspace(10)* %"arg1'")
; CHECK-NEXT: call void @free(i8* %malloccall)
; CHECK-NEXT: ret void
; CHECK-NEXT: }

0 comments on commit c21f77b

Please sign in to comment.