Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InexactError when using complex matrices #928

Closed
seadra opened this issue May 31, 2024 · 4 comments · Fixed by SciML/SciMLSensitivity.jl#1064
Closed

InexactError when using complex matrices #928

seadra opened this issue May 31, 2024 · 4 comments · Fixed by SciML/SciMLSensitivity.jl#1064
Labels
bug Something isn't working

Comments

@seadra
Copy link

seadra commented May 31, 2024

The following code fails with an error

using DifferentialEquations, DiffEqFlux, Zygote, SciMLSensitivity, Optimization, OptimizationFlux, OptimizationOptimJL, ComponentArrays, Lux, Random, LinearAlgebra

const T = 10.0;
const ω = π/T;

const id = Matrix{Complex{Float64}}(I,2, 2);
const u0 = id;


const utarget = Matrix{Complex{Float64}}([im 0; 0 -im]);

ann = Lux.Chain(Lux.Dense(1,32), Lux.Dense(32,32,tanh), Lux.Dense(32,1));
rng = Random.default_rng();
ip, st = Lux.setup(rng, ann);

function f_nn(u, p, t)
    local a, _ = ann([t/T],p,st);
    local A = [sin(a[1]) 0.0; 0.0 -a[1]];
    return -(im*A)*u;
end



tspan = (0.0, T)
prob_ode = ODEProblem(f_nn, u0, tspan, ComponentArray(ip));


function loss_adjoint(p)
    local prediction = solve(prob_ode, BS5(), p=p, abstol=1e-7, reltol=1e-7)
    local usol = last(prediction)
    local loss = abs(1.0 - abs(tr(usol*utarget')/2))
    return loss
end

opt_f = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), Optimization.AutoZygote());
opt_prob = Optimization.OptimizationProblem(opt_f, ComponentArray(ip));
optimized_sol_nn = Optimization.solve(opt_prob, AMSGrad(0.001), maxiters = 100, progress=true);

The error is

┌ Warning: Potential performance improvement omitted. ZygoteVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/bspwn/src/concrete_solve.jl:100
┌ Warning: Potential performance improvement omitted. ReverseDiffVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/bspwn/src/concrete_solve.jl:117

┌ Warning: Potential performance improvement omitted. TrackerVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/bspwn/src/concrete_solve.jl:135
┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/bspwn/src/concrete_solve.jl:145


InexactError: Float32(0.23771682913973122 - 0.0036578124321906547im)

I narrowed down the problem to the line [sin(a[1]) 0.0; 0.0 -a[1]];: if we change it to [a[1] 0.0; 0.0 -a[1]]; (as it was in this issue), it works.

@seadra seadra added the bug Something isn't working label May 31, 2024
@seadra
Copy link
Author

seadra commented May 31, 2024

A bit tangential regarding those warnings about potential performance improvements, is there a way I can get those performance improvements such a loss function?

ChrisRackauckas added a commit to SciML/SciMLSensitivity.jl that referenced this issue Jun 6, 2024
Last little bit to fix SciML/DiffEqFlux.jl#928 and make that nicer
@ChrisRackauckas
Copy link
Member

Your code doesn't use DiffEqFlux, so just remove it from the using and submit this to SciMLSensitivity in the future.

With SciML/SciMLSensitivity.jl#1064 your code is much faster. You do require complex coefficients in this case. So with that PR and the new patch, the following is really good:

using OrdinaryDiffEq, Zygote, SciMLSensitivity, Optimization, OptimizationOptimisers,
      ComponentArrays, Lux, Random, LinearAlgebra

const T = 10.0;
const ω = π/T;
const id = Matrix{Complex{Float64}}(I,2, 2);
const u0 = id;
const utarget = Matrix{Complex{Float64}}([im 0; 0 -im]);

ann = Lux.Chain(Lux.Dense(1,32), Lux.Dense(32,32,tanh), Lux.Dense(32,1));
rng = Random.default_rng();
ip, st = Lux.setup(rng, ann);

function f_nn(u, p, t)
    local a, _ = ann([t/T],p,st);
    local A = [sin(a[1]) 0.0; 0.0 -a[1]];
    return -(im*A)*u;
end

tspan = (0.0, T)
prob_ode = ODEProblem(f_nn, u0, tspan, ComponentArray{ComplexF32}(ip));

function loss_adjoint(p)
    local prediction = solve(prob_ode, BS5(), p=p, abstol=1e-7, reltol=1e-7)
    local usol = last(prediction)
    local loss = abs(1.0 - abs(tr(usol*utarget')/2))
    return loss
end

opt_f = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), Optimization.AutoZygote());
opt_prob = Optimization.OptimizationProblem(opt_f, ComponentArray{ComplexF32}(ip));
optimized_sol_nn = Optimization.solve(opt_prob, Adam(0.001), maxiters = 100, progress=true);

@seadra
Copy link
Author

seadra commented Jun 11, 2024

Thank you very much for the great progress on supporting complex matrices.

I'm a bit lost on this comment that you made:

You do require complex coefficients in this case.

But I need them to be real. If a[1] becomes complex, the equation corresponds to something unphysical.

@seadra
Copy link
Author

seadra commented Jun 11, 2024

Using real coefficients used to work with DiffEqSensitivity.

I just tried using

using OrdinaryDiffEq, Zygote, SciMLSensitivity, Optimization, OptimizationOptimisers,
      ComponentArrays, Lux, Random, LinearAlgebra

const T = 10.0;
const ω = π/T;
const id = Matrix{Complex{Float64}}(I,2, 2);
const u0 = id;
const utarget = Matrix{Complex{Float64}}([im 0; 0 -im]);

ann = Lux.Chain(Lux.Dense(1,32), Lux.Dense(32,32,tanh), Lux.Dense(32,1));
rng = Random.default_rng();
ip, st = Lux.setup(rng, ann);

function f_nn(u, p, t)
    local a, _ = ann([t/T],p,st);
    local A = [sin(a[1]) 0.0; 0.0 -a[1]];
    return -(im*A)*u;
end

tspan = (0.0, T)
prob_ode = ODEProblem(f_nn, u0, tspan, ComponentArray{Float32}(ip));

function loss_adjoint(p)
    local prediction = solve(prob_ode, BS5(), p=p, abstol=1e-7, reltol=1e-7)
    local usol = last(prediction)
    local loss = abs(1.0 - abs(tr(usol*utarget')/2))
    return loss
end

opt_f = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), Optimization.AutoZygote());
opt_prob = Optimization.OptimizationProblem(opt_f, ComponentArray{Float32}(ip));
optimized_sol_nn = Optimization.solve(opt_prob, Adam(0.001), maxiters = 100, progress=true);

and it seems to have worked.

Is there something that I'm missing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants