diff --git a/.github/workflows/AD.yml b/.github/workflows/AD.yml index fafabd00..47ef8549 100644 --- a/.github/workflows/AD.yml +++ b/.github/workflows/AD.yml @@ -21,6 +21,7 @@ jobs: arch: - x64 AD: + - Enzyme - ForwardDiff - Tapir - Tracker @@ -29,6 +30,10 @@ jobs: exclude: - version: 1.6 AD: Tapir + # TODO(mhauru) Hopefully can enable Enzyme on older versions at some point, see + # discussion in https://github.com/TuringLang/Bijectors.jl/pull. + - version: 1.6 + AD: Enzyme steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 diff --git a/Project.toml b/Project.toml index d5095408..c53ca2a9 100644 --- a/Project.toml +++ b/Project.toml @@ -25,6 +25,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -34,6 +35,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] BijectorsDistributionsADExt = "DistributionsAD" +BijectorsEnzymeExt = "Enzyme" BijectorsForwardDiffExt = "ForwardDiff" BijectorsLazyArraysExt = "LazyArrays" BijectorsReverseDiffExt = "ReverseDiff" @@ -50,6 +52,7 @@ Compat = "3.46, 4.2" Distributions = "0.25.33" DistributionsAD = "0.6" DocStringExtensions = "0.9" +Enzyme = "0.12.22" ForwardDiff = "0.10" Functors = "0.1, 0.2, 0.3, 0.4" InverseFunctions = "0.1" @@ -69,6 +72,7 @@ julia = "1.6" [extras] DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" diff --git a/ext/BijectorsEnzymeExt.jl b/ext/BijectorsEnzymeExt.jl new file mode 100644 index 00000000..1e8d8aa3 --- /dev/null +++ b/ext/BijectorsEnzymeExt.jl @@ -0,0 +1,14 @@ +module BijectorsEnzymeExt + +if isdefined(Base, :get_extension) + using Enzyme: @import_frule, @import_rrule + using Bijectors: find_alpha +else + using ..Enzyme: @import_frule, @import_rrule + using ..Bijectors: find_alpha +end + +@import_rrule typeof(find_alpha) Real Real Real +@import_frule typeof(find_alpha) Real Real Real + +end diff --git a/src/interface.jl b/src/interface.jl index fa4363ff..d600cd6c 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -93,7 +93,7 @@ function transform(t::Transform, x) res = with_logabsdet_jacobian(t, x) if res isa ChangesOfVariables.NoLogAbsDetJacobian error( - "`transform` not implemented for $(typeof(f)); implement `transform` and/or `with_logabsdet_jacobian`.", + "`transform` not implemented for $(typeof(t)); implement `transform` and/or `with_logabsdet_jacobian`.", ) end diff --git a/test/Project.toml b/test/Project.toml index 22c517f9..d4d9a3df 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -15,8 +16,8 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -31,6 +32,7 @@ ChangesOfVariables = "0.1" Combinatorics = "1.0.2" Compat = "3.46, 4.2" DistributionsAD = "0.6.3" +Enzyme = "0.12.22" FillArrays = "1" FiniteDifferences = "0.11, 0.12" ForwardDiff = "0.10.12" diff --git a/test/ad/utils.jl b/test/ad/utils.jl index d88d2cd9..3e21e693 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -2,7 +2,24 @@ const AD = get(ENV, "AD", "All") function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) + for b in broken + if !( + b in ( + :ForwardDiff, + :Zygote, + :Tapir, + :ReverseDiff, + :Enzyme, + :EnzymeForward, + :EnzymeReverse, + ) + ) + error("Unknown broken AD backend: $b") + end + end + finitediff = FiniteDifferences.grad(central_fdm(5, 1), f, x)[1] + et = eltype(finitediff) if AD == "All" || AD == "ForwardDiff" if :ForwardDiff in broken @@ -30,6 +47,37 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) end end + # TODO(mhauru) The version bound should be relaxed once some Enzyme issues get + # sorted out. I think forward mode will remain broken for versions <= 1.6 due to + # some Julia bug. See https://github.com/EnzymeAD/Enzyme.jl/issues/1629 and + # discussion in https://github.com/TuringLang/Bijectors.jl/pull/318. + if (AD == "All" || AD == "Enzyme") && VERSION >= v"1.10" + forward_broken = :EnzymeForward in broken || :Enzyme in broken + reverse_broken = :EnzymeReverse in broken || :Enzyme in broken + if forward_broken + @test_broken( + collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, + rtol = rtol, + atol = atol + ) + else + @test( + collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, + rtol = rtol, + atol = atol + ) + end + if reverse_broken + @test_broken( + Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol + ) + else + @test( + Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol + ) + end + end + if (AD == "All" || AD == "Tapir") && VERSION >= v"1.10" rule = Tapir.build_rrule(f, x; safety_on=false) if :tapir in broken diff --git a/test/runtests.jl b/test/runtests.jl index eafc34bc..914c0e32 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using Bijectors using ChainRulesTestUtils using Combinatorics using DistributionsAD +using Enzyme using FiniteDifferences using ForwardDiff using Functors