Skip to content

Commit

Permalink
docs: try fixing nested autodiff
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 6, 2024
1 parent e395ed9 commit 52c8880
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 58 deletions.
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand Down
28 changes: 14 additions & 14 deletions docs/src/manual/nested_autodiff.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Let's explore this using some questions that were posted on the
[Julia Discourse forum](https://discourse.julialang.org/).

```@example nested_ad
using ADTypes, Lux, LinearAlgebra, Zygote, ForwardDiff, Random
using ADTypes, Lux, LinearAlgebra, Zygote, ForwardDiff, Random, StableRNGs
using ComponentArrays, FiniteDiff
```

Expand Down Expand Up @@ -70,15 +70,15 @@ function loss_function1(model, x, ps, st, y)
loss_emp = sum(abs2, ŷ .- y)
# You can use `Zygote.jacobian` as well but ForwardDiff tends to be more efficient here
J = ForwardDiff.jacobian(smodel, x)
loss_reg = abs2(norm(J))
loss_reg = abs2(norm(J .* 0.01f0))
return loss_emp + loss_reg
end
# Using Batchnorm to show that it is possible
model = Chain(Dense(2 => 4, tanh), BatchNorm(4), Dense(4 => 2))
ps, st = Lux.setup(Xoshiro(0), model)
x = rand(Xoshiro(0), Float32, 2, 10)
y = rand(Xoshiro(11), Float32, 2, 10)
ps, st = Lux.setup(StableRNG(0), model)
x = randn(StableRNG(0), Float32, 2, 10)
y = randn(StableRNG(11), Float32, 2, 10)
loss_function1(model, x, ps, st, y)
```
Expand All @@ -97,9 +97,9 @@ Now let's verify the gradients using finite differences:
ComponentArray(ps))
println("∞-norm(∂x - ∂x_fd): ", norm(∂x .- ∂x_fd, Inf))
@assert norm(∂x .- ∂x_fd, Inf) < 1e-1 # hide
@assert norm(∂x .- ∂x_fd, Inf) < 1e-2 # hide
println("∞-norm(∂ps - ∂ps_fd): ", norm(ComponentArray(∂ps) .- ∂ps_fd, Inf))
@assert norm(ComponentArray(∂ps) .- ∂ps_fd, Inf) < 1e-1 # hide
@assert norm(ComponentArray(∂ps) .- ∂ps_fd, Inf) < 1e-2 # hide
nothing; # hide
```

Expand All @@ -123,8 +123,8 @@ end
model = Chain(Dense(1 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh),
Dense(12 => 1))
ps, st = Lux.setup(Xoshiro(0), model)
t = rand(Xoshiro(0), Float32, 1, 16)
ps, st = Lux.setup(StableRNG(0), model)
t = rand(StableRNG(0), Float32, 1, 16)
```

Now the moment of truth:
Expand Down Expand Up @@ -164,9 +164,9 @@ end
model = Chain(Dense(1 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh),
Dense(12 => 1))
ps, st = Lux.setup(Xoshiro(0), model)
ps, st = Lux.setup(StableRNG(0), model)
ps = ComponentArray(ps) # needs to be an AbstractArray for most jacobian functions
x = rand(Xoshiro(0), Float32, 1, 16)
x = rand(StableRNG(0), Float32, 1, 16)
```

We can as usual compute the gradient/jacobian of the loss function:
Expand Down Expand Up @@ -260,9 +260,9 @@ Now let's compute the trace and compare the results:
```@example nested_ad
model = Chain(Dense(4 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh),
Dense(12 => 4))
ps, st = Lux.setup(Xoshiro(0), model)
x = rand(Xoshiro(0), Float32, 4, 12)
v = (rand(Xoshiro(12), Float32, 4, 12) .> 0.5f0) * 2.0f0 .- 1.0f0 # rademacher sample
ps, st = Lux.setup(StableRNG(0), model)
x = rand(StableRNG(0), Float32, 4, 12)
v = (rand(StableRNG(12), Float32, 4, 12) .> 0.5f0) * 2.0f0 .- 1.0f0 # rademacher sample
nothing; # hide
```

Expand Down
64 changes: 20 additions & 44 deletions test/autodiff/nested_autodiff_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,15 @@ function test_nested_ad_input_gradient_jacobian(aType, dev, ongpu, loss_fn, X, m
!iszero(ComponentArray(∂ps |> cpu_device())) &&
all(x -> x === nothing || isfinite(x), ComponentArray(∂ps |> cpu_device()))

__f = (x, ps) -> loss_fn(model, x, ps, st)

allow_unstable() do
FDIFF_WORKS = try
LuxTestUtils.gradient(__f, AutoForwardDiff(), X, ps)
true
catch
false
end
skip_backends = [AutoReverseDiff(), AutoTracker(), AutoEnzyme()]
if FDIFF_WORKS
test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; atol=1.0f-3,
rtol=1.0f-1, soft_fail=[AutoFiniteDiff()], skip_backends)
else
test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; atol=1.0f-3,
rtol=1.0f-1, skip_backends=vcat(skip_backends, [AutoFiniteDiff()]))
end
test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps;
atol=1.0f-3, rtol=1.0f-1, soft_fail=[AutoFiniteDiff()],
skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()])
end
end

const Xs = (randn(rng, Float32, 3, 3, 2, 4), randn(rng, Float32, 2, 4),
randn(rng, Float32, 2, 4), randn(rng, Float32, 3, 3, 2, 4))
const Xs = (randn(rng, Float32, 3, 3, 2, 2), randn(rng, Float32, 2, 2),
randn(rng, Float32, 2, 2), randn(rng, Float32, 3, 3, 2, 2))

const models = (
Chain(Conv((3, 3), 2 => 4, gelu; pad=SamePad()), BatchNorm(4),
Expand All @@ -63,25 +50,25 @@ const models = (
# smodel | ForwardDiff.jacobian
function loss_function1(model, x, ps, st)
smodel = StatefulLuxLayer{true}(model, ps, st)
return sum(abs2, ForwardDiff.jacobian(smodel, x))
return sum(abs2, ForwardDiff.jacobian(smodel, x) .* 0.01f0)
end

# smodel | Zygote.jacobian
function loss_function2(model, x, ps, st)
smodel = StatefulLuxLayer{true}(model, ps, st)
return sum(abs2, only(Zygote.jacobian(smodel, x)))
return sum(abs2, only(Zygote.jacobian(smodel, x)) .* 0.01f0)
end

# sum(abs2) ∘ smodel | ForwardDiff.gradient
function loss_function3(model, x, ps, st)
smodel = StatefulLuxLayer{true}(model, ps, st)
return sum(abs2, ForwardDiff.gradient(Base.Fix1(sum, abs2) smodel, x))
return sum(abs2, ForwardDiff.gradient(Base.Fix1(sum, abs2) smodel, x) .* 0.01f0)
end

# sum(abs2) ∘ smodel | Zygote.gradient
function loss_function4(model, x, ps, st)
smodel = StatefulLuxLayer{true}(model, ps, st)
return sum(abs2, only(Zygote.gradient(Base.Fix1(sum, abs2) smodel, x)))
return sum(abs2, only(Zygote.gradient(Base.Fix1(sum, abs2) smodel, x)) .* 0.01f0)
end

const ALL_TEST_CONFIGS = Iterators.product(
Expand Down Expand Up @@ -165,28 +152,15 @@ function test_nested_ad_parameter_gradient_jacobian(aType, dev, ongpu, loss_fn,
!iszero(ComponentArray(∂ps |> cpu_device())) &&
all(x -> x === nothing || isfinite(x), ComponentArray(∂ps |> cpu_device()))

__f = (x, ps) -> loss_fn(model, x, ps, st)

allow_unstable() do
FDIFF_WORKS = try
LuxTestUtils.gradient(__f, AutoForwardDiff(), X, ps)
true
catch
false
end
skip_backends = [AutoReverseDiff(), AutoTracker(), AutoEnzyme()]
if FDIFF_WORKS
test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; atol=1.0f-3,
rtol=1.0f-1, soft_fail=[AutoFiniteDiff()], skip_backends)
else
test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; atol=1.0f-3,
rtol=1.0f-1, skip_backends=vcat(skip_backends, [AutoFiniteDiff()]))
end
test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps;
atol=1.0f-3, rtol=1.0f-1, soft_fail=[AutoFiniteDiff()],
skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()])
end
end

const Xs = (randn(rng, Float32, 3, 3, 2, 4), randn(rng, Float32, 2, 4),
randn(rng, Float32, 2, 4), randn(rng, Float32, 3, 3, 2, 4))
const Xs = (randn(rng, Float32, 3, 3, 2, 2), randn(rng, Float32, 2, 2),
randn(rng, Float32, 2, 2), randn(rng, Float32, 3, 3, 2, 2))

const models = (
Chain(Conv((3, 3), 2 => 4, gelu; pad=SamePad()), BatchNorm(4),
Expand All @@ -201,25 +175,27 @@ const models = (
# smodel | ForwardDiff.jacobian
function loss_function1(model, x, ps, st)
smodel = StatefulLuxLayer{true}(model, ps, st)
return sum(abs2, ForwardDiff.jacobian(Base.Fix1(smodel, x), ps))
return sum(abs2, ForwardDiff.jacobian(Base.Fix1(smodel, x), ps) .* 0.01f0)
end

# smodel | Zygote.jacobian
function loss_function2(model, x, ps, st)
smodel = StatefulLuxLayer{true}(model, ps, st)
return sum(abs2, only(Zygote.jacobian(Base.Fix1(smodel, x), ps)))
return sum(abs2, only(Zygote.jacobian(Base.Fix1(smodel, x), ps)) .* 0.01f0)
end

# sum(abs2) ∘ smodel | ForwardDiff.gradient
function loss_function3(model, x, ps, st)
smodel = StatefulLuxLayer{true}(model, ps, st)
return sum(abs2, ForwardDiff.gradient(Base.Fix1(sum, abs2) Base.Fix1(smodel, x), ps))
return sum(abs2,
ForwardDiff.gradient(Base.Fix1(sum, abs2) Base.Fix1(smodel, x), ps) .* 0.01f0)
end

# sum(abs2) ∘ smodel | Zygote.gradient
function loss_function4(model, x, ps, st)
smodel = StatefulLuxLayer{true}(model, ps, st)
return sum(abs2, only(Zygote.gradient(Base.Fix1(sum, abs2) Base.Fix1(smodel, x), ps)))
return sum(abs2,
only(Zygote.gradient(Base.Fix1(sum, abs2) Base.Fix1(smodel, x), ps)) .* 0.01f0)
end

const ALL_TEST_CONFIGS = Iterators.product(
Expand Down

0 comments on commit 52c8880

Please sign in to comment.