From a8a33d520fc99d71c8493320b406acb179749408 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 2 Oct 2024 11:16:21 -0400 Subject: [PATCH] add sinh/cosh/tanh support and tests for LLVM19 and newer --- enzyme/Enzyme/FunctionUtils.cpp | 5 ++++ enzyme/Enzyme/GradientUtils.cpp | 5 ++++ enzyme/Enzyme/InstructionDerivatives.td | 24 ++++++++++++++--- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 11 ++++++++ enzyme/Enzyme/Utils.h | 5 ++++ enzyme/test/Enzyme/ReverseMode/cosh19.ll | 28 ++++++++++++++++++++ enzyme/test/Enzyme/ReverseMode/sinh19.ll | 28 ++++++++++++++++++++ enzyme/test/Enzyme/ReverseMode/tanh19.ll | 29 +++++++++++++++++++++ 8 files changed, 132 insertions(+), 3 deletions(-) create mode 100644 enzyme/test/Enzyme/ReverseMode/cosh19.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/sinh19.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/tanh19.ll diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 2839caffae1..d70bb775ea6 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -2812,6 +2812,11 @@ bool guaranteedDataDependent(Value *z) { case Intrinsic::sqrt: case Intrinsic::sin: case Intrinsic::cos: +#if LLVM_VERSION_MAJOR >= 19 + case Intrinsic::sinh: + case Intrinsic::cosh: + case Intrinsic::tanh: +#endif return guaranteedDataDependent(II->getArgOperand(0)); default: break; diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index a9b2fbb8af1..23a95f8925a 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -4180,6 +4180,11 @@ bool GradientUtils::shouldRecompute(const Value *val, case Intrinsic::sin: case Intrinsic::cos: case Intrinsic::exp: +#if LLVM_VERSION_MAJOR >= 19 + case Intrinsic::tanh: + case Intrinsic::cosh: + case Intrinsic::sinh: +#endif case Intrinsic::log: case Intrinsic::nvvm_ldu_global_i: case Intrinsic::nvvm_ldu_global_p: diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index 9f765f88826..8ed27c15b29 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -336,7 +336,7 @@ def : CallPattern<(Op $x, $y), >; def : CallPattern<(Op $x), - ["tanh"], + ["tanh", "", "18"], [(FDiv (DiffeRet), (FMul(Call<(SameTypesFunc<"cosh">), [ReadNone,NoUnwind]> $x):$c, $c))], (ForwardFromSummedReverse), [ReadNone, NoUnwind] @@ -355,7 +355,7 @@ def : CallPattern<(Op $x), >; def : CallPattern<(Op $x), - ["cosh"], + ["cosh", "", "18"], [(FMul (DiffeRet), (Call<(SameTypesFunc<"sinh">), [ReadNone,NoUnwind]> $x))], (ForwardFromSummedReverse), [ReadNone, NoUnwind] @@ -374,7 +374,7 @@ def : CallPattern<(Op $x), >; def : CallPattern<(Op $x), - ["sinh"], + ["sinh", "", "18"], [(FMul (DiffeRet), (Call<(SameTypesFunc<"cosh">), [ReadNone,NoUnwind]> $x))], (ForwardFromSummedReverse), [ReadNone, NoUnwind] @@ -872,6 +872,24 @@ def : CallPattern<(Op (Op $x, $y):$z), [ReadNone, NoUnwind] >; +def : IntrPattern<(Op $x), + [["tanh", "19", ""]], + [(FDiv (DiffeRet), (FMul(Intrinsic<"cosh"> $x):$c, $c))], + (ForwardFromSummedReverse) + >; + +def : IntrPattern<(Op $x), + [["sinh", "19", ""]], + [(FMul (DiffeRet), (Intrinsic<"cosh"> $x))], + (ForwardFromSummedReverse) + >; + +def : IntrPattern<(Op $x), + [["cosh", "19", ""]], + [(FMul (DiffeRet), (Intrinsic<"sinh"> $x))], + (ForwardFromSummedReverse) + >; + def : IntrPattern<(Op $x), [["sin"]], [(FMul (DiffeRet), (Intrinsic<"cos"> $x))] , diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index aed7767652f..c59f712b012 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -118,9 +118,15 @@ const llvm::StringMap LIBM_FUNCTIONS = { {"atan", Intrinsic::not_intrinsic}, {"atan2", Intrinsic::not_intrinsic}, {"__nv_atan2", Intrinsic::not_intrinsic}, +#if LLVM_VERSION_MAJOR >= 19 + {"cosh", Intrinsic::cosh}, + {"sinh", Intrinsic::sinh}, + {"tanh", Intrinsic::tanh}, +#else {"cosh", Intrinsic::not_intrinsic}, {"sinh", Intrinsic::not_intrinsic}, {"tanh", Intrinsic::not_intrinsic}, +#endif {"acosh", Intrinsic::not_intrinsic}, {"asinh", Intrinsic::not_intrinsic}, {"atanh", Intrinsic::not_intrinsic}, @@ -3849,6 +3855,11 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) { case Intrinsic::exp2: case Intrinsic::sin: case Intrinsic::cos: +#if LLVM_VERSION_MAJOR >= 19 + case Intrinsic::sinh: + case Intrinsic::cosh: + case Intrinsic::tanh: +#endif case Intrinsic::floor: case Intrinsic::ceil: case Intrinsic::trunc: diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 9b66730d14d..a8ce244caad 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -1693,6 +1693,11 @@ static inline bool isNoEscapingAllocation(const llvm::Function *F) { case Intrinsic::exp: case Intrinsic::cos: case Intrinsic::sin: +#if LLVM_VERSION_MAJOR >= 19 + case Intrinsic::tanh: + case Intrinsic::cosh: + case Intrinsic::sinh: +#endif case Intrinsic::copysign: case Intrinsic::fabs: return true; diff --git a/enzyme/test/Enzyme/ReverseMode/cosh19.ll b/enzyme/test/Enzyme/ReverseMode/cosh19.ll new file mode 100644 index 00000000000..fa147e22dbe --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/cosh19.ll @@ -0,0 +1,28 @@ +; RUN: if [ %llvmver -ge 19 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare double @llvm.cosh.f64(double) #14 + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = call double @llvm.cosh.f64(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double)*, ...) + +; CHECK: define internal { double } @diffetester(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast double @llvm.sinh.f64(double %x) +; CHECK-NEXT: %1 = fmul fast double %differeturn, %0 +; CHECK-NEXT: %2 = insertvalue { double } undef, double %1, 0 +; CHECK-NEXT: ret { double } %2 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/sinh19.ll b/enzyme/test/Enzyme/ReverseMode/sinh19.ll new file mode 100644 index 00000000000..81d353b5a16 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/sinh19.ll @@ -0,0 +1,28 @@ +; RUN: if [ %llvmver -ge 19 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare double @llvm.sinh.f64(double) #14 + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = call double @llvm.sinh.f64(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double)*, ...) + +; CHECK: define internal { double } @diffetester(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast double @llvm.cosh.f64(double %x) +; CHECK-NEXT: %1 = fmul fast double %differeturn, %0 +; CHECK-NEXT: %2 = insertvalue { double } undef, double %1, 0 +; CHECK-NEXT: ret { double } %2 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/tanh19.ll b/enzyme/test/Enzyme/ReverseMode/tanh19.ll new file mode 100644 index 00000000000..2d22ab6b632 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/tanh19.ll @@ -0,0 +1,29 @@ +; RUN: if [ %llvmver -ge 19 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi + +; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare double @llvm.tanh.f64(double) #14 + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = call double @llvm.tanh.f64(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double)*, ...) + +; CHECK: define internal { double } @diffetester(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast double @llvm.cosh.f64(double %x) +; CHECK-NEXT: %1 = fmul fast double %0, %0 +; CHECK-NEXT: %2 = fdiv fast double %differeturn, %1 +; CHECK-NEXT: %3 = insertvalue { double } undef, double %2, 0 +; CHECK-NEXT: ret { double } %3 +; CHECK-NEXT: }