From 52c8880c76949e347f7cbfc84fda35f7d2cc1079 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 18:38:45 -0400 Subject: [PATCH] docs: try fixing nested autodiff --- docs/Project.toml | 1 + docs/src/manual/nested_autodiff.md | 28 +++++------ test/autodiff/nested_autodiff_tests.jl | 64 ++++++++------------------ 3 files changed, 35 insertions(+), 58 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 2126c3122..85ac205ae 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -21,6 +21,7 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/docs/src/manual/nested_autodiff.md b/docs/src/manual/nested_autodiff.md index 0a5e074a4..497179c11 100644 --- a/docs/src/manual/nested_autodiff.md +++ b/docs/src/manual/nested_autodiff.md @@ -22,7 +22,7 @@ Let's explore this using some questions that were posted on the [Julia Discourse forum](https://discourse.julialang.org/). ```@example nested_ad -using ADTypes, Lux, LinearAlgebra, Zygote, ForwardDiff, Random +using ADTypes, Lux, LinearAlgebra, Zygote, ForwardDiff, Random, StableRNGs using ComponentArrays, FiniteDiff ``` @@ -70,15 +70,15 @@ function loss_function1(model, x, ps, st, y) loss_emp = sum(abs2, ŷ .- y) # You can use `Zygote.jacobian` as well but ForwardDiff tends to be more efficient here J = ForwardDiff.jacobian(smodel, x) - loss_reg = abs2(norm(J)) + loss_reg = abs2(norm(J .* 0.01f0)) return loss_emp + loss_reg end # Using Batchnorm to show that it is possible model = Chain(Dense(2 => 4, tanh), BatchNorm(4), Dense(4 => 2)) -ps, st = Lux.setup(Xoshiro(0), model) -x = rand(Xoshiro(0), Float32, 2, 10) -y = rand(Xoshiro(11), Float32, 2, 10) +ps, st = Lux.setup(StableRNG(0), model) +x = randn(StableRNG(0), Float32, 2, 10) +y = randn(StableRNG(11), Float32, 2, 10) loss_function1(model, x, ps, st, y) ``` @@ -97,9 +97,9 @@ Now let's verify the gradients using finite differences: ComponentArray(ps)) println("∞-norm(∂x - ∂x_fd): ", norm(∂x .- ∂x_fd, Inf)) -@assert norm(∂x .- ∂x_fd, Inf) < 1e-1 # hide +@assert norm(∂x .- ∂x_fd, Inf) < 1e-2 # hide println("∞-norm(∂ps - ∂ps_fd): ", norm(ComponentArray(∂ps) .- ∂ps_fd, Inf)) -@assert norm(ComponentArray(∂ps) .- ∂ps_fd, Inf) < 1e-1 # hide +@assert norm(ComponentArray(∂ps) .- ∂ps_fd, Inf) < 1e-2 # hide nothing; # hide ``` @@ -123,8 +123,8 @@ end model = Chain(Dense(1 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 1)) -ps, st = Lux.setup(Xoshiro(0), model) -t = rand(Xoshiro(0), Float32, 1, 16) +ps, st = Lux.setup(StableRNG(0), model) +t = rand(StableRNG(0), Float32, 1, 16) ``` Now the moment of truth: @@ -164,9 +164,9 @@ end model = Chain(Dense(1 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 1)) -ps, st = Lux.setup(Xoshiro(0), model) +ps, st = Lux.setup(StableRNG(0), model) ps = ComponentArray(ps) # needs to be an AbstractArray for most jacobian functions -x = rand(Xoshiro(0), Float32, 1, 16) +x = rand(StableRNG(0), Float32, 1, 16) ``` We can as usual compute the gradient/jacobian of the loss function: @@ -260,9 +260,9 @@ Now let's compute the trace and compare the results: ```@example nested_ad model = Chain(Dense(4 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 4)) -ps, st = Lux.setup(Xoshiro(0), model) -x = rand(Xoshiro(0), Float32, 4, 12) -v = (rand(Xoshiro(12), Float32, 4, 12) .> 0.5f0) * 2.0f0 .- 1.0f0 # rademacher sample +ps, st = Lux.setup(StableRNG(0), model) +x = rand(StableRNG(0), Float32, 4, 12) +v = (rand(StableRNG(12), Float32, 4, 12) .> 0.5f0) * 2.0f0 .- 1.0f0 # rademacher sample nothing; # hide ``` diff --git a/test/autodiff/nested_autodiff_tests.jl b/test/autodiff/nested_autodiff_tests.jl index 834400676..850c31387 100644 --- a/test/autodiff/nested_autodiff_tests.jl +++ b/test/autodiff/nested_autodiff_tests.jl @@ -27,28 +27,15 @@ function test_nested_ad_input_gradient_jacobian(aType, dev, ongpu, loss_fn, X, m !iszero(ComponentArray(∂ps |> cpu_device())) && all(x -> x === nothing || isfinite(x), ComponentArray(∂ps |> cpu_device())) - __f = (x, ps) -> loss_fn(model, x, ps, st) - allow_unstable() do - FDIFF_WORKS = try - LuxTestUtils.gradient(__f, AutoForwardDiff(), X, ps) - true - catch - false - end - skip_backends = [AutoReverseDiff(), AutoTracker(), AutoEnzyme()] - if FDIFF_WORKS - test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; atol=1.0f-3, - rtol=1.0f-1, soft_fail=[AutoFiniteDiff()], skip_backends) - else - test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; atol=1.0f-3, - rtol=1.0f-1, skip_backends=vcat(skip_backends, [AutoFiniteDiff()])) - end + test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; + atol=1.0f-3, rtol=1.0f-1, soft_fail=[AutoFiniteDiff()], + skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) end end -const Xs = (randn(rng, Float32, 3, 3, 2, 4), randn(rng, Float32, 2, 4), - randn(rng, Float32, 2, 4), randn(rng, Float32, 3, 3, 2, 4)) +const Xs = (randn(rng, Float32, 3, 3, 2, 2), randn(rng, Float32, 2, 2), + randn(rng, Float32, 2, 2), randn(rng, Float32, 3, 3, 2, 2)) const models = ( Chain(Conv((3, 3), 2 => 4, gelu; pad=SamePad()), BatchNorm(4), @@ -63,25 +50,25 @@ const models = ( # smodel | ForwardDiff.jacobian function loss_function1(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, ForwardDiff.jacobian(smodel, x)) + return sum(abs2, ForwardDiff.jacobian(smodel, x) .* 0.01f0) end # smodel | Zygote.jacobian function loss_function2(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, only(Zygote.jacobian(smodel, x))) + return sum(abs2, only(Zygote.jacobian(smodel, x)) .* 0.01f0) end # sum(abs2) ∘ smodel | ForwardDiff.gradient function loss_function3(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, ForwardDiff.gradient(Base.Fix1(sum, abs2) ∘ smodel, x)) + return sum(abs2, ForwardDiff.gradient(Base.Fix1(sum, abs2) ∘ smodel, x) .* 0.01f0) end # sum(abs2) ∘ smodel | Zygote.gradient function loss_function4(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, only(Zygote.gradient(Base.Fix1(sum, abs2) ∘ smodel, x))) + return sum(abs2, only(Zygote.gradient(Base.Fix1(sum, abs2) ∘ smodel, x)) .* 0.01f0) end const ALL_TEST_CONFIGS = Iterators.product( @@ -165,28 +152,15 @@ function test_nested_ad_parameter_gradient_jacobian(aType, dev, ongpu, loss_fn, !iszero(ComponentArray(∂ps |> cpu_device())) && all(x -> x === nothing || isfinite(x), ComponentArray(∂ps |> cpu_device())) - __f = (x, ps) -> loss_fn(model, x, ps, st) - allow_unstable() do - FDIFF_WORKS = try - LuxTestUtils.gradient(__f, AutoForwardDiff(), X, ps) - true - catch - false - end - skip_backends = [AutoReverseDiff(), AutoTracker(), AutoEnzyme()] - if FDIFF_WORKS - test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; atol=1.0f-3, - rtol=1.0f-1, soft_fail=[AutoFiniteDiff()], skip_backends) - else - test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; atol=1.0f-3, - rtol=1.0f-1, skip_backends=vcat(skip_backends, [AutoFiniteDiff()])) - end + test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; + atol=1.0f-3, rtol=1.0f-1, soft_fail=[AutoFiniteDiff()], + skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) end end -const Xs = (randn(rng, Float32, 3, 3, 2, 4), randn(rng, Float32, 2, 4), - randn(rng, Float32, 2, 4), randn(rng, Float32, 3, 3, 2, 4)) +const Xs = (randn(rng, Float32, 3, 3, 2, 2), randn(rng, Float32, 2, 2), + randn(rng, Float32, 2, 2), randn(rng, Float32, 3, 3, 2, 2)) const models = ( Chain(Conv((3, 3), 2 => 4, gelu; pad=SamePad()), BatchNorm(4), @@ -201,25 +175,27 @@ const models = ( # smodel | ForwardDiff.jacobian function loss_function1(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, ForwardDiff.jacobian(Base.Fix1(smodel, x), ps)) + return sum(abs2, ForwardDiff.jacobian(Base.Fix1(smodel, x), ps) .* 0.01f0) end # smodel | Zygote.jacobian function loss_function2(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, only(Zygote.jacobian(Base.Fix1(smodel, x), ps))) + return sum(abs2, only(Zygote.jacobian(Base.Fix1(smodel, x), ps)) .* 0.01f0) end # sum(abs2) ∘ smodel | ForwardDiff.gradient function loss_function3(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, ForwardDiff.gradient(Base.Fix1(sum, abs2) ∘ Base.Fix1(smodel, x), ps)) + return sum(abs2, + ForwardDiff.gradient(Base.Fix1(sum, abs2) ∘ Base.Fix1(smodel, x), ps) .* 0.01f0) end # sum(abs2) ∘ smodel | Zygote.gradient function loss_function4(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, only(Zygote.gradient(Base.Fix1(sum, abs2) ∘ Base.Fix1(smodel, x), ps))) + return sum(abs2, + only(Zygote.gradient(Base.Fix1(sum, abs2) ∘ Base.Fix1(smodel, x), ps)) .* 0.01f0) end const ALL_TEST_CONFIGS = Iterators.product(