Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test against Enzyme #318

Merged
merged 22 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/AD.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
arch:
- x64
AD:
- Enzyme
- ForwardDiff
- Tapir
- Tracker
Expand All @@ -29,6 +30,10 @@ jobs:
exclude:
- version: 1.6
AD: Tapir
# TODO(mhauru) Hopefully can enable Enzyme on older versions at some point, see
# discussion in https://github.com/TuringLang/Bijectors.jl/pull.
- version: 1.6
AD: Enzyme
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand Down
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand All @@ -34,6 +35,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
BijectorsDistributionsADExt = "DistributionsAD"
BijectorsEnzymeExt = "Enzyme"
BijectorsForwardDiffExt = "ForwardDiff"
BijectorsLazyArraysExt = "LazyArrays"
BijectorsReverseDiffExt = "ReverseDiff"
Expand All @@ -50,6 +52,7 @@ Compat = "3.46, 4.2"
Distributions = "0.25.33"
DistributionsAD = "0.6"
DocStringExtensions = "0.9"
Enzyme = "0.12.22"
ForwardDiff = "0.10"
Functors = "0.1, 0.2, 0.3, 0.4"
InverseFunctions = "0.1"
Expand All @@ -69,6 +72,7 @@ julia = "1.6"

[extras]
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand Down
14 changes: 14 additions & 0 deletions ext/BijectorsEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module BijectorsEnzymeExt

if isdefined(Base, :get_extension)
using Enzyme: @import_frule, @import_rrule
using Bijectors: find_alpha
else
using ..Enzyme: @import_frule, @import_rrule
using ..Bijectors: find_alpha
end

@import_rrule typeof(find_alpha) Real Real Real
@import_frule typeof(find_alpha) Real Real Real

end
2 changes: 1 addition & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ function transform(t::Transform, x)
res = with_logabsdet_jacobian(t, x)
if res isa ChangesOfVariables.NoLogAbsDetJacobian
error(
"`transform` not implemented for $(typeof(f)); implement `transform` and/or `with_logabsdet_jacobian`.",
"`transform` not implemented for $(typeof(t)); implement `transform` and/or `with_logabsdet_jacobian`.",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change has nothing to do with this PR, I just spotted the typo while working on this PR and didn't feel like making a separate one character PR.

)
end

Expand Down
4 changes: 3 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -15,8 +16,8 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand All @@ -31,6 +32,7 @@ ChangesOfVariables = "0.1"
Combinatorics = "1.0.2"
Compat = "3.46, 4.2"
DistributionsAD = "0.6.3"
Enzyme = "0.12.22"
FillArrays = "1"
FiniteDifferences = "0.11, 0.12"
ForwardDiff = "0.10.12"
Expand Down
48 changes: 48 additions & 0 deletions test/ad/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,24 @@
const AD = get(ENV, "AD", "All")

function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
for b in broken
if !(
b in (
:ForwardDiff,
:Zygote,
:Tapir,
:ReverseDiff,
:Enzyme,
:EnzymeForward,
:EnzymeReverse,
)
)
error("Unknown broken AD backend: $b")
end
end

finitediff = FiniteDifferences.grad(central_fdm(5, 1), f, x)[1]
et = eltype(finitediff)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Shouldn't Enzyme return the correct types automatically?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In forward mode it returns tuples, and if the gradient is empty, the result is a Tuple{}. This resulted in comparing an empty Float64[] to an empty Union{}[], which failed. See EnzymeAD/Enzyme.jl#1584

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the gradient should never be empty? Such a test would be quite useless, so I assume we don't run into this special case here? So maybe a simple collect (without specifying the element type) would be sufficient?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a corner case, but we ran into it here, when d==1:

for d in [1, 2, 5]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest not testing AD in the case d = 1 - or just checking that the gradient is empty. We already handle this case in a special way in e.g.

test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false)
test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false)
.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It passes now though, and I don't really see a downside to testing it? Good to know for instance that nothing crashes even if you hit this corner case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I also would prefer to not remove the test completely (even though I think it's of very limited use)

just checking that the gradient is empty

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what the current test is effectively doing, because when d = 1 finitediff returns an empty array. The eltype thing just makes sure that the check becomes Float64[] == Float64[], rather than Union{}[] == Float64[]. We could put in a specific case for d == 1 in the test file, but this seems like more work to me, because you need to make it cater to different AD backends and specify manually that the result should be empty.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant something differently - removing the eltype/collect completely and only add a special case to this weird test of the CorrBijector since we already have special cases for d = 1 there anyway.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's how I understood you, but that would require something like adding another argument to test_ad called expect_empty that would skip comparing to finitediff and instead check that the gradient has length 1 (for all AD backends) and setting that argument to d == 1, which to me seems more complicated, with a bunch of if statements, compared adding the one-liner enforcing eltype. I can do it if you prefer it, I just don't see the benefit.


if AD == "All" || AD == "ForwardDiff"
if :ForwardDiff in broken
Expand Down Expand Up @@ -30,6 +47,37 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
end
end

# TODO(mhauru) The version bound should be relaxed once some Enzyme issues get
# sorted out. I think forward mode will remain broken for versions <= 1.6 due to
# some Julia bug. See https://github.com/EnzymeAD/Enzyme.jl/issues/1629 and
# discussion in https://github.com/TuringLang/Bijectors.jl/pull/318.
if (AD == "All" || AD == "Enzyme") && VERSION >= v"1.10"
forward_broken = :EnzymeForward in broken || :Enzyme in broken
reverse_broken = :EnzymeReverse in broken || :Enzyme in broken
if forward_broken
@test_broken(
collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff,
rtol = rtol,
atol = atol
)
else
@test(
collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff,
rtol = rtol,
atol = atol
)
end
if reverse_broken
@test_broken(
Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol
)
else
@test(
Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol
)
end
end

if (AD == "All" || AD == "Tapir") && VERSION >= v"1.10"
rule = Tapir.build_rrule(f, x; safety_on=false)
if :tapir in broken
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Bijectors
using ChainRulesTestUtils
using Combinatorics
using DistributionsAD
using Enzyme
using FiniteDifferences
using ForwardDiff
using Functors
Expand Down
Loading