From dcb6c6d8be97a6cfee94a799e2012354d7f34700 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Tue, 1 Oct 2024 11:52:26 -0700 Subject: [PATCH] docs: added to Nested AD example how to use `batched_jacobian` (#964) * Added to Nested AD example how to use `batched_jacobian` * Complete example with loss function and tests * Update docs/src/manual/nested_autodiff.md --- docs/src/manual/nested_autodiff.md | 47 ++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/docs/src/manual/nested_autodiff.md b/docs/src/manual/nested_autodiff.md index 497179c11..826925e99 100644 --- a/docs/src/manual/nested_autodiff.md +++ b/docs/src/manual/nested_autodiff.md @@ -106,6 +106,53 @@ nothing; # hide That's pretty good, of course you will have some error from the finite differences calculation. +### Using Batched Jacobian for Multiple Inputs + +Notice that in this example the Jacobian `J` consists on the full matrix of derivatives of `smodel` with respect +the different inputs in `x`. In many cases, we are interested in computing the Jacobian with respect to each +input individually, avoiding the unnecessary calculation of zero entries of the Jacobian. This can be achived with +[`batched_jacobian`](@ref) to parse the calculation of the Jacobian per each single input. Using the same example +from the previous section: + +```@example nested_ad +model = Chain(Dense(2 => 4, tanh), Dense(4 => 2)) +ps, st = Lux.setup(StableRNG(0), model) +x = randn(StableRNG(0), Float32, 2, 10) +y = randn(StableRNG(11), Float32, 2, 10) + +function loss_function_batched(model, x, ps, st, y) + # Make it a stateful layer + smodel = StatefulLuxLayer{true}(model, ps, st) + ŷ = smodel(x) + loss_emp = sum(abs2, ŷ .- y) + # You can use `AutoZygote()` as well but `AutoForwardDiff()` tends to be more efficient here + J = batched_jacobian(smodel, AutoForwardDiff(), x) + loss_reg = abs2(norm(J .* 0.01f0)) + return loss_emp + loss_reg +end + +loss_function_batched(model, x, ps, st, y) +``` + +Notice that in this last example we removed `BatchNorm()` from the neural network. This is done so outputs corresponding +to differern inputs don't have an algebraic dependency due to the batch normalization happening in the neural network. +We can now verify again the value of the Jacobian: + +```@example nested_ad +∂x_fd = FiniteDiff.finite_difference_gradient(x -> loss_function_batched(model, x, ps, st, y), x) +∂ps_fd = FiniteDiff.finite_difference_gradient(ps -> loss_function_batched(model, x, ps, st, y), + ComponentArray(ps)) + +_, ∂x_b, ∂ps_b, _, _ = Zygote.gradient(loss_function_batched, model, x, ps, st, y) +println("∞-norm(∂x_b - ∂x_fd): ", norm(∂x_b .- ∂x_fd, Inf)) +@assert norm(∂x_b .- ∂x_fd, Inf) < 1e-2 # hide +println("∞-norm(∂ps_b - ∂ps_fd): ", norm(ComponentArray(∂ps_b) .- ∂ps_fd, Inf)) +@assert norm(ComponentArray(∂ps_b) .- ∂ps_fd, Inf) < 1e-2 # hide +``` + +In this example, it is important to remark that now `batched_jacobian` returns a 3D array with the Jacobian calculation +for each independent input value in `x`. + ## Loss Function contains Gradient Computation Ok here I am going to cheat a bit. This comes from a discussion on nested AD for PINNs