diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index 32b30bd8f9b4..6491fa032798 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -162,12 +162,20 @@ def CFNeg : SubRoutine<(Op (Op $re, $im):$z), (FNeg $re), (FNeg $im) )>; + +def Conj : SubRoutine<(Op (Op $re, $im):$z), + (ArrayRet + $re, + (FNeg $im) + )>; + def CFExp : SubRoutine<(Op (Op $re, $im):$z), (ArrayRet (FMul (FExp $re):$exp, (FCos $im)), (FMul $exp, (FSin $im)) )>; + // Same function as the one being called def SameFunc { } @@ -826,9 +834,10 @@ def : CallPattern<(Op (Op $x, $y):$z), def : CallPattern<(Op (Op $x, $y):$z), ["cmplx_inv"], [ - (CFDiv (CFNeg (DiffeRet)), (CFMul $z, $z)), + // Reverse mode needs to return the conjugate + (Conj (CFDiv (CFNeg (Conj (DiffeRet))), (CFMul $z, $z))), ], - (ForwardFromSummedReverse), + (CFDiv (CFNeg (Shadow $z)), (CFMul $z, $z)), [ReadNone, NoUnwind] >;