Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficient way to compute Jacobian in nested AD #963

Closed
facusapienza21 opened this issue Oct 1, 2024 · 7 comments
Closed

Efficient way to compute Jacobian in nested AD #963

facusapienza21 opened this issue Oct 1, 2024 · 7 comments

Comments

@facusapienza21
Copy link
Contributor

HI! I was looking at the example in the docs about how to perform nested AD with Lux with Lux. The code in the documentation definitively works, and I have included a full example in Discurse with this for completeness. However, I noticed that when we evaluate the Jacobian this give us the Jacobian full with zeros, given than outputs of different inputs are not algebraically related (that is, no need to compute the full matrix).

For the following piece of code,

function loss(model, ps, st)

    # Compute predicions using model parameters
    X_pred = predict(ps)
    loss_emp = mean(abs2, Xₙ .- X_pred)

    # Make it a stateful layer
    smodel = StatefulLuxLayer{true}(U, ps, st)
    
    J = ForwardDiff.jacobian(smodel, Xₙ)
    loss_reg = 0.01f0 * abs2(norm(J))
    return loss_emp + loss_reg
end

this is how J looks like:
image

I think this is not very efficient, but I am also maybe missing something of how Lux internally manages this calculations. I tried computing the Jacobian/gradient for each individual input layer value, and this seems to be very inefficient. Other options on the top of my head include

  • Computing a VJP of this J times a vector with ones instead, to avoid the calculation of the zero entries. I think this is what Lux.jacobian_vector_preoduct is doing?
  • Add some sparsity pattern to the Jacobian calculation

Any suggestion here? I just would like to see this example using the best practices when using Lux.

Thank you so much! All this looks amazing.

@avik-pal
Copy link
Member

avik-pal commented Oct 1, 2024

@avik-pal
Copy link
Member

avik-pal commented Oct 1, 2024

Only thing to be aware of is that it gives you a 3D array (essentially a Uniform BlockDiagonal Matrix without storign the zeros)

@facusapienza21
Copy link
Contributor Author

Lovely, just like this I guess?
image

Thank you @avik-pal ! Would it make sense to update the documentation with this? Happy to open a PR.

@avik-pal
Copy link
Member

avik-pal commented Oct 1, 2024

yes, add a section after the full jacobian

@facusapienza21
Copy link
Contributor Author

Even with this improvement in how the calculation of the Jacobian is batched, I still observe that in this example the training of the UDE suffers drastically in terms of running time due to the regularization term. I was expecting the adjoint method on the numerical solver to be more expensive than the computation of the Jacobian of the NN wrt the input layer for <100 input values (+differentiation during the reverse pass). Does this makes sense for you @avik-pal ? Happy to provide with the example.

@avik-pal
Copy link
Member

avik-pal commented Oct 2, 2024

I was expecting the adjoint method on the numerical solver to be more expensive than the computation of the Jacobian of the NN wrt the input layer for <100 input values (+differentiation during the reverse pass).

That is not generally true. Just computing the jacobian will take ~ (100 / (batch_size * chunksize)) JVPs, and differentiating that would take the same number of JVPs over VJPs. Adjoint method is roughly bottlenecked by N VJPs where N is the number of (backward) timesteps. Unless the ODE is extremely expensive to solve (and similarly the adjoint is expensive) the former would be more expensive.

Generally for regularization, I would recommend using some form of a stochastic approximator for that using JVPs or VJPs

@avik-pal
Copy link
Member

avik-pal commented Oct 7, 2024

Since there is nothing actionable here, I am closing this. Feel free to post any questions at https://github.com/orgs/LuxDL/discussions, and I can take a look at them.

@avik-pal avik-pal closed this as completed Oct 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants