From 175f382a79dda37a5687ecb7822754f98191a621 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 26 Jun 2024 12:40:45 +0100 Subject: [PATCH 01/20] Test against Enzyme --- .github/workflows/AD.yml | 1 + test/Project.toml | 3 +++ test/ad/utils.jl | 16 ++++++++++++++++ test/runtests.jl | 6 ++++++ 4 files changed, 26 insertions(+) diff --git a/.github/workflows/AD.yml b/.github/workflows/AD.yml index 541ccdd9..070ede1a 100644 --- a/.github/workflows/AD.yml +++ b/.github/workflows/AD.yml @@ -25,6 +25,7 @@ jobs: - Tracker - ReverseDiff - Zygote + - Enzyme steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 diff --git a/test/Project.toml b/test/Project.toml index 6f156b8b..cfa3c537 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,6 +4,7 @@ ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -24,6 +25,8 @@ ChangesOfVariables = "0.1" Combinatorics = "1.0.2" Compat = "3.46, 4.2" DistributionsAD = "0.6.3" +# TODO(mhauru) Enzyme needs a compat bound, but only once the rapid iteration to fix issues is done. +#Enzyme = "0.12.19" FillArrays = "1" FiniteDifferences = "0.11, 0.12" ForwardDiff = "0.10.12" diff --git a/test/ad/utils.jl b/test/ad/utils.jl index bea823fd..c9f753d7 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -30,5 +30,21 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) end end + if AD == "All" || AD == "Enzyme" + if :EnzymeReverse in broken + @test collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol = rtol atol = atol + @test_broken Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol = rtol atol = atol + elseif :EnzymeForward in broken + @test_broken collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol = rtol atol = atol + @test Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol = rtol atol = atol + elseif :Enzyme in broken + @test_broken collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol = rtol atol = atol + @test_broken Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol = rtol atol = atol + else + @test collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol = rtol atol = atol + @test Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol = rtol atol = atol + end + end + return nothing end diff --git a/test/runtests.jl b/test/runtests.jl index 052b6056..86ffbf30 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using Bijectors using ChainRulesTestUtils using Combinatorics using DistributionsAD +using Enzyme using FiniteDifferences using ForwardDiff using Functors @@ -29,6 +30,11 @@ using ChangesOfVariables: ChangesOfVariables using InverseFunctions: InverseFunctions using LazyArrays: LazyArrays +# TODO(mhauru) Remove this once Enzyme is fixed? +Enzyme.API.typeWarning!(false) +# Enable runtime activity (workaround) +Enzyme.API.runtimeActivity!(true) + if VERSION < v"1.9" using Compat: stack end From 89a3267b4ea2105e2d5ca89c17bfc334d3f76765 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 26 Jun 2024 12:44:31 +0100 Subject: [PATCH 02/20] Run JuliaFormatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/ad/utils.jl | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index c9f753d7..4bc18f49 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -32,16 +32,22 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) if AD == "All" || AD == "Enzyme" if :EnzymeReverse in broken - @test collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol = rtol atol = atol - @test_broken Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol = rtol atol = atol + @test collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol = rtol atol = + atol + @test_broken Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol = rtol atol = + atol elseif :EnzymeForward in broken - @test_broken collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol = rtol atol = atol + @test_broken collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol = + rtol atol = atol @test Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol = rtol atol = atol elseif :Enzyme in broken - @test_broken collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol = rtol atol = atol - @test_broken Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol = rtol atol = atol + @test_broken collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol = + rtol atol = atol + @test_broken Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol = rtol atol = + atol else - @test collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol = rtol atol = atol + @test collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol = rtol atol = + atol @test Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol = rtol atol = atol end end From 961aeb6a6632f4f9404e25231094a5a09b4c3eac Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 26 Jun 2024 14:29:53 +0100 Subject: [PATCH 03/20] Disable some CI tests for Enzyme testing purposes --- .github/workflows/AD.yml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/AD.yml b/.github/workflows/AD.yml index 070ede1a..a82aa993 100644 --- a/.github/workflows/AD.yml +++ b/.github/workflows/AD.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: version: - - '1.6' + #- '1.6' # TODO(mhauru) Disabled temporarily for Enzyme testing - '1' os: - ubuntu-latest @@ -21,10 +21,11 @@ jobs: arch: - x64 AD: - - ForwardDiff - - Tracker + # TODO(mhauru) Disabled temporarily for Enzyme testing + #- ForwardDiff + #- Tracker - ReverseDiff - - Zygote + #- Zygote - Enzyme steps: - uses: actions/checkout@v2 From f33da48ab9da3305f5c33b17089f483778c75d1c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Jun 2024 16:44:42 +0100 Subject: [PATCH 04/20] Import ChainRule for find_alpha for Enzyme --- ext/BijectorsEnzymeExt.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 ext/BijectorsEnzymeExt.jl diff --git a/ext/BijectorsEnzymeExt.jl b/ext/BijectorsEnzymeExt.jl new file mode 100644 index 00000000..b1079789 --- /dev/null +++ b/ext/BijectorsEnzymeExt.jl @@ -0,0 +1,15 @@ +module BijectorsEnzymeExt + +if isdefined(Base, :get_extension) + using Enzyme: @import_frule, @import_rrule + using Bijectors: find_alpha +else + using ..Tapir: @import_frule, @import_rrule + using ..Bijectors: find_alpha +end + +@import_rrule typeof(find_alpha) Float64 Float64 Float64 +@import_frule typeof(find_alpha) Float64 Float64 Float64 + +end + From 28beac58c137caf79ae256df8d16c059f6446172 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 27 Jun 2024 16:46:43 +0100 Subject: [PATCH 05/20] Remove unnecessary whitespace Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- ext/BijectorsEnzymeExt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/ext/BijectorsEnzymeExt.jl b/ext/BijectorsEnzymeExt.jl index b1079789..7be8bab4 100644 --- a/ext/BijectorsEnzymeExt.jl +++ b/ext/BijectorsEnzymeExt.jl @@ -12,4 +12,3 @@ end @import_frule typeof(find_alpha) Float64 Float64 Float64 end - From 1239040f0533dd6277c0fffd007fec1e34d61834 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 28 Jun 2024 09:28:43 +0100 Subject: [PATCH 06/20] Fixes to Enzyme extension --- ext/BijectorsEnzymeExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/BijectorsEnzymeExt.jl b/ext/BijectorsEnzymeExt.jl index 7be8bab4..1e8d8aa3 100644 --- a/ext/BijectorsEnzymeExt.jl +++ b/ext/BijectorsEnzymeExt.jl @@ -4,11 +4,11 @@ if isdefined(Base, :get_extension) using Enzyme: @import_frule, @import_rrule using Bijectors: find_alpha else - using ..Tapir: @import_frule, @import_rrule + using ..Enzyme: @import_frule, @import_rrule using ..Bijectors: find_alpha end -@import_rrule typeof(find_alpha) Float64 Float64 Float64 -@import_frule typeof(find_alpha) Float64 Float64 Float64 +@import_rrule typeof(find_alpha) Real Real Real +@import_frule typeof(find_alpha) Real Real Real end From 94023a1c565e8986336975021862c8510dae3d1a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 28 Jun 2024 11:43:43 +0100 Subject: [PATCH 07/20] Enzyme test fixes --- Project.toml | 6 +++++- test/ad/utils.jl | 45 ++++++++++++++++++++++++++++++--------------- 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/Project.toml b/Project.toml index be360aa6..d85b8f2e 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -32,6 +33,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] BijectorsDistributionsADExt = "DistributionsAD" +BijectorsEnzymeExt = "Enzyme" BijectorsForwardDiffExt = "ForwardDiff" BijectorsLazyArraysExt = "LazyArrays" BijectorsReverseDiffExt = "ReverseDiff" @@ -45,8 +47,10 @@ ChainRulesCore = "0.10.11, 1" ChangesOfVariables = "0.1" Compat = "3.46, 4.2" Distributions = "0.25.33" -ForwardDiff = "0.10" DistributionsAD = "0.6" +# TODO(mhauru) Set Enzyme compat once rapid iteration is done. +#Enzyme = "0.12.20" +ForwardDiff = "0.10" Functors = "0.1, 0.2, 0.3, 0.4" InverseFunctions = "0.1" IrrationalConstants = "0.1, 0.2" diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 4bc18f49..ee82fdf4 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -3,6 +3,7 @@ const AD = get(ENV, "AD", "All") function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) finitediff = FiniteDifferences.grad(central_fdm(5, 1), f, x)[1] + et = eltype(finitediff) if AD == "All" || AD == "ForwardDiff" if :ForwardDiff in broken @@ -18,7 +19,7 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) else ∇zygote = Zygote.gradient(f, x)[1] @test (all(finitediff .== 0) && ∇zygote === nothing) || - isapprox(∇zygote, finitediff; rtol=rtol, atol=atol) + isapprox(∇zygote, finitediff; rtol=rtol, atol=atol) end end @@ -32,23 +33,37 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) if AD == "All" || AD == "Enzyme" if :EnzymeReverse in broken - @test collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol = rtol atol = - atol - @test_broken Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol = rtol atol = - atol + @test( + collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, + rtol=rtol, atol=atol + ) + @test_broken( + Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol=rtol, atol=atol + ) elseif :EnzymeForward in broken - @test_broken collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol = - rtol atol = atol - @test Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol = rtol atol = atol + @test_broken( + collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, + rtol=rtol, atol=atol + ) + @test( + Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol=rtol, atol=atol + ) elseif :Enzyme in broken - @test_broken collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol = - rtol atol = atol - @test_broken Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol = rtol atol = - atol + @test_broken( + collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, + rtol=rtol, atol=atol + ) + @test_broken( + Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol=rtol, atol=atol + ) else - @test collect(Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff rtol = rtol atol = - atol - @test Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff rtol = rtol atol = atol + @test( + collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, + rtol=rtol, atol=atol + ) + @test( + Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol=rtol, atol=atol + ) end end From 7860701ec76de5be2d2761ecd41499661a55399c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 28 Jun 2024 11:44:18 +0100 Subject: [PATCH 08/20] Remove unnecessary Enzyme settings --- test/runtests.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 86ffbf30..be029745 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -30,11 +30,6 @@ using ChangesOfVariables: ChangesOfVariables using InverseFunctions: InverseFunctions using LazyArrays: LazyArrays -# TODO(mhauru) Remove this once Enzyme is fixed? -Enzyme.API.typeWarning!(false) -# Enable runtime activity (workaround) -Enzyme.API.runtimeActivity!(true) - if VERSION < v"1.9" using Compat: stack end From 2163ec46b7e45f05e3d950c033e5b23f09703aa1 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 28 Jun 2024 11:52:20 +0100 Subject: [PATCH 09/20] Code style --- test/ad/utils.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index ee82fdf4..77c0b021 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -19,7 +19,7 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) else ∇zygote = Zygote.gradient(f, x)[1] @test (all(finitediff .== 0) && ∇zygote === nothing) || - isapprox(∇zygote, finitediff; rtol=rtol, atol=atol) + isapprox(∇zygote, finitediff; rtol=rtol, atol=atol) end end @@ -35,34 +35,34 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) if :EnzymeReverse in broken @test( collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, - rtol=rtol, atol=atol + rtol = rtol, atol = atol ) @test_broken( - Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol=rtol, atol=atol + Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol ) elseif :EnzymeForward in broken @test_broken( collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, - rtol=rtol, atol=atol + rtol = rtol, atol = atol ) @test( - Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol=rtol, atol=atol + Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol ) elseif :Enzyme in broken @test_broken( collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, - rtol=rtol, atol=atol + rtol = rtol, atol = atol ) @test_broken( - Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol=rtol, atol=atol + Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol ) else @test( collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, - rtol=rtol, atol=atol + rtol = rtol, atol = atol ) @test( - Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol=rtol, atol=atol + Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol ) end end From 86544752e41a6a4aab0e1607fbf1e53a8c548a64 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 28 Jun 2024 12:00:39 +0100 Subject: [PATCH 10/20] Apply suggestions from reviewdog Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/ad/utils.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 77c0b021..781a8b6f 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -35,7 +35,8 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) if :EnzymeReverse in broken @test( collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, - rtol = rtol, atol = atol + rtol = rtol, + atol = atol ) @test_broken( Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol @@ -43,7 +44,8 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) elseif :EnzymeForward in broken @test_broken( collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, - rtol = rtol, atol = atol + rtol = rtol, + atol = atol ) @test( Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol @@ -51,7 +53,8 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) elseif :Enzyme in broken @test_broken( collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, - rtol = rtol, atol = atol + rtol = rtol, + atol = atol ) @test_broken( Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol @@ -59,7 +62,8 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) else @test( collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, - rtol = rtol, atol = atol + rtol = rtol, + atol = atol ) @test( Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol From e43d71e4dcbaec137cfabe8cb7e900c1214fed01 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 8 Jul 2024 17:23:16 +0100 Subject: [PATCH 11/20] Check broken symbols in tests --- test/ad/utils.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 781a8b6f..31274f5d 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -2,6 +2,12 @@ const AD = get(ENV, "AD", "All") function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) + for b in broken + if !(b in (:ForwardDiff, :Zygote, :ReverseDiff, :Enzyme, :EnzymeForward, :EnzymeReverse)) + error("Unknown broken AD backend: $b") + end + end + finitediff = FiniteDifferences.grad(central_fdm(5, 1), f, x)[1] et = eltype(finitediff) From deec0a6108235c8524b9226b1d023f95fc3efabe Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 8 Jul 2024 18:04:46 +0100 Subject: [PATCH 12/20] Add :Tapir to list of valid broken test marks Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- test/ad/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 31274f5d..ae11bcb1 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -3,7 +3,7 @@ const AD = get(ENV, "AD", "All") function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) for b in broken - if !(b in (:ForwardDiff, :Zygote, :ReverseDiff, :Enzyme, :EnzymeForward, :EnzymeReverse)) + if !(b in (:ForwardDiff, :Zygote, :Tapir, :ReverseDiff, :Enzyme, :EnzymeForward, :EnzymeReverse)) error("Unknown broken AD backend: $b") end end From cd3000a205bf6fd203b1e62854b01a4de1f630b7 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 9 Jul 2024 10:12:19 +0100 Subject: [PATCH 13/20] Add Enzyme compat bounds --- Project.toml | 4 ++-- test/Project.toml | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index b6333650..c77d5b5d 100644 --- a/Project.toml +++ b/Project.toml @@ -50,8 +50,7 @@ Compat = "3.46, 4.2" Distributions = "0.25.33" DistributionsAD = "0.6" DocStringExtensions = "0.9" -# TODO(mhauru) Set Enzyme compat once rapid iteration is done. -#Enzyme = "0.12.20" +Enzyme = "0.12.22" ForwardDiff = "0.10" Functors = "0.1, 0.2, 0.3, 0.4" InverseFunctions = "0.1" @@ -70,6 +69,7 @@ julia = "1.6" [extras] DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" diff --git a/test/Project.toml b/test/Project.toml index 5c87fa4a..3bac4792 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -31,8 +31,7 @@ ChangesOfVariables = "0.1" Combinatorics = "1.0.2" Compat = "3.46, 4.2" DistributionsAD = "0.6.3" -# TODO(mhauru) Enzyme needs a compat bound, but only once the rapid iteration to fix issues is done. -#Enzyme = "0.12.19" +Enzyme = "0.12.22" FillArrays = "1" FiniteDifferences = "0.11, 0.12" ForwardDiff = "0.10.12" From 44bd7396c9dfac17a6a050b2751879b341a3b14d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 9 Jul 2024 13:29:44 +0100 Subject: [PATCH 14/20] Code style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/ad/utils.jl | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index d0815b22..995470e4 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -3,7 +3,17 @@ const AD = get(ENV, "AD", "All") function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) for b in broken - if !(b in (:ForwardDiff, :Zygote, :Tapir, :ReverseDiff, :Enzyme, :EnzymeForward, :EnzymeReverse)) + if !( + b in ( + :ForwardDiff, + :Zygote, + :Tapir, + :ReverseDiff, + :Enzyme, + :EnzymeForward, + :EnzymeReverse, + ) + ) error("Unknown broken AD backend: $b") end end @@ -25,7 +35,7 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) else ∇zygote = Zygote.gradient(f, x)[1] @test (all(finitediff .== 0) && ∇zygote === nothing) || - isapprox(∇zygote, finitediff; rtol=rtol, atol=atol) + isapprox(∇zygote, finitediff; rtol=rtol, atol=atol) end end From 976874ba03e4981d455d12d0e203cc9a7929db8c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 9 Jul 2024 13:33:03 +0100 Subject: [PATCH 15/20] Adjust when Enzyme tests are run --- .github/workflows/AD.yml | 13 +++++++------ test/ad/utils.jl | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/.github/workflows/AD.yml b/.github/workflows/AD.yml index 8458bca8..a61833b5 100644 --- a/.github/workflows/AD.yml +++ b/.github/workflows/AD.yml @@ -22,15 +22,16 @@ jobs: - x64 AD: - Enzyme - # TODO(mhauru) Disabled temporarily for Enzyme testing - #- ForwardDiff - #- Tapir - #- Tracker - #- ReverseDiff - #- Zygote + - ForwardDiff + - Tapir + - Tracker + - ReverseDiff + - Zygote exclude: - version: 1.6 AD: Tapir + - version: 1.6 + AD: Enzyme steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 995470e4..cceb3355 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -47,7 +47,7 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) end end - if AD == "All" || AD == "Enzyme" + if (AD == "All" || AD == "Enzyme") && VERSION >= v"1.10" if :EnzymeReverse in broken @test( collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, From f5fd835fe5821910c558994228b545968fd85fbb Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 10 Jul 2024 09:59:48 +0100 Subject: [PATCH 16/20] Improve Enzyme brokenness check --- test/ad/utils.jl | 27 +++++++-------------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index cceb3355..990e59f1 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -47,40 +47,27 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) end end - if (AD == "All" || AD == "Enzyme") && VERSION >= v"1.10" - if :EnzymeReverse in broken - @test( - collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, - rtol = rtol, - atol = atol - ) - @test_broken( - Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol - ) - elseif :EnzymeForward in broken + if AD == "All" || AD == "Enzyme" + forward_broken = :EnzymeForward in broken || :Enzyme in broken || VERSION <= v"1.6" + reverse_broken = :EnzymeReverse in broken || :Enzyme in broken + if forward_broken @test_broken( collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, rtol = rtol, atol = atol ) + else @test( - Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol - ) - elseif :Enzyme in broken - @test_broken( collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, rtol = rtol, atol = atol ) + end + if reverse_broken @test_broken( Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol ) else - @test( - collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, - rtol = rtol, - atol = atol - ) @test( Enzyme.gradient(Enzyme.Reverse, f, x) ≈ finitediff, rtol = rtol, atol = atol ) From 41d643c14a0a7cd7acdc32c6d58e9d4fb0f1cd2f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 11 Jul 2024 14:25:44 +0100 Subject: [PATCH 17/20] Don't check Enzyme AD for Julia < v1.10 --- test/ad/utils.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 990e59f1..6c0f577d 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -48,8 +48,12 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) end if AD == "All" || AD == "Enzyme" - forward_broken = :EnzymeForward in broken || :Enzyme in broken || VERSION <= v"1.6" - reverse_broken = :EnzymeReverse in broken || :Enzyme in broken + # TODO(mhauru) The version bounds should be relaxed once some Enzyme issues get + # sorted out. I think forward mode will remain broken for versions <= 1.6 due to + # some Julia bug. See https://github.com/EnzymeAD/Enzyme.jl/issues/1629 and + # discussion in https://github.com/TuringLang/Bijectors.jl/pull/318. + forward_broken = :EnzymeForward in broken || :Enzyme in broken || VERSION < v"1.10" + reverse_broken = :EnzymeReverse in broken || :Enzyme in broken || VERSION < v"1.10" if forward_broken @test_broken( collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff, From 5124bd70817d35118f04d8d7721be5a3c0f6e34e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 11 Jul 2024 14:28:51 +0100 Subject: [PATCH 18/20] Reenable CI for Julia 1.6 --- .github/workflows/AD.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/AD.yml b/.github/workflows/AD.yml index a61833b5..47ef8549 100644 --- a/.github/workflows/AD.yml +++ b/.github/workflows/AD.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: version: - #- '1.6' # TODO(mhauru) Disabled temporarily for Enzyme testing + - '1.6' - '1' os: - ubuntu-latest @@ -30,6 +30,8 @@ jobs: exclude: - version: 1.6 AD: Tapir + # TODO(mhauru) Hopefully can enable Enzyme on older versions at some point, see + # discussion in https://github.com/TuringLang/Bijectors.jl/pull. - version: 1.6 AD: Enzyme steps: From 73711d66b537aacf4fb87d7c0b8ac51577f19371 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 11 Jul 2024 14:29:38 +0100 Subject: [PATCH 19/20] Misc tiny typos --- src/interface.jl | 2 +- test/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index fa4363ff..d600cd6c 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -93,7 +93,7 @@ function transform(t::Transform, x) res = with_logabsdet_jacobian(t, x) if res isa ChangesOfVariables.NoLogAbsDetJacobian error( - "`transform` not implemented for $(typeof(f)); implement `transform` and/or `with_logabsdet_jacobian`.", + "`transform` not implemented for $(typeof(t)); implement `transform` and/or `with_logabsdet_jacobian`.", ) end diff --git a/test/Project.toml b/test/Project.toml index dea357c3..d4d9a3df 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -16,8 +16,8 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" From 04d98e9ccc664b4ab390273156ca23096fcdceab Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 11 Jul 2024 14:49:31 +0100 Subject: [PATCH 20/20] Don't run Enzyme at all for Julia < v1.10 --- test/ad/utils.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 6c0f577d..3e21e693 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -47,13 +47,13 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) end end - if AD == "All" || AD == "Enzyme" - # TODO(mhauru) The version bounds should be relaxed once some Enzyme issues get - # sorted out. I think forward mode will remain broken for versions <= 1.6 due to - # some Julia bug. See https://github.com/EnzymeAD/Enzyme.jl/issues/1629 and - # discussion in https://github.com/TuringLang/Bijectors.jl/pull/318. - forward_broken = :EnzymeForward in broken || :Enzyme in broken || VERSION < v"1.10" - reverse_broken = :EnzymeReverse in broken || :Enzyme in broken || VERSION < v"1.10" + # TODO(mhauru) The version bound should be relaxed once some Enzyme issues get + # sorted out. I think forward mode will remain broken for versions <= 1.6 due to + # some Julia bug. See https://github.com/EnzymeAD/Enzyme.jl/issues/1629 and + # discussion in https://github.com/TuringLang/Bijectors.jl/pull/318. + if (AD == "All" || AD == "Enzyme") && VERSION >= v"1.10" + forward_broken = :EnzymeForward in broken || :Enzyme in broken + reverse_broken = :EnzymeReverse in broken || :Enzyme in broken if forward_broken @test_broken( collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) ≈ finitediff,