Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lux support for complex differentiation #977

Open
facusapienza21 opened this issue Oct 11, 2024 · 1 comment
Open

Lux support for complex differentiation #977

facusapienza21 opened this issue Oct 11, 2024 · 1 comment

Comments

@facusapienza21
Copy link
Contributor

Hi! In my efforts to assess different modes of nested differentiation tools, I am interested in the combination of complex step differentiation with reverse AD. I have the following example of a small neural network where I pass complex numbers, but when computing the gradient in reverse mode it gives an error.

I am posting this issue here because this used to work with previous versions of Lux, so I am wondering if new changes had been added to now break this kind of uses

using Lux, Zygote
using Random
rng = Random.default_rng()
Random.seed!(rng, 666)

rbf(x) = exp.(-(x .^ 2))

U = Lux.Chain(
    Lux.Dense(1, 10, rbf),
    Lux.Dense(10, 3, rbf)
)

θ, st = Lux.setup(rng, U)

@show U([1.0], θ, st)[begin]

function complex_step_differentiation(f::Function, x::Float64, ϵ::Float64)
    return imag(f(x + ϵ * im)) / ϵ
end

@show complex_step_differentiation(t -> U([t], θ, st)[begin], 1.0, 1e-5)

loss(t) = complex_step_differentiation-> U([τ], θ, st)[begin], t, 1e-5)

Zygote.gradient(t -> loss(t), 1.0)

Image

@avik-pal something like this is what I was using before the new nested AD feature in Lux, so I would like to have all the differentiation modes working with the same architecture/Lux version since I am hoping to be able to compare them. Happy to provide more information!

@avik-pal
Copy link
Member

This is Zygote trying to use forwarddiff for differentiating broadcasting of complex numbers using forwarddiff which won't work nicely. Can you add a rrule for rbf? That will hit a more optimized case that doesn't using ForwardDiff

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants