-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Differentiating Zygote.pullback
#621
Comments
It works with |
|
pullbacks are hard to differentiate directly. See #610 (comment). We just need some rrules for DifferentiationInterface.pullback. Zygote.pullback is almost never going to work, unless someone can use some nice trick to write the tangent for the pullback function. DI.pullback on the other hand is quite simple, DEQs.jl already does that |
Zygote.pullback
I'm still getting the error with (br-3) pkg> st
Status `D:\Codes\Mine\bug-report\br-3\Project.toml`
[b0b7db55] ComponentArrays v0.15.11
[a0c0ee7d] DifferentiationInterface v0.3.3
[f6369f11] ForwardDiff v0.10.36
[b2108857] Lux v0.5.42
[e88e6eb3] Zygote v0.6.69
[9a3f8284] Random
(br-3) pkg> st --outdated
Status `D:\Codes\Mine\bug-report\br-3\Project.toml`
(br-3) pkg> st --outdated -m
Status `D:\Codes\Mine\bug-report\br-3\Manifest.toml` Even after using DI: using ComponentArrays, Lux, Random, Zygote, ForwardDiff, DifferentiationInterface
nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)
snn = StatefulLuxLayer(nn, ps, st)
function fn1(u, p)
z, uJ = DifferentiationInterface.value_and_pullback(x -> snn(x, p), AutoZygote(), u, u)
sum(uJ) + sum(z)
end
fn1(r, ps)
DifferentiationInterface.gradient(x -> fn1(r, x), AutoZygote(), ps) |
Use Lux.vector_jacobian_product |
I tried: using ComponentArrays, Lux, Random, Zygote, ForwardDiff, DifferentiationInterface
nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)
snn = StatefulLuxLayer(nn, ps, st)
function fn1(u, p)
z, uJ = Lux.vector_jacobian_product(x -> snn(x, p), AutoZygote(), u, u)
sum(uJ) + sum(z)
end
fn1(r, ps)
# DifferentiationInterface.gradient(x -> fn1(r, x), AutoZygote(), ps)
Zygote.gradient(fn1, r, ps) But the problem is still there. 😬 |
Closures don't work, see the first part in https://lux.csail.mit.edu/stable/manual/nested_autodiff. Also https://lux.csail.mit.edu/stable/api/Lux/utilities#Lux.vector_jacobian_product returns only the vjp not the value and vjp (which can be added later but doesn't affect the code by much) |
Thanks, the problem is resolved. using ComponentArrays, Lux, Random, Zygote, ForwardDiff, DifferentiationInterface
nn = Dense(2, 2, tanh)
r = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps)
function fn1(u, ps, st)
snn = StatefulLuxLayer(nn, ps, st)
z = snn(u)
uJ = Lux.vector_jacobian_product(snn, AutoZygote(), u, u)
sum(uJ) + sum(z)
end
fn1(r, ps, st)
# DifferentiationInterface.gradient(x -> fn1(r, x, st), AutoZygote(), ps)
Zygote.gradient(fn1, r, ps, st) |
Capturing DI would have been the ideal situation but it causes ambiguities and I would have to manually define the functions for all possibilities which will get messy #600 (comment) |
Error:
MRE:
Environment:
The text was updated successfully, but these errors were encountered: