Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

usage of universal Differential Algebraic Equations #842

Closed
ghost opened this issue Jul 18, 2023 · 1 comment
Closed

usage of universal Differential Algebraic Equations #842

ghost opened this issue Jul 18, 2023 · 1 comment

Comments

@ghost
Copy link

ghost commented Jul 18, 2023

Hello,

I have been trying to create a universal Differential Algebraic Equation (I want to enforce some physical constraints).

There is a test case in here. However, I cannot run it, so I tried a couple of modifications to make it work.

Here is what I got:

script 1

using DiffEqFlux, OrdinaryDiffEq, Test, Plots, Lux, StableRNGs, ComponentArrays, Sundials

using Optimization, OptimizationOptimisers, OptimizationOptimJL

function f!(du, u, p, t)
    y₁,y₂,y₃ = u
    k₁,k₂,k₃ = p
    du[1] = -k₁*y₁ + k₃*y₂*y₃
    du[2] =  k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2
    du[3] =  y₁ + y₂ + y₃ - 1
    nothing
end

u₀ = [1.0, 0, 0]

M = [1. 0  0
     0  1. 0
     0  0  0]

tspan = (0.0, 10.0)
p_true = [0.04, 3e7, 1e4]

func = ODEFunction(f!, mass_matrix=M)
prob = ODEProblem(func, u₀, tspan, p_true)
sol = solve(prob, Rodas5(), saveat=0.1, 
            abstol = 1e-9, 
            reltol = 1e-9 )

t_true = sol.t

# Neural Network
rng = StableRNG(1111);
U = Lux.Chain(Lux.Dense(3, 64, tanh), Lux.Dense(64, 2))
p, st = Lux.setup(rng, U)
p = ComponentArray(p)

function ude_dynamics!(du, u, p, t, p_true)
    NN = U(u, p, st)[1]
    du[1] = NN[1];
    du[2] = NN[2]
    du[3] = u[1] + u[2] + u[3] - 1.0;
end

# Closure with the known parameters
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_true);

prob_nn = ODEProblem(ODEFunction(nn_dynamics!, mass_matrix = M), u₀ , tspan, p);

function predict(θ, X = u₀, T = t_true)
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)

    Array(solve(_prob, Rodas5(), saveat = T,
                abstol = 1e-4, reltol = 1e-4, verbose=false))
end

function loss(p)
    pred = predict(p)
    loss = sum(abs2,sol .- pred)
    loss,pred
end

losses = Float64[];

callback = function (p, l, pred)
    push!(losses, l)
    if length(losses) % 10 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
        println("Current min/max after $(length(losses)) iterations: $(extrema(pred[1, :] + pred[2, :] + pred[3, :]))")
    end
    return false
end

optfunc = Optimization.OptimizationFunction((x, p) -> loss(x), Optimization.AutoZygote())
optprob = Optimization.OptimizationProblem(optfunc, p)
res = Optimization.solve(  optprob,
                            ADAM(), 
                            callback = callback, 
                            maxiters = 100
                         )

In the context of uDAE I created a script based on uODE with the lotka-volterra system:

script 2

using DiffEqFlux, OrdinaryDiffEq, Test, Plots, Lux, StableRNGs, ComponentArrays

using Optimization, OptimizationOptimisers, OptimizationOptimJL

function f!(du, u, p, t)
    y₁,y₂,y₃ = u
    k₁,k₂,k₃ = p
    du[1] = -k₁*y₁ + k₃*y₂*y₃
    du[2] =  k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2
    du[3] =  y₁ + y₂ + y₃ - 1
    nothing
end

u₀ = [1.0, 0, 0]

M = [1. 0  0
     0  1. 0
     0  0  0]

tspan = (0.0, 10.0)
p_true = [0.04, 3e7, 1e4]

func = ODEFunction(f!, mass_matrix=M)
prob = ODEProblem(func, u₀, tspan, p_true)
sol = solve(prob, Rodas5(), saveat=0.1, 
            abstol = 1e-9, 
            reltol = 1e-9 )

t_true = sol.t

# Neural Network
rng = StableRNG(1111);
U = Lux.Chain(Lux.Dense(3, 64, tanh), Lux.Dense(64, 2))
p, st = Lux.setup(rng, U)
p = ComponentArray(p)

function ude_dynamics!(du, u, p, t, p_true)
    NN = U(u, p, st)[1]
    du[1] = NN[1];
    du[2] = NN[2]
    du[3] = u[1] + u[2] + u[3] - 1.0;
end

# Closure with the known parameters
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_true);

prob_nn = ODEProblem(ODEFunction(nn_dynamics!, mass_matrix = M), u₀ , tspan, p);

function predict(θ, X = u₀, T = t_true)
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)

    Array(solve(_prob, Rodas5(), saveat = T,
                abstol = 1e-4, reltol = 1e-4, verbose=false))
end

function loss(p)
    pred = predict(p)
    loss = sum(abs2,sol .- pred)
    loss,pred
end

losses = Float64[];

callback = function (p, l, pred)
    push!(losses, l)
    if length(losses) % 10 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
        println("Current min/max after $(length(losses)) iterations: $(extrema(pred[1, :] + pred[2, :] + pred[3, :]))")
    end
    return false
end

optfunc = Optimization.OptimizationFunction((x, p) -> loss(x), Optimization.AutoZygote())
optprob = Optimization.OptimizationProblem(optfunc, p)
res = Optimization.solve(  optprob,
                            ADAM(), 
                            callback = callback, 
                            maxiters = 100
                         )

The end result of both scripts is similar but not exactly the same (don't know way, but I guess some floating point round-off)

However, the computational time is very high. Did I do something wrong that is costing a lot of resources? Or are uDAE expensive by themselves? What can I do to make the code faster?

Best Regards

@ChrisRackauckas
Copy link
Member

There is a test case in here. However, I cannot run it, so I tried a couple of modifications to make it work.

That test passes on latest versions, so just make sure you're on latest (Julia v1.9 with latest DiffEqFlux and SciMLSensitivity). I just ran the test suite and it went fine.

The end result of both scripts is similar but not exactly the same (don't know way, but I guess some floating point round-off)

Solving to 1e-4 accuracy locally is about 1e-3 - 1e-2 globally each step of an optimization for 100 steps of an optimization, so digits of accuracy each step of an optimization. Yeah that's not going to be the most stable. If you need more stability then lower the tolerances.

However, the computational time is very high. Did I do something wrong that is costing a lot of resources? Or are uDAE expensive by themselves? What can I do to make the code faster?

Using Rodas5 will be quite expensive here with this choice of adjoint. Using sensealg=GaussAdjoint() (which just merged today) should be a lot faster for this use case. Also, you may want to look into using FBDF() as the solver here. Both should cut the cost down a lot in the adjoint pass.

But there doesn't seem to be anything actionable here, so I'm closing it. Feel free to keep asking questions, though for usage questions non-bug reports we recommend using the Discourse https://discourse.julialang.org/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant