Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NNODE for DAE Problems #695

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ include("transform_inf_integral.jl")
include("discretize.jl")
include("neural_adapter.jl")

export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE,
export NNODE, NNDAE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE,
KolmogorovPDEProblem, NNKolmogorov, NNStopping, ParamKolmogorovPDEProblem,
KolmogorovParamDomain, NNParamKolmogorov,
PhysicsInformedNN, discretize,
Expand Down
25 changes: 25 additions & 0 deletions src/dae_problem_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using DAEProblemLibrary, Sundials, Optimisers, OptimizationOptimisers, DifferentialEquations
using NeuralPDE, Lux, Test, Statistics, Plots

prob = DAEProblemLibrary.prob_dae_resrob
true_sol = solve(prob, IDA(), saveat = 0.01)
# sol = solve(prob, IDA())

u0 = [1.0, 1.0, 1.0]
func = Lux.σ
N = 12
chain = Lux.Chain(Lux.Dense(1, N, func), Lux.Dense(N, N, func), Lux.Dense(N, N, func),
Lux.Dense(N, N, func), Lux.Dense(N, length(u0)))

opt = Optimisers.Adam(0.01)
dx = 0.05
alg = NeuralPDE.NNDAE(chain, opt, autodiff = false, strategy = NeuralPDE.GridTraining(dx))
sol = solve(prob, alg, verbose=true, maxiters = 100000, saveat = 0.01)

println(abs(mean(true_sol .- sol)))

using Plots

plot(sol)
plot!(true_sol)
# ylims!(0,8)
287 changes: 275 additions & 12 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
abstract type NeuralPDEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end
abstract type NeuralPDEAlgorithmDAE <: DiffEqBase.AbstractDAEAlgorithm end

"""
```julia
Expand Down Expand Up @@ -29,7 +30,8 @@

## Example

```julia
```juliap = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0, 1.0]
ts=[t for t in 1:100]
(u_, t_) = (analytical_func(ts), ts)
function additional_loss(phi, θ)
Expand Down Expand Up @@ -96,6 +98,25 @@
NNODE(chain, opt, init_params, autodiff, batch, strategy, additional_loss, kwargs)
end

struct NNDAE{C, O, P, B, K, AL <: Union{Nothing, Function},
S <: Union{Nothing, AbstractTrainingStrategy}
} <:
NeuralPDEAlgorithmDAE
chain::C
opt::O
init_params::P
autodiff::Bool
batch::B
strategy::S
additional_loss::AL
kwargs::K
end
function NNDAE(chain, opt, init_params = nothing;

Check warning on line 114 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L114

Added line #L114 was not covered by tests
strategy = nothing,
autodiff = false, batch = nothing, additional_loss = nothing, kwargs...)
NNDAE(chain, opt, init_params, autodiff, batch, strategy, additional_loss, kwargs)

Check warning on line 117 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L117

Added line #L117 was not covered by tests
end

"""
```julia
ODEPhi(chain::Lux.AbstractExplicitLayer, t, u0, st)
Expand All @@ -120,23 +141,21 @@
end
end

function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params::Nothing)
θ, st = Lux.setup(Random.default_rng(), chain)
ODEPhi(chain, t, u0, st), ComponentArrays.ComponentArray(θ)
end

function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params)
θ, st = Lux.setup(Random.default_rng(), chain)
ODEPhi(chain, t, u0, st), ComponentArrays.ComponentArray(init_params)
end

function generate_phi_θ(chain::Flux.Chain, t, u0, init_params::Nothing)
θ, re = Flux.destructure(chain)
ODEPhi(re, t, u0), θ
if init_params === nothing
init_params = ComponentArrays.ComponentArray(θ)

Check warning on line 147 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L146-L147

Added lines #L146 - L147 were not covered by tests
else
init_params = ComponentArrays.ComponentArray(init_params)

Check warning on line 149 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L149

Added line #L149 was not covered by tests
end
ODEPhi(chain, t, u0, st), init_params

Check warning on line 151 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L151

Added line #L151 was not covered by tests
end

function generate_phi_θ(chain::Flux.Chain, t, u0, init_params)
θ, re = Flux.destructure(chain)
if init_params === nothing
init_params = θ
end
ODEPhi(re, t, u0), init_params
end

Expand Down Expand Up @@ -252,6 +271,35 @@
sum(abs2, dxdtguess .- fs) / length(t)
end

"""
L2 inner loss for DAEProblems
"""

function inner_loss_DAE(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ,

Check warning on line 278 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L278

Added line #L278 was not covered by tests
p) where {C, T, U <: Number}
sum(abs2, ode_dfdx(phi, t, θ, autodiff) - f(phi(t, θ), p, t))

Check warning on line 280 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L280

Added line #L280 was not covered by tests
end

function inner_loss_DAE(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ,

Check warning on line 283 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L283

Added line #L283 was not covered by tests
p) where {C, T, U <: Number}
out = phi(t, θ)
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
sum(abs2, f(dxdtguess[i], p, t[i]) for i in 1:size(out, 2)) / length(t)

Check warning on line 287 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L285-L287

Added lines #L285 - L287 were not covered by tests
sdesai1287 marked this conversation as resolved.
Show resolved Hide resolved
end

function inner_loss_DAE(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ,

Check warning on line 290 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L290

Added line #L290 was not covered by tests
p) where {C, T, U}
sum(abs2, ode_dfdx(phi, t, θ, autodiff) .- f(phi(t, θ), p, t))

Check warning on line 292 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L292

Added line #L292 was not covered by tests
end

function inner_loss_DAE(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ,

Check warning on line 295 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L295

Added line #L295 was not covered by tests
p) where {C, T, U}
out = Array(phi(t, θ))
arrt = Array(t)
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
sum(abs2, f(dxdtguess[:, i], p, arrt[i]) for i in 1:size(out, 2)) / length(t)

Check warning on line 300 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L297-L300

Added lines #L297 - L300 were not covered by tests
sdesai1287 marked this conversation as resolved.
Show resolved Hide resolved
end

"""
Representation of the loss function, parametric on the training strategy `strategy`
"""
Expand Down Expand Up @@ -334,6 +382,89 @@
error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional spaces only. Use StochasticTraining instead.")
end

"""
Representation of the loss function, parametric on the training strategy `strategy` for DAE problems
"""
function generate_loss_DAE(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p,

Check warning on line 388 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L388

Added line #L388 was not covered by tests
batch)
integrand(t::Number, θ) = abs2(inner_loss_DAE(phi, f, autodiff, t, θ, p))
integrand(ts, θ) = [abs2(inner_loss_DAE(phi, f, autodiff, t, θ, p)) for t in ts]
@assert batch == 0 # not implemented

Check warning on line 392 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L390-L392

Added lines #L390 - L392 were not covered by tests

function loss(θ, _)
intprob = IntegralProblem(integrand, tspan[1], tspan[2], θ)
sol = solve(intprob, QuadGKJL(); abstol = strategy.abstol, reltol = strategy.reltol)
sol.u

Check warning on line 397 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L394-L397

Added lines #L394 - L397 were not covered by tests
end

return loss

Check warning on line 400 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L400

Added line #L400 was not covered by tests
end

function generate_loss_DAE(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, batch)
ts = tspan[1]:(strategy.dx):tspan[2]

Check warning on line 404 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L403-L404

Added lines #L403 - L404 were not covered by tests

# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
function loss(θ, _)
if batch
sum(abs2, inner_loss_DAE(phi, f, autodiff, ts, θ, p))

Check warning on line 409 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L407-L409

Added lines #L407 - L409 were not covered by tests
else
sum(abs2, [inner_loss_DAE(phi, f, autodiff, t, θ, p) for t in ts])

Check warning on line 411 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L411

Added line #L411 was not covered by tests
end
end
return loss

Check warning on line 414 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L414

Added line #L414 was not covered by tests
end

function generate_loss_DAE(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p,

Check warning on line 417 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L417

Added line #L417 was not covered by tests
batch)
# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
function loss(θ, _)
ts = adapt(parameterless_type(θ),

Check warning on line 421 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L420-L421

Added lines #L420 - L421 were not covered by tests
[(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)])
if batch
sum(abs2, inner_loss_DAE(phi, f, autodiff, ts, θ, p))

Check warning on line 424 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L423-L424

Added lines #L423 - L424 were not covered by tests
else
sum(abs2, [inner_loss_DAE(phi, f, autodiff, t, θ, p) for t in ts])

Check warning on line 426 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L426

Added line #L426 was not covered by tests
end
end
return loss

Check warning on line 429 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L429

Added line #L429 was not covered by tests
end

function generate_loss_DAE(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p,

Check warning on line 432 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L432

Added line #L432 was not covered by tests
batch)
minT = tspan[1]
maxT = tspan[2]

Check warning on line 435 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L434-L435

Added lines #L434 - L435 were not covered by tests

weights = strategy.weights ./ sum(strategy.weights)

Check warning on line 437 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L437

Added line #L437 was not covered by tests

N = length(weights)
samples = strategy.samples

Check warning on line 440 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L439-L440

Added lines #L439 - L440 were not covered by tests

difference = (maxT - minT) / N

Check warning on line 442 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L442

Added line #L442 was not covered by tests

data = Float64[]
for (index, item) in enumerate(weights)
temp_data = rand(1, trunc(Int, samples * item)) .* difference .+ minT .+

Check warning on line 446 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L444-L446

Added lines #L444 - L446 were not covered by tests
((index - 1) * difference)
data = append!(data, temp_data)
end

Check warning on line 449 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L448-L449

Added lines #L448 - L449 were not covered by tests

ts = data

Check warning on line 451 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L451

Added line #L451 was not covered by tests

function loss(θ, _)
if batch
sum(abs2, inner_loss_DAE(phi, f, autodiff, ts, θ, p))

Check warning on line 455 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L453-L455

Added lines #L453 - L455 were not covered by tests
else
sum(abs2, [inner_loss_DAE(phi, f, autodiff, t, θ, p) for t in ts])

Check warning on line 457 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L457

Added line #L457 was not covered by tests
end
end
return loss

Check warning on line 460 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L460

Added line #L460 was not covered by tests
end

function generate_loss_DAE(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, tspan)
error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional spaces only. Use StochasticTraining instead.")

Check warning on line 464 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L463-L464

Added lines #L463 - L464 were not covered by tests
end


struct NNODEInterpolation{T <: ODEPhi, T2}
phi::T
θ::T2
Expand Down Expand Up @@ -483,3 +614,135 @@
dense_errors = false)
sol
end #solve

function DiffEqBase.__solve(prob::DiffEqBase.AbstractDAEProblem,

Check warning on line 618 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L618

Added line #L618 was not covered by tests
alg::NNDAE,
args...;
dt = nothing,
timeseries_errors = true,
save_everystep = true,
adaptive = false,
abstol = 1.0f-6,
reltol = 1.0f-3,
verbose = false,
saveat = nothing,
maxiters = nothing)

u0 = prob.u0
tspan = prob.tspan
f = prob.f
p = prob.p
t0 = tspan[1]

Check warning on line 635 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L631-L635

Added lines #L631 - L635 were not covered by tests

#hidden layer
chain = alg.chain
opt = alg.opt
autodiff = alg.autodiff

Check warning on line 640 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L638-L640

Added lines #L638 - L640 were not covered by tests

#train points generation
init_params = alg.init_params

Check warning on line 643 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L643

Added line #L643 was not covered by tests

if chain isa Lux.AbstractExplicitLayer || chain isa Flux.Chain
phi, init_params = generate_phi_θ(chain, t0, u0, init_params)

Check warning on line 646 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L645-L646

Added lines #L645 - L646 were not covered by tests
else
error("Only Lux.AbstractExplicitLayer and Flux.Chain neural networks are supported")

Check warning on line 648 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L648

Added line #L648 was not covered by tests
end

# if isinplace(prob)
# throw(error("The NNODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)."))
# end

try
phi(t0, init_params)

Check warning on line 656 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L655-L656

Added lines #L655 - L656 were not covered by tests
catch err
if isa(err, DimensionMismatch)
throw(DimensionMismatch("Dimensions of the initial u0 and chain should match"))

Check warning on line 659 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L658-L659

Added lines #L658 - L659 were not covered by tests
else
throw(err)

Check warning on line 661 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L661

Added line #L661 was not covered by tests
end
end

strategy = if alg.strategy === nothing
if dt !== nothing
GridTraining(dt)

Check warning on line 667 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L665-L667

Added lines #L665 - L667 were not covered by tests
else
QuadratureTraining(; quadrature_alg = QuadGKJL(),

Check warning on line 669 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L669

Added line #L669 was not covered by tests
reltol = convert(eltype(u0), reltol),
abstol = convert(eltype(u0), abstol), maxiters = maxiters,
batch = 0)
end
else
alg.strategy

Check warning on line 675 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L675

Added line #L675 was not covered by tests
end

batch = if alg.batch === nothing
if strategy isa QuadratureTraining
strategy.batch

Check warning on line 680 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L678-L680

Added lines #L678 - L680 were not covered by tests
else
true

Check warning on line 682 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L682

Added line #L682 was not covered by tests
end
else
alg.batch

Check warning on line 685 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L685

Added line #L685 was not covered by tests
end

inner_f = generate_loss_DAE(strategy, phi, f, autodiff, tspan, p, batch)
additional_loss = alg.additional_loss

Check warning on line 689 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L688-L689

Added lines #L688 - L689 were not covered by tests

# Creates OptimizationFunction Object from total_loss
function total_loss(θ, _)
L2_loss = inner_f(θ, phi)
if !(additional_loss isa Nothing)
return additional_loss(phi, θ) + L2_loss

Check warning on line 695 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L692-L695

Added lines #L692 - L695 were not covered by tests
end
L2_loss

Check warning on line 697 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L697

Added line #L697 was not covered by tests
end

# Choice of Optimization Algo for Training Strategies
opt_algo = if strategy isa QuadratureTraining
Optimization.AutoForwardDiff()

Check warning on line 702 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L701-L702

Added lines #L701 - L702 were not covered by tests
else
Optimization.AutoZygote()

Check warning on line 704 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L704

Added line #L704 was not covered by tests
end

# Creates OptimizationFunction Object from total_loss
optf = OptimizationFunction(total_loss, opt_algo)

Check warning on line 708 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L708

Added line #L708 was not covered by tests

iteration = 0
callback = function (p, l)
iteration += 1
verbose && println("Current loss is: $l, Iteration: $iteration")
l < abstol

Check warning on line 714 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L710-L714

Added lines #L710 - L714 were not covered by tests
end

optprob = OptimizationProblem(optf, init_params)
res = solve(optprob, opt; callback, maxiters, alg.kwargs...)

Check warning on line 718 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L717-L718

Added lines #L717 - L718 were not covered by tests

#solutions at timepoints
if saveat isa Number
ts = tspan[1]:saveat:tspan[2]
elseif saveat isa AbstractArray
ts = saveat
elseif dt !== nothing
ts = tspan[1]:dt:tspan[2]
elseif save_everystep
ts = range(tspan[1], tspan[2], length = 100)

Check warning on line 728 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L721-L728

Added lines #L721 - L728 were not covered by tests
else
ts = [tspan[1], tspan[2]]

Check warning on line 730 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L730

Added line #L730 was not covered by tests
end

if u0 isa Number
u = [first(phi(t, res.u)) for t in ts]

Check warning on line 734 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L733-L734

Added lines #L733 - L734 were not covered by tests
else
u = [phi(t, res.u) for t in ts]

Check warning on line 736 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L736

Added line #L736 was not covered by tests
end

sol = DiffEqBase.build_solution(prob, alg, ts, u;

Check warning on line 739 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L739

Added line #L739 was not covered by tests
k = res, dense = true,
interp = NNODEInterpolation(phi, res.u),
calculate_error = false,
retcode = ReturnCode.Success)
DiffEqBase.has_analytic(prob.f) &&

Check warning on line 744 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L744

Added line #L744 was not covered by tests
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
sol

Check warning on line 747 in src/ode_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ode_solve.jl#L747

Added line #L747 was not covered by tests
end #solve