Skip to content

Commit

Permalink
Tag the original solution to sol.original and simplify dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Apr 7, 2024
1 parent 49efc20 commit 81792ea
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 36 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Cubature = "667455a9-e2ce-5579-9412-b964f529a492"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand Down Expand Up @@ -47,7 +46,6 @@ CUDA = "5.2"
ChainRulesCore = "1.21"
ComponentArrays = "0.15.8"
Cubature = "1.5"
DiffEqBase = "6.148"
DiffEqNoiseProcess = "5.20"
Distributions = "0.25.107"
DocStringExtensions = "0.9"
Expand Down
6 changes: 3 additions & 3 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
phystd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCargs = (n_leapfrog=30), nchains = 1, init_params = nothing,
MCMCargs = (n_leapfrog=30), nchains = 1, init_params = nothing,
Adaptorkwargs = (Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8, Metric = DiagEuclideanMetric),
Integratorkwargs = (Integrator = Leapfrog,), autodiff = false,
progress = false, verbose = false)
Expand Down Expand Up @@ -64,7 +64,7 @@ sol_lux_pestim = solve(prob, alg)
Note that the solution is evaluated at fixed time points according to the strategy chosen.
ensemble solution is evaluated and given at steps of `saveat`.
Dataset should only be provided when ODE parameter Estimation is being done.
Dataset should only be provided when ODE parameter Estimation is being done.
The neural network is a fully continuous solution so `BPINNsolution`
is an accurate interpolation (up to the neural network training result). In addition, the
`BPINNstats` is returned as `sol.fullsolution` for further analysis.
Expand Down Expand Up @@ -170,7 +170,7 @@ struct BPINNsolution{O <: BPINNstats, E, NP, OP, P}
end
end

function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem,
function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
alg::BNNODE,
args...;
dt = nothing,
Expand Down
3 changes: 1 addition & 2 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module NeuralPDE

using DocStringExtensions
using Reexport, Statistics
@reexport using DiffEqBase
@reexport using SciMLBase
@reexport using ModelingToolkit

using Zygote, ForwardDiff, Random, Distributions
Expand All @@ -16,7 +16,6 @@ using Integrals, Cubature
using QuasiMonteCarlo: LatinHypercubeSample
import QuasiMonteCarlo
using RuntimeGeneratedFunctions
using SciMLBase
using Statistics
using ArrayInterface
import Optim
Expand Down
8 changes: 4 additions & 4 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}
}
dim::Int
prob::DiffEqBase.ODEProblem
prob::SciMLBase.ODEProblem
chain::C
st::S
strategy::ST
Expand Down Expand Up @@ -336,12 +336,12 @@ end

"""
ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining,
dataset = [nothing],init_params = nothing,
dataset = [nothing],init_params = nothing,
draw_samples = 1000, physdt = 1 / 20.0f0,l2std = [0.05],
phystd = [0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1, autodiff = false, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric,
Metric = DiagEuclideanMetric,
targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
MCMCkwargs = (n_leapfrog = 30,),
Expand Down Expand Up @@ -431,7 +431,7 @@ Incase you are only solving the Equations for solution, do not provide dataset
* AdvancedHMC.jl is still developing convenience structs so might need changes on new releases.
"""
function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
strategy = GridTraining, dataset = [nothing],
init_params = nothing, draw_samples = 1000,
physdt = 1 / 20.0, l2std = [0.05],
Expand Down
14 changes: 8 additions & 6 deletions src/dae_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ of the physics-informed neural network which is used as a solver for a standard
By default, `GridTraining` is used with `dt` if given.
"""
struct NNDAE{C, O, P, K, S <: Union{Nothing, AbstractTrainingStrategy}
} <: DiffEqBase.AbstractDAEAlgorithm
} <: SciMLBase.AbstractDAEAlgorithm
chain::C
opt::O
init_params::P
Expand Down Expand Up @@ -79,7 +79,7 @@ function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p,
return loss
end

function DiffEqBase.__solve(prob::DiffEqBase.AbstractDAEProblem,
function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,
alg::NNDAE,
args...;
dt = nothing,
Expand Down Expand Up @@ -178,12 +178,14 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractDAEProblem,
u = [phi(t, res.u) for t in ts]
end

sol = DiffEqBase.build_solution(prob, alg, ts, u;
sol = SciMLBase.build_solution(prob, alg, ts, u;
k = res, dense = true,
calculate_error = false,
retcode = ReturnCode.Success)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
retcode = ReturnCode.Success,
original = res,
resid = res.objective)
SciMLBase.has_analytic(prob.f) &&
SciMLBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
sol
end
22 changes: 12 additions & 10 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
abstract type NeuralPDEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end
abstract type NeuralPDEAlgorithm <: SciMLBase.AbstractODEAlgorithm end

"""
NNODE(chain, opt, init_params = nothing; autodiff = false, batch = 0, additional_loss = nothing, kwargs...)
Expand All @@ -14,10 +14,10 @@ of the physics-informed neural network which is used as a solver for a standard
## Positional Arguments
* `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer` or `Flux.Chain`.
* `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer` or `Flux.Chain`.
`Flux.Chain` will be converted to `Lux` using `adapt(FromFluxAdaptor(false, false), chain)`.
* `opt`: The optimizer to train the neural network.
* `init_params`: The initial parameter of the neural network. By default, this is `nothing`
* `init_params`: The initial parameter of the neural network. By default, this is `nothing`
which thus uses the random initialization provided by the neural network library.
## Keyword Arguments
Expand All @@ -28,8 +28,8 @@ of the physics-informed neural network which is used as a solver for a standard
automatic differentiation (via Zygote), this is only for the derivative
in the loss function (the derivative with respect to time).
* `batch`: The batch size for the loss computation. Defaults to `true`, means the neural network is applied at a row vector of values
`t` simultaneously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data.
`false` means which means the application of the neural network is done at individual time points one at a time.
`t` simultaneously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data.
`false` means which means the application of the neural network is done at individual time points one at a time.
This is not applicable to `QuadratureTraining` where `batch` is passed in the `strategy` which is the number of points it can parallelly compute the integrand.
* `param_estim`: Boolean to indicate whether parameters of the differential equations are learnt along with parameters of the neural network.
* `strategy`: The training strategy used to choose the points for the evaluations.
Expand Down Expand Up @@ -339,7 +339,7 @@ end
SciMLBase.interp_summary(::NNODEInterpolation) = "Trained neural network interpolation"
SciMLBase.allowscomplex(::NNODE) = true

function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
alg::NNODE,
args...;
dt = nothing,
Expand Down Expand Up @@ -479,13 +479,15 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
u = [phi(t, res.u) for t in ts]
end

sol = DiffEqBase.build_solution(prob, alg, ts, u;
sol = SciMLBase.build_solution(prob, alg, ts, u;
k = res, dense = true,
interp = NNODEInterpolation(phi, res.u),
calculate_error = false,
retcode = ReturnCode.Success)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
retcode = ReturnCode.Success,
original = res,
resid = res.objective)
SciMLBase.has_analytic(prob.f) &&
SciMLBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
sol
end #solve
18 changes: 9 additions & 9 deletions src/rode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function NNRODE(chain, W, opt = Optim.BFGS(), init_params = nothing; autodiff =
NNRODE(chain, W, opt, init_params, autodiff, kwargs)
end

function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem,
function SciMLBase.solve(prob::SciMLBase.AbstractRODEProblem,
alg::NeuralPDEAlgorithm,
args...;
dt,
Expand All @@ -30,7 +30,7 @@ function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem,
abstol = 1.0f-6,
verbose = false,
maxiters = 100)
DiffEqBase.isinplace(prob) && error("Only out-of-place methods are allowed!")
SciMLBase.isinplace(prob) && error("Only out-of-place methods are allowed!")

u0 = prob.u0
tspan = prob.tspan
Expand All @@ -52,24 +52,24 @@ function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem,
if u0 isa Number
phi = (t, W, θ) -> u0 +
(t - tspan[1]) *
first(chain(adapt(DiffEqBase.parameterless_type(θ), [t, W]),
first(chain(adapt(SciMLBase.parameterless_type(θ), [t, W]),
θ))
else
phi = (t, W, θ) -> u0 +
(t - tspan[1]) *
chain(adapt(DiffEqBase.parameterless_type(θ), [t, W]), θ)
chain(adapt(SciMLBase.parameterless_type(θ), [t, W]), θ)
end
else
_, re = Flux.destructure(chain)
#The phi trial solution
if u0 isa Number
phi = (t, W, θ) -> u0 +
(t - t0) *
first(re(θ)(adapt(DiffEqBase.parameterless_type(θ), [t, W])))
first(re(θ)(adapt(SciMLBase.parameterless_type(θ), [t, W])))
else
phi = (t, W, θ) -> u0 +
(t - t0) *
re(θ)(adapt(DiffEqBase.parameterless_type(θ), [t, W]))
re(θ)(adapt(SciMLBase.parameterless_type(θ), [t, W]))
end
end

Expand Down Expand Up @@ -108,9 +108,9 @@ function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem,
u = [(phi(ts[i], W.W[i], res.minimizer)) for i in 1:length(ts)]
end

sol = DiffEqBase.build_solution(prob, alg, ts, u, W = W, calculate_error = false)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
sol = SciMLBase.build_solution(prob, alg, ts, u, W = W, calculate_error = false)
SciMLBase.has_analytic(prob.f) &&
SciMLBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
sol
end #solve

0 comments on commit 81792ea

Please sign in to comment.