Skip to content

Commit

Permalink
add tanh support for llvm19+
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Oct 2, 2024
1 parent 7425c37 commit a1d7334
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 6 deletions.
5 changes: 5 additions & 0 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2812,6 +2812,11 @@ bool guaranteedDataDependent(Value *z) {
case Intrinsic::sqrt:
case Intrinsic::sin:
case Intrinsic::cos:
#if LLVM_VERSION_MAJOR >= 19
case Intrinsic::sinh:
case Intrinsic::cosh:
case Intrinsic::tanh:
#endif
return guaranteedDataDependent(II->getArgOperand(0));
default:
break;
Expand Down
5 changes: 5 additions & 0 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4180,6 +4180,11 @@ bool GradientUtils::shouldRecompute(const Value *val,
case Intrinsic::sin:
case Intrinsic::cos:
case Intrinsic::exp:
#if LLVM_VERSION_MAJOR >= 19
case Intrinsic::tanh:
case Intrinsic::cosh:
case Intrinsic::sinh:
#endif
case Intrinsic::log:
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_p:
Expand Down
12 changes: 6 additions & 6 deletions enzyme/Enzyme/InstructionDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,6 @@ def : CallPattern<(Op $x, $y),
[ReadNone, NoUnwind]
>;

def : CallPattern<(Op $x),
["tanh"],
[(FDiv (DiffeRet), (FMul(Call<(SameTypesFunc<"cosh">), [ReadNone,NoUnwind]> $x):$c, $c))],
(ForwardFromSummedReverse),
[ReadNone, NoUnwind]
>;
def : CallPattern<(Op $x),
["tanhf"],
[(FDiv (DiffeRet), (FMul(Call<(SameTypesFunc<"coshf">), [ReadNone,NoUnwind]> $x):$c, $c))],
Expand Down Expand Up @@ -872,6 +866,12 @@ def : CallPattern<(Op (Op $x, $y):$z),
[ReadNone, NoUnwind]
>;

def : IntrPattern<(Op $x),
[["tanh"]],
[(FDiv (DiffeRet), (FMul(Intrinsic<"cosh"> $x):$c, $c))],
(ForwardFromSummedReverse)
>;

def : IntrPattern<(Op $x),
[["sin"]],
[(FMul (DiffeRet), (Intrinsic<"cos"> $x))] ,
Expand Down
11 changes: 11 additions & 0 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,15 @@ const llvm::StringMap<llvm::Intrinsic::ID> LIBM_FUNCTIONS = {
{"atan", Intrinsic::not_intrinsic},
{"atan2", Intrinsic::not_intrinsic},
{"__nv_atan2", Intrinsic::not_intrinsic},
#if LLVM_VERSION_MAJOR >= 19
{"cosh", Intrinsic::cosh},
{"sinh", Intrinsic::sinh},
{"tanh", Intrinsic::tanh},
#else
{"cosh", Intrinsic::not_intrinsic},
{"sinh", Intrinsic::not_intrinsic},
{"tanh", Intrinsic::not_intrinsic},
#endif
{"acosh", Intrinsic::not_intrinsic},
{"asinh", Intrinsic::not_intrinsic},
{"atanh", Intrinsic::not_intrinsic},
Expand Down Expand Up @@ -3849,6 +3855,11 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) {
case Intrinsic::exp2:
case Intrinsic::sin:
case Intrinsic::cos:
#if LLVM_VERSION_MAJOR >= 19
case Intrinsic::sinh:
case Intrinsic::cosh:
case Intrinsic::tanh:
#endif
case Intrinsic::floor:
case Intrinsic::ceil:
case Intrinsic::trunc:
Expand Down
5 changes: 5 additions & 0 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1693,6 +1693,11 @@ static inline bool isNoEscapingAllocation(const llvm::Function *F) {
case Intrinsic::exp:
case Intrinsic::cos:
case Intrinsic::sin:
#if LLVM_VERSION_MAJOR >= 19
case Intrinsic::tanh:
case Intrinsic::cosh:
case Intrinsic::sinh:
#endif
case Intrinsic::copysign:
case Intrinsic::fabs:
return true;
Expand Down
29 changes: 29 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/tanh19.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
; RUN: if [ %llvmver -ge 19 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi

; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare double @llvm.tanh.f64(double) #14

; Function Attrs: nounwind readnone uwtable
define double @tester(double %x) {
entry:
%0 = call double @llvm.tanh.f64(double %x)
ret double %0
}

define double @test_derivative(double %x) {
entry:
%0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x)
ret double %0
}

; Function Attrs: nounwind
declare double @__enzyme_autodiff(double (double)*, ...)

; CHECK: define internal { double } @diffetester(double %x, double %differeturn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call fast double @llvm.cosh.f64(double %x)
; CHECK-NEXT: %1 = fmul fast double %0, %0
; CHECK-NEXT: %2 = fdiv fast double %differeturn, %1
; CHECK-NEXT: %3 = insertvalue { double } undef, double %2, 0
; CHECK-NEXT: ret { double } %3
; CHECK-NEXT: }

0 comments on commit a1d7334

Please sign in to comment.