Skip to content

Commit

Permalink
Fix reverse mode complex error function (#1770)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Feb 28, 2024
1 parent f7a46fd commit 070601e
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 21 deletions.
12 changes: 6 additions & 6 deletions enzyme/Enzyme/InstructionDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -655,30 +655,30 @@ def ToStruct2 : SubRoutine<(Op (Op $re, $im):$z),
def : CallPattern<(Op $x, $tbd),
["Faddeeva_erf"],
[
(ToStruct2 (CFMul (DiffeRet), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x)))))),
(ToStruct2 (Conj (CFMul (Conj (DiffeRet)), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x))))))),
(InactiveArg) // relerr
],
(ForwardFromSummedReverse),
(ToStruct2 (CFMul (Shadow $x), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x)))))),
[ReadNone, NoUnwind]
>;

def : CallPattern<(Op $x, $tbd),
["Faddeeva_erfi"],
[
(ToStruct2 (CFMul (DiffeRet), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFMul $x, $x))))),
(ToStruct2 (Conj (CFMul (Conj (DiffeRet)), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFMul $x, $x)))))),
(InactiveArg) // relerr
],
(ForwardFromSummedReverse),
(ToStruct2 (CFMul (Shadow $x), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFMul $x, $x))))),
[ReadNone, NoUnwind]
>;

def : CallPattern<(Op $x, $tbd),
["Faddeeva_erfc"],
[
(ToStruct2 (CFMul (DiffeRet), (CFMul (ConstantCFP<"-1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x)))))),
(ToStruct2 (Conj (CFMul (Conj (DiffeRet)), (CFMul (ConstantCFP<"-1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x))))))),
(InactiveArg) // relerr
],
(ForwardFromSummedReverse),
(ToStruct2 (CFMul (Shadow $x), (CFMul (ConstantCFP<"-1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x)))))),
[ReadNone, NoUnwind]
>;

Expand Down
12 changes: 7 additions & 5 deletions enzyme/test/Enzyme/ReverseMode/Faddeeva_erf.ll
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ declare { double, double } @__enzyme_autodiff({ double, double } ({ double, doub

; CHECK: define internal { { double, double } } @diffetester({ double, double } %in, { double, double } %differeturn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %[[a16:.+]] = extractvalue { double, double } %differeturn, 0
; CHECK-NEXT: %[[a17:.+]] = extractvalue { double, double } %differeturn, 1
; CHECK-NEXT: %[[conj:.+]] = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %[[a17]]
; CHECK-NEXT: %[[a0:.+]] = extractvalue { double, double } %in, 0
; CHECK-NEXT: %[[a1:.+]] = extractvalue { double, double } %in, 1
; CHECK-DAG: %[[a2:.+]] = fmul fast double %[[a1]], %[[a1]]
Expand All @@ -36,16 +39,15 @@ declare { double, double } @__enzyme_autodiff({ double, double } ({ double, doub
; CHECK-NEXT: %[[a13:.+]] = fmul fast double %[[a9]], %[[a12]]
; CHECK-NEXT: %[[a14:.+]] = fmul fast double 0x3FF20DD750429B6D, %[[a11]]
; CHECK-NEXT: %[[a15:.+]] = fmul fast double 0x3FF20DD750429B6D, %[[a13]]
; CHECK-NEXT: %[[a16:.+]] = extractvalue { double, double } %differeturn, 0
; CHECK-NEXT: %[[a17:.+]] = extractvalue { double, double } %differeturn, 1
; CHECK-DAG: %[[a19:.+]] = fmul fast double %[[a16]], %[[a14]]
; CHECK-DAG: %[[a18:.+]] = fmul fast double %[[a17]], %[[a15]]
; CHECK-DAG: %[[a18:.+]] = fmul fast double %[[conj]], %[[a15]]
; CHECK-NEXT: %[[a20:.+]] = fsub fast double %[[a19]], %[[a18]]
; CHECK-DAG: %[[a22:.+]] = fmul fast double %[[a16]], %[[a15]]
; CHECK-DAG: %[[a21:.+]] = fmul fast double %[[a14]], %[[a17]]
; CHECK-DAG: %[[a21:.+]] = fmul fast double %[[a14]], %[[conj]]
; CHECK-NEXT: %[[a23:.+]] = fadd fast double %[[a22]], %[[a21]]
; CHECK-NEXT: %[[conj2:.+]] = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %[[a23]]
; CHECK-NEXT: %[[insert5:.+]] = insertvalue { double, double } {{(undef|poison)}}, double %[[a20]], 0
; CHECK-NEXT: %[[insert8:.+]] = insertvalue { double, double } %[[insert5]], double %[[a23]], 1
; CHECK-NEXT: %[[insert8:.+]] = insertvalue { double, double } %[[insert5]], double %[[conj2]], 1
; CHECK-NEXT: %[[a24:.+]] = insertvalue { { double, double } } undef, { double, double } %[[insert8]], 0
; CHECK-NEXT: ret { { double, double } } %[[a24]]
; CHECK-NEXT: }
12 changes: 7 additions & 5 deletions enzyme/test/Enzyme/ReverseMode/Faddeeva_erfc.ll
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ declare { double, double } @__enzyme_autodiff({ double, double } ({ double, doub

; CHECK: define internal { { double, double } } @diffetester({ double, double } %in, { double, double } %differeturn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %[[a16:.+]] = extractvalue { double, double } %differeturn, 0
; CHECK-NEXT: %[[a17:.+]] = extractvalue { double, double } %differeturn, 1
; CHECK-NEXT: %[[conj:.+]] = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %[[a17]]
; CHECK-NEXT: %[[a0:.+]] = extractvalue { double, double } %in, 0
; CHECK-NEXT: %[[a1:.+]] = extractvalue { double, double } %in, 1
; CHECK-DAG: %[[a2:.+]] = fmul fast double %[[a1]], %[[a1]]
Expand All @@ -36,16 +39,15 @@ declare { double, double } @__enzyme_autodiff({ double, double } ({ double, doub
; CHECK-NEXT: %[[a13:.+]] = fmul fast double %[[a9]], %[[a12]]
; CHECK-NEXT: %[[a14:.+]] = fmul fast double 0xBFF20DD750429B6D, %[[a11]]
; CHECK-NEXT: %[[a15:.+]] = fmul fast double 0xBFF20DD750429B6D, %[[a13]]
; CHECK-NEXT: %[[a16:.+]] = extractvalue { double, double } %differeturn, 0
; CHECK-NEXT: %[[a17:.+]] = extractvalue { double, double } %differeturn, 1
; CHECK-DAG: %[[a19:.+]] = fmul fast double %[[a16]], %[[a14]]
; CHECK-DAG: %[[a18:.+]] = fmul fast double %[[a17]], %[[a15]]
; CHECK-DAG: %[[a18:.+]] = fmul fast double %[[conj]], %[[a15]]
; CHECK-NEXT: %[[a20:.+]] = fsub fast double %[[a19]], %[[a18]]
; CHECK-DAG: %[[a22:.+]] = fmul fast double %[[a16]], %[[a15]]
; CHECK-DAG: %[[a21:.+]] = fmul fast double %[[a14]], %[[a17]]
; CHECK-DAG: %[[a21:.+]] = fmul fast double %[[a14]], %[[conj]]
; CHECK-NEXT: %[[a23:.+]] = fadd fast double %[[a22]], %[[a21]]
; CHECK-NEXT: %[[conj2:.+]] = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %[[a23]]
; CHECK-NEXT: %[[insert5:.+]] = insertvalue { double, double } {{(undef|poison)}}, double %[[a20]], 0
; CHECK-NEXT: %[[insert8:.+]] = insertvalue { double, double } %[[insert5]], double %[[a23]], 1
; CHECK-NEXT: %[[insert8:.+]] = insertvalue { double, double } %[[insert5]], double %[[conj2]], 1
; CHECK-NEXT: %[[a24:.+]] = insertvalue { { double, double } } undef, { double, double } %[[insert8]], 0
; CHECK-NEXT: ret { { double, double } } %[[a24]]
; CHECK-NEXT: }
12 changes: 7 additions & 5 deletions enzyme/test/Enzyme/ReverseMode/Faddeeva_erfi.ll
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ declare { double, double } @__enzyme_autodiff({ double, double } ({ double, doub

; CHECK: define internal { { double, double } } @diffetester({ double, double } %in, { double, double } %differeturn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %[[i16:.+]] = extractvalue { double, double } %differeturn, 0
; CHECK-NEXT: %[[i17:.+]] = extractvalue { double, double } %differeturn, 1
; CHECK-NEXT: %[[conj:.+]] = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %[[i17]]
; CHECK-NEXT: %[[a0:.+]] = extractvalue { double, double } %in, 0
; CHECK-NEXT: %[[a1:.+]] = extractvalue { double, double } %in, 1
; CHECK-NEXT: %[[a3:.+]] = fmul fast double %[[a0]], %[[a0]]
Expand All @@ -34,16 +37,15 @@ declare { double, double } @__enzyme_autodiff({ double, double } ({ double, doub
; CHECK-NEXT: %[[i13:.+]] = fmul fast double %[[i9]], %[[i12]]
; CHECK-NEXT: %[[i14:.+]] = fmul fast double 0x3FF20DD750429B6D, %[[i11]]
; CHECK-NEXT: %[[i15:.+]] = fmul fast double 0x3FF20DD750429B6D, %[[i13]]
; CHECK-NEXT: %[[i16:.+]] = extractvalue { double, double } %differeturn, 0
; CHECK-NEXT: %[[i17:.+]] = extractvalue { double, double } %differeturn, 1
; CHECK-NEXT: %[[i19:.+]] = fmul fast double %[[i16]], %[[i14]]
; CHECK-NEXT: %[[i18:.+]] = fmul fast double %[[i17]], %[[i15]]
; CHECK-NEXT: %[[i18:.+]] = fmul fast double %[[conj]], %[[i15]]
; CHECK-NEXT: %[[i20:.+]] = fsub fast double %[[i19]], %[[i18]]
; CHECK-NEXT: %[[i22:.+]] = fmul fast double %[[i16]], %[[i15]]
; CHECK-NEXT: %[[i21:.+]] = fmul fast double %[[i14]], %[[i17]]
; CHECK-NEXT: %[[i21:.+]] = fmul fast double %[[i14]], %[[conj]]
; CHECK-NEXT: %[[i23:.+]] = fadd fast double %[[i22]], %[[i21]]
; CHECK-NEXT: %[[conj2:.+]] = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %[[i23]]
; CHECK-NEXT: %[[insert5:.+]] = insertvalue { double, double } {{(undef|poison)}}, double %[[i20]], 0
; CHECK-NEXT: %[[insert8:.+]] = insertvalue { double, double } %[[insert5]], double %[[i23]], 1
; CHECK-NEXT: %[[insert8:.+]] = insertvalue { double, double } %[[insert5]], double %[[conj2]], 1
; CHECK-NEXT: %[[i24:.+]] = insertvalue { { double, double } } undef, { double, double } %[[insert8]], 0
; CHECK-NEXT: ret { { double, double } } %[[i24]]
; CHECK-NEXT: }

0 comments on commit 070601e

Please sign in to comment.