diff --git a/src/ode_solve.jl b/src/ode_solve.jl index 8fc621cda..a70d453d5 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -278,28 +278,28 @@ L2 inner loss for DAEProblems function inner_loss_DAE end function inner_loss_DAE(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ, - p, u) where {C, T, U <: Number} - sum(abs2, ode_dfdx(phi, t, θ, autodiff) - f(phi(t, θ), u, p, t)) + p) where {C, T, U <: Number} + sum(abs2, ode_dfdx(phi, t, θ, autodiff) - f(ode_dfdx(phi, t, θ, autodiff), phi, p, t)) end function inner_loss_DAE(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ, - p, u) where {C, T, U <: Number} + p) where {C, T, U <: Number} out = phi(t, θ) dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff)) - sum(abs2, f(dxdtguess[i], u, p, t[i]) for i in 1:size(out, 2)) / length(t) + sum(abs2, f(dxdtguess[i], phi, p, t[i]) for i in 1:size(out, 2)) / length(t) end function inner_loss_DAE(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ, - p, u) where {C, T, U} - sum(abs2, ode_dfdx(phi, t, θ, autodiff) .- f(phi(t, θ), u, p, t)) + p) where {C, T, U} + sum(abs2, ode_dfdx(phi, t, θ, autodiff) .- f(ode_dfdx(phi, t, θ, autodiff), phi, p, t)) end function inner_loss_DAE(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ, - p, u) where {C, T, U} + p) where {C, T, U} out = Array(phi(t, θ)) arrt = Array(t) dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff)) - sum(abs2, f(dxdtguess[:, i], u, p, arrt[i]) for i in 1:size(out, 2)) / length(t) + sum(abs2, f(dxdtguess[:, i], phi, p, arrt[i]) for i in 1:size(out, 2)) / length(t) end """ @@ -387,10 +387,10 @@ end """ Representation of the loss function, parametric on the training strategy `strategy` for DAE problems """ -function generate_loss_DAE(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p, u, +function generate_loss_DAE(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p, batch) - integrand(t::Number, θ) = abs2(inner_loss_DAE(phi, f, autodiff, t, θ, p, u)) - integrand(ts, θ) = [abs2(inner_loss_DAE(phi, f, autodiff, t, θ, p, u)) for t in ts] + integrand(t::Number, θ) = abs2(inner_loss_DAE(phi, f, autodiff, t, θ, p)) + integrand(ts, θ) = [abs2(inner_loss_DAE(phi, f, autodiff, t, θ, p)) for t in ts] @assert batch == 0 # not implemented function loss(θ, _) @@ -402,36 +402,36 @@ function generate_loss_DAE(strategy::QuadratureTraining, phi, f, autodiff::Bool, return loss end -function generate_loss_DAE(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, u, batch) +function generate_loss_DAE(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, batch) ts = tspan[1]:(strategy.dx):tspan[2] # sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken function loss(θ, _) if batch - sum(abs2, inner_loss_DAE(phi, f, autodiff, ts, θ, p, u)) + sum(abs2, inner_loss_DAE(phi, f, autodiff, ts, θ, p)) else - sum(abs2, [inner_loss_DAE(phi, f, autodiff, t, θ, p, u) for t in ts]) + sum(abs2, [inner_loss_DAE(phi, f, autodiff, t, θ, p) for t in ts]) end end return loss end -function generate_loss_DAE(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p, u, +function generate_loss_DAE(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p, batch) # sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken function loss(θ, _) ts = adapt(parameterless_type(θ), [(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)]) if batch - sum(abs2, inner_loss_DAE(phi, f, autodiff, ts, θ, p, u)) + sum(abs2, inner_loss_DAE(phi, f, autodiff, ts, θ, p)) else - sum(abs2, [inner_loss_DAE(phi, f, autodiff, t, θ, p, u) for t in ts]) + sum(abs2, [inner_loss_DAE(phi, f, autodiff, t, θ, p) for t in ts]) end end return loss end -function generate_loss_DAE(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p, u, +function generate_loss_DAE(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p, batch) minT = tspan[1] maxT = tspan[2] @@ -454,15 +454,15 @@ function generate_loss_DAE(strategy::WeightedIntervalTraining, phi, f, autodiff: function loss(θ, _) if batch - sum(abs2, inner_loss_DAE(phi, f, autodiff, ts, θ, p, u)) + sum(abs2, inner_loss_DAE(phi, f, autodiff, ts, θ, p)) else - sum(abs2, [inner_loss_DAE(phi, f, autodiff, t, θ, p, u) for t in ts]) + sum(abs2, [inner_loss_DAE(phi, f, autodiff, t, θ, p) for t in ts]) end end return loss end -function generate_loss_DAE(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, tspan, p, u, batch) +function generate_loss_DAE(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, tspan, p, batch) error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional spaces only. Use StochasticTraining instead.") end @@ -685,9 +685,9 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractDAEProblem, end else alg.batch - endx + end - inner_f = generate_loss_DAE(strategy, phi, f, autodiff, tspan, p, u0, batch) + inner_f = generate_loss_DAE(strategy, phi, f, autodiff, tspan, p, batch) additional_loss = alg.additional_loss # Creates OptimizationFunction Object from total_loss diff --git a/test/dae_problem_test.jl b/test/dae_problem_test.jl new file mode 100644 index 000000000..1ca7eb096 --- /dev/null +++ b/test/dae_problem_test.jl @@ -0,0 +1,50 @@ +using DAEProblemLibrary, Sundials, Optimisers, OptimizationOptimisers, DifferentialEquations +using NeuralPDE, Lux, Test, Statistics, Plots + +f = function (yp, y, p, tres) + [-0.04 * y[1] + 1.0e4 * y[2] * y[3] - yp[1], + -(-0.04 * y[1] + 1.0e4 * y[2] * y[3]) - 3.0e7 * y[2] * y[2] - yp[2], + y[1] + y[2] + y[3] - 1.0] +end +u0 = [1.0, 0, 0] +du0 = [-0.04, 0.04, 0.0] + +println("f defined") +""" +The Robertson biochemical reactions in DAE form + +```math +\frac{dy₁}{dt} = -k₁y₁+k₃y₂y₃ +``` +```math +\frac{dy₂}{dt} = k₁y₁-k₂y₂^2-k₃y₂y₃ +``` +```math +1 = y₁ + y₂ + y₃ +``` +where ``k₁=0.04``, ``k₂=3\times10^7``, ``k₃=10^4``. For details, see: +Hairer Norsett Wanner Solving Ordinary Differential Equations I - Nonstiff Problems Page 129 +Usually solved on ``[0,1e11]`` +""" + +prob_oop = DAEProblem{false}(f, du0, u0, (0.0, 100000.0)) +true_sol = solve(prob_oop, IDA(), saveat = 0.01) + +u0 = [1.0, 1.0, 1.0] +func = Lux.σ +N = 12 +chain = Lux.Chain(Lux.Dense(1, N, func), Lux.Dense(N, N, func), Lux.Dense(N, N, func), + Lux.Dense(N, N, func), Lux.Dense(N, length(u0))) + +opt = Optimisers.Adam(0.01) +dx = 0.05 +alg = NeuralPDE.NNDAE(chain, opt, autodiff = false, strategy = NeuralPDE.GridTraining(dx)) +sol = solve(prob_oop, alg, verbose=true, maxiters = 100000, saveat = 0.01) + +# println(abs(mean(true_sol .- sol))) + +# using Plots + +# plot(sol) +# plot!(true_sol) +# # ylims!(0,8) \ No newline at end of file