Skip to content

Commit

Permalink
Fixed some of the issues with problem definitions and loss function g…
Browse files Browse the repository at this point in the history
…eneration, still have an error
  • Loading branch information
Samedh Desai committed Jul 26, 2023
1 parent 595359f commit 41c9084
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 23 deletions.
46 changes: 23 additions & 23 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,28 +278,28 @@ L2 inner loss for DAEProblems
function inner_loss_DAE end

function inner_loss_DAE(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ,
p, u) where {C, T, U <: Number}
sum(abs2, ode_dfdx(phi, t, θ, autodiff) - f(phi(t, θ), u, p, t))
p) where {C, T, U <: Number}
sum(abs2, ode_dfdx(phi, t, θ, autodiff) - f(ode_dfdx(phi, t, θ, autodiff), phi, p, t))
end

function inner_loss_DAE(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ,
p, u) where {C, T, U <: Number}
p) where {C, T, U <: Number}
out = phi(t, θ)
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
sum(abs2, f(dxdtguess[i], u, p, t[i]) for i in 1:size(out, 2)) / length(t)
sum(abs2, f(dxdtguess[i], phi, p, t[i]) for i in 1:size(out, 2)) / length(t)
end

function inner_loss_DAE(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ,
p, u) where {C, T, U}
sum(abs2, ode_dfdx(phi, t, θ, autodiff) .- f(phi(t, θ), u, p, t))
p) where {C, T, U}
sum(abs2, ode_dfdx(phi, t, θ, autodiff) .- f(ode_dfdx(phi, t, θ, autodiff), phi, p, t))
end

function inner_loss_DAE(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ,
p, u) where {C, T, U}
p) where {C, T, U}
out = Array(phi(t, θ))
arrt = Array(t)
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
sum(abs2, f(dxdtguess[:, i], u, p, arrt[i]) for i in 1:size(out, 2)) / length(t)
sum(abs2, f(dxdtguess[:, i], phi, p, arrt[i]) for i in 1:size(out, 2)) / length(t)
end

"""
Expand Down Expand Up @@ -387,10 +387,10 @@ end
"""
Representation of the loss function, parametric on the training strategy `strategy` for DAE problems
"""
function generate_loss_DAE(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p, u,
function generate_loss_DAE(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p,
batch)
integrand(t::Number, θ) = abs2(inner_loss_DAE(phi, f, autodiff, t, θ, p, u))
integrand(ts, θ) = [abs2(inner_loss_DAE(phi, f, autodiff, t, θ, p, u)) for t in ts]
integrand(t::Number, θ) = abs2(inner_loss_DAE(phi, f, autodiff, t, θ, p))
integrand(ts, θ) = [abs2(inner_loss_DAE(phi, f, autodiff, t, θ, p)) for t in ts]
@assert batch == 0 # not implemented

function loss(θ, _)
Expand All @@ -402,36 +402,36 @@ function generate_loss_DAE(strategy::QuadratureTraining, phi, f, autodiff::Bool,
return loss
end

function generate_loss_DAE(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, u, batch)
function generate_loss_DAE(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, batch)
ts = tspan[1]:(strategy.dx):tspan[2]

# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
function loss(θ, _)
if batch
sum(abs2, inner_loss_DAE(phi, f, autodiff, ts, θ, p, u))
sum(abs2, inner_loss_DAE(phi, f, autodiff, ts, θ, p))
else
sum(abs2, [inner_loss_DAE(phi, f, autodiff, t, θ, p, u) for t in ts])
sum(abs2, [inner_loss_DAE(phi, f, autodiff, t, θ, p) for t in ts])
end
end
return loss
end

function generate_loss_DAE(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p, u,
function generate_loss_DAE(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p,
batch)
# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
function loss(θ, _)
ts = adapt(parameterless_type(θ),
[(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)])
if batch
sum(abs2, inner_loss_DAE(phi, f, autodiff, ts, θ, p, u))
sum(abs2, inner_loss_DAE(phi, f, autodiff, ts, θ, p))
else
sum(abs2, [inner_loss_DAE(phi, f, autodiff, t, θ, p, u) for t in ts])
sum(abs2, [inner_loss_DAE(phi, f, autodiff, t, θ, p) for t in ts])
end
end
return loss
end

function generate_loss_DAE(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p, u,
function generate_loss_DAE(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p,
batch)
minT = tspan[1]
maxT = tspan[2]
Expand All @@ -454,15 +454,15 @@ function generate_loss_DAE(strategy::WeightedIntervalTraining, phi, f, autodiff:

function loss(θ, _)
if batch
sum(abs2, inner_loss_DAE(phi, f, autodiff, ts, θ, p, u))
sum(abs2, inner_loss_DAE(phi, f, autodiff, ts, θ, p))
else
sum(abs2, [inner_loss_DAE(phi, f, autodiff, t, θ, p, u) for t in ts])
sum(abs2, [inner_loss_DAE(phi, f, autodiff, t, θ, p) for t in ts])
end
end
return loss
end

function generate_loss_DAE(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, tspan, p, u, batch)
function generate_loss_DAE(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, tspan, p, batch)
error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional spaces only. Use StochasticTraining instead.")
end

Expand Down Expand Up @@ -685,9 +685,9 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractDAEProblem,
end
else
alg.batch
endx
end

inner_f = generate_loss_DAE(strategy, phi, f, autodiff, tspan, p, u0, batch)
inner_f = generate_loss_DAE(strategy, phi, f, autodiff, tspan, p, batch)
additional_loss = alg.additional_loss

# Creates OptimizationFunction Object from total_loss
Expand Down
50 changes: 50 additions & 0 deletions test/dae_problem_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using DAEProblemLibrary, Sundials, Optimisers, OptimizationOptimisers, DifferentialEquations
using NeuralPDE, Lux, Test, Statistics, Plots

f = function (yp, y, p, tres)
[-0.04 * y[1] + 1.0e4 * y[2] * y[3] - yp[1],
-(-0.04 * y[1] + 1.0e4 * y[2] * y[3]) - 3.0e7 * y[2] * y[2] - yp[2],
y[1] + y[2] + y[3] - 1.0]
end
u0 = [1.0, 0, 0]
du0 = [-0.04, 0.04, 0.0]

println("f defined")
"""
The Robertson biochemical reactions in DAE form
```math
\frac{dy₁}{dt} = -k₁y₁+k₃y₂y₃
```
```math
\frac{dy₂}{dt} = k₁y₁-k₂y₂^2-k₃y₂y₃
```
```math
1 = y₁ + y₂ + y₃
```
where ``k₁=0.04``, ``k₂=3\times10^7``, ``k₃=10^4``. For details, see:
Hairer Norsett Wanner Solving Ordinary Differential Equations I - Nonstiff Problems Page 129
Usually solved on ``[0,1e11]``
"""

prob_oop = DAEProblem{false}(f, du0, u0, (0.0, 100000.0))
true_sol = solve(prob_oop, IDA(), saveat = 0.01)

u0 = [1.0, 1.0, 1.0]
func = Lux.σ
N = 12
chain = Lux.Chain(Lux.Dense(1, N, func), Lux.Dense(N, N, func), Lux.Dense(N, N, func),
Lux.Dense(N, N, func), Lux.Dense(N, length(u0)))

opt = Optimisers.Adam(0.01)
dx = 0.05
alg = NeuralPDE.NNDAE(chain, opt, autodiff = false, strategy = NeuralPDE.GridTraining(dx))
sol = solve(prob_oop, alg, verbose=true, maxiters = 100000, saveat = 0.01)

# println(abs(mean(true_sol .- sol)))

# using Plots

# plot(sol)
# plot!(true_sol)
# # ylims!(0,8)

0 comments on commit 41c9084

Please sign in to comment.