-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Capture DifferentiationInterface calls for efficient Nested AD #600
Comments
Commenting so I keep track of this. If you think something deserves to be in DI, let me know! |
@gdalle what do you think about the last part in https://discourse.julialang.org/t/ann-lux-jl-explicitly-parameterized-neural-networks-in-julia/81689/65?u=avikpal? If that exists in DI, I can just unwrap StatefulLuxLayer into that DI struct and forward the call |
Can't we just specialize on |
I was trying that for the gradient calls but DI specializes on the extras type which means we will also have to specialize on each extras for all backends |
To support second order for Enzyme, I introduced |
I'm not sure, cause there are several things one might want to do with nested backends, and depending on the situation this lux replacement trick may not always be appropriate? |
Just putting it out there in case Avik is inspired. Essentially, modifying the backend is the cleanest approach I could think of for this type of problem |
To clarify how nested AD works in Lux: It doesn't simply switch the backends, i.e. we don't take a The only case where replacement is not ideal is All the other forms of Zygote over ForwardDiff or Zygote over Zygote (or any reverse mode over X-mode) have no computational benefit and will error in most cases, so it does make sense to switch. Even doing an Footnotes
|
Oh right, my |
We capture:
after #598. We should capture the DI jacobian, gradient, and, most importantly
pullback
calls to augment them with the faster versions.An important question here is where we should switch all calls or only calls with
SecondOrder
. I prefer the former, where we can just use forwarddiff to do the AD. Maybe forSecondOrder
we respect the user choice.The text was updated successfully, but these errors were encountered: