Skip to content

Commit

Permalink
add sinh/cosh/tanh support and tests for LLVM19 and newer
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Oct 2, 2024
1 parent 7425c37 commit a8a33d5
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 3 deletions.
5 changes: 5 additions & 0 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2812,6 +2812,11 @@ bool guaranteedDataDependent(Value *z) {
case Intrinsic::sqrt:
case Intrinsic::sin:
case Intrinsic::cos:
#if LLVM_VERSION_MAJOR >= 19
case Intrinsic::sinh:
case Intrinsic::cosh:
case Intrinsic::tanh:
#endif
return guaranteedDataDependent(II->getArgOperand(0));
default:
break;
Expand Down
5 changes: 5 additions & 0 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4180,6 +4180,11 @@ bool GradientUtils::shouldRecompute(const Value *val,
case Intrinsic::sin:
case Intrinsic::cos:
case Intrinsic::exp:
#if LLVM_VERSION_MAJOR >= 19
case Intrinsic::tanh:
case Intrinsic::cosh:
case Intrinsic::sinh:
#endif
case Intrinsic::log:
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
Expand Down
24 changes: 21 additions & 3 deletions enzyme/Enzyme/InstructionDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def : CallPattern<(Op $x, $y),
>;

def : CallPattern<(Op $x),
["tanh"],
["tanh", "", "18"],
[(FDiv (DiffeRet), (FMul(Call<(SameTypesFunc<"cosh">), [ReadNone,NoUnwind]> $x):$c, $c))],
(ForwardFromSummedReverse),
[ReadNone, NoUnwind]
Expand All @@ -355,7 +355,7 @@ def : CallPattern<(Op $x),
>;

def : CallPattern<(Op $x),
["cosh"],
["cosh", "", "18"],
[(FMul (DiffeRet), (Call<(SameTypesFunc<"sinh">), [ReadNone,NoUnwind]> $x))],
(ForwardFromSummedReverse),
[ReadNone, NoUnwind]
Expand All @@ -374,7 +374,7 @@ def : CallPattern<(Op $x),
>;

def : CallPattern<(Op $x),
["sinh"],
["sinh", "", "18"],
[(FMul (DiffeRet), (Call<(SameTypesFunc<"cosh">), [ReadNone,NoUnwind]> $x))],
(ForwardFromSummedReverse),
[ReadNone, NoUnwind]
Expand Down Expand Up @@ -872,6 +872,24 @@ def : CallPattern<(Op (Op $x, $y):$z),
[ReadNone, NoUnwind]
>;

def : IntrPattern<(Op $x),
[["tanh", "19", ""]],
[(FDiv (DiffeRet), (FMul(Intrinsic<"cosh"> $x):$c, $c))],
(ForwardFromSummedReverse)
>;

def : IntrPattern<(Op $x),
[["sinh", "19", ""]],
[(FMul (DiffeRet), (Intrinsic<"cosh"> $x))],
(ForwardFromSummedReverse)
>;

def : IntrPattern<(Op $x),
[["cosh", "19", ""]],
[(FMul (DiffeRet), (Intrinsic<"sinh"> $x))],
(ForwardFromSummedReverse)
>;

def : IntrPattern<(Op $x),
[["sin"]],
[(FMul (DiffeRet), (Intrinsic<"cos"> $x))] ,
Expand Down
11 changes: 11 additions & 0 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,15 @@ const llvm::StringMap<llvm::Intrinsic::ID> LIBM_FUNCTIONS = {
{"atan", Intrinsic::not_intrinsic},
{"atan2", Intrinsic::not_intrinsic},
{"__nv_atan2", Intrinsic::not_intrinsic},
#if LLVM_VERSION_MAJOR >= 19
{"cosh", Intrinsic::cosh},
{"sinh", Intrinsic::sinh},
{"tanh", Intrinsic::tanh},
#else
{"cosh", Intrinsic::not_intrinsic},
{"sinh", Intrinsic::not_intrinsic},
{"tanh", Intrinsic::not_intrinsic},
#endif
{"acosh", Intrinsic::not_intrinsic},
{"asinh", Intrinsic::not_intrinsic},
{"atanh", Intrinsic::not_intrinsic},
Expand Down Expand Up @@ -3849,6 +3855,11 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) {
case Intrinsic::exp2:
case Intrinsic::sin:
case Intrinsic::cos:
#if LLVM_VERSION_MAJOR >= 19
case Intrinsic::sinh:
case Intrinsic::cosh:
case Intrinsic::tanh:
#endif
case Intrinsic::floor:
case Intrinsic::ceil:
case Intrinsic::trunc:
Expand Down
5 changes: 5 additions & 0 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1693,6 +1693,11 @@ static inline bool isNoEscapingAllocation(const llvm::Function *F) {
case Intrinsic::exp:
case Intrinsic::cos:
case Intrinsic::sin:
#if LLVM_VERSION_MAJOR >= 19
case Intrinsic::tanh:
case Intrinsic::cosh:
case Intrinsic::sinh:
#endif
case Intrinsic::copysign:
case Intrinsic::fabs:
return true;
Expand Down
28 changes: 28 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/cosh19.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
; RUN: if [ %llvmver -ge 19 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi

; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare double @llvm.cosh.f64(double) #14

; Function Attrs: nounwind readnone uwtable
define double @tester(double %x) {
entry:
%0 = call double @llvm.cosh.f64(double %x)
ret double %0
}

define double @test_derivative(double %x) {
entry:
%0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x)
ret double %0
}

; Function Attrs: nounwind
declare double @__enzyme_autodiff(double (double)*, ...)

; CHECK: define internal { double } @diffetester(double %x, double %differeturn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call fast double @llvm.sinh.f64(double %x)
; CHECK-NEXT: %1 = fmul fast double %differeturn, %0
; CHECK-NEXT: %2 = insertvalue { double } undef, double %1, 0
; CHECK-NEXT: ret { double } %2
; CHECK-NEXT: }
28 changes: 28 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/sinh19.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
; RUN: if [ %llvmver -ge 19 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi

; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare double @llvm.sinh.f64(double) #14

; Function Attrs: nounwind readnone uwtable
define double @tester(double %x) {
entry:
%0 = call double @llvm.sinh.f64(double %x)
ret double %0
}

define double @test_derivative(double %x) {
entry:
%0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x)
ret double %0
}

; Function Attrs: nounwind
declare double @__enzyme_autodiff(double (double)*, ...)

; CHECK: define internal { double } @diffetester(double %x, double %differeturn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call fast double @llvm.cosh.f64(double %x)
; CHECK-NEXT: %1 = fmul fast double %differeturn, %0
; CHECK-NEXT: %2 = insertvalue { double } undef, double %1, 0
; CHECK-NEXT: ret { double } %2
; CHECK-NEXT: }
29 changes: 29 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/tanh19.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
; RUN: if [ %llvmver -ge 19 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi

; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare double @llvm.tanh.f64(double) #14

; Function Attrs: nounwind readnone uwtable
define double @tester(double %x) {
entry:
%0 = call double @llvm.tanh.f64(double %x)
ret double %0
}

define double @test_derivative(double %x) {
entry:
%0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x)
ret double %0
}

; Function Attrs: nounwind
declare double @__enzyme_autodiff(double (double)*, ...)

; CHECK: define internal { double } @diffetester(double %x, double %differeturn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call fast double @llvm.cosh.f64(double %x)
; CHECK-NEXT: %1 = fmul fast double %0, %0
; CHECK-NEXT: %2 = fdiv fast double %differeturn, %1
; CHECK-NEXT: %3 = insertvalue { double } undef, double %2, 0
; CHECK-NEXT: ret { double } %3
; CHECK-NEXT: }

0 comments on commit a8a33d5

Please sign in to comment.