Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tag the original solution to sol.original and simplify dependencies #846

Merged
merged 1 commit into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Cubature = "667455a9-e2ce-5579-9412-b964f529a492"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand Down Expand Up @@ -47,7 +46,6 @@ CUDA = "5.2"
ChainRulesCore = "1.21"
ComponentArrays = "0.15.8"
Cubature = "1.5"
DiffEqBase = "6.148"
DiffEqNoiseProcess = "5.20"
Distributions = "0.25.107"
DocStringExtensions = "0.9"
Expand Down
6 changes: 3 additions & 3 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
phystd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCargs = (n_leapfrog=30), nchains = 1, init_params = nothing,
MCMCargs = (n_leapfrog=30), nchains = 1, init_params = nothing,
Adaptorkwargs = (Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8, Metric = DiagEuclideanMetric),
Integratorkwargs = (Integrator = Leapfrog,), autodiff = false,
progress = false, verbose = false)
Expand Down Expand Up @@ -64,7 +64,7 @@ sol_lux_pestim = solve(prob, alg)

Note that the solution is evaluated at fixed time points according to the strategy chosen.
ensemble solution is evaluated and given at steps of `saveat`.
Dataset should only be provided when ODE parameter Estimation is being done.
Dataset should only be provided when ODE parameter Estimation is being done.
The neural network is a fully continuous solution so `BPINNsolution`
is an accurate interpolation (up to the neural network training result). In addition, the
`BPINNstats` is returned as `sol.fullsolution` for further analysis.
Expand Down Expand Up @@ -170,7 +170,7 @@ struct BPINNsolution{O <: BPINNstats, E, NP, OP, P}
end
end

function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem,
function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
alg::BNNODE,
args...;
dt = nothing,
Expand Down
3 changes: 1 addition & 2 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module NeuralPDE

using DocStringExtensions
using Reexport, Statistics
@reexport using DiffEqBase
@reexport using SciMLBase
@reexport using ModelingToolkit

using Zygote, ForwardDiff, Random, Distributions
Expand All @@ -16,7 +16,6 @@ using Integrals, Cubature
using QuasiMonteCarlo: LatinHypercubeSample
import QuasiMonteCarlo
using RuntimeGeneratedFunctions
using SciMLBase
using Statistics
using ArrayInterface
import Optim
Expand Down
8 changes: 4 additions & 4 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}
}
dim::Int
prob::DiffEqBase.ODEProblem
prob::SciMLBase.ODEProblem
chain::C
st::S
strategy::ST
Expand Down Expand Up @@ -336,12 +336,12 @@ end

"""
ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining,
dataset = [nothing],init_params = nothing,
dataset = [nothing],init_params = nothing,
draw_samples = 1000, physdt = 1 / 20.0f0,l2std = [0.05],
phystd = [0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1, autodiff = false, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric,
Metric = DiagEuclideanMetric,
targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
MCMCkwargs = (n_leapfrog = 30,),
Expand Down Expand Up @@ -431,7 +431,7 @@ Incase you are only solving the Equations for solution, do not provide dataset

* AdvancedHMC.jl is still developing convenience structs so might need changes on new releases.
"""
function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
strategy = GridTraining, dataset = [nothing],
init_params = nothing, draw_samples = 1000,
physdt = 1 / 20.0, l2std = [0.05],
Expand Down
14 changes: 8 additions & 6 deletions src/dae_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ of the physics-informed neural network which is used as a solver for a standard
By default, `GridTraining` is used with `dt` if given.
"""
struct NNDAE{C, O, P, K, S <: Union{Nothing, AbstractTrainingStrategy}
} <: DiffEqBase.AbstractDAEAlgorithm
} <: SciMLBase.AbstractDAEAlgorithm
chain::C
opt::O
init_params::P
Expand Down Expand Up @@ -79,7 +79,7 @@ function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p,
return loss
end

function DiffEqBase.__solve(prob::DiffEqBase.AbstractDAEProblem,
function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,
alg::NNDAE,
args...;
dt = nothing,
Expand Down Expand Up @@ -178,12 +178,14 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractDAEProblem,
u = [phi(t, res.u) for t in ts]
end

sol = DiffEqBase.build_solution(prob, alg, ts, u;
sol = SciMLBase.build_solution(prob, alg, ts, u;
k = res, dense = true,
calculate_error = false,
retcode = ReturnCode.Success)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
retcode = ReturnCode.Success,
original = res,
resid = res.objective)
SciMLBase.has_analytic(prob.f) &&
SciMLBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
sol
end
22 changes: 12 additions & 10 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
abstract type NeuralPDEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end
abstract type NeuralPDEAlgorithm <: SciMLBase.AbstractODEAlgorithm end

"""
NNODE(chain, opt, init_params = nothing; autodiff = false, batch = 0, additional_loss = nothing, kwargs...)
Expand All @@ -14,10 +14,10 @@ of the physics-informed neural network which is used as a solver for a standard

## Positional Arguments

* `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer` or `Flux.Chain`.
* `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer` or `Flux.Chain`.
`Flux.Chain` will be converted to `Lux` using `adapt(FromFluxAdaptor(false, false), chain)`.
* `opt`: The optimizer to train the neural network.
* `init_params`: The initial parameter of the neural network. By default, this is `nothing`
* `init_params`: The initial parameter of the neural network. By default, this is `nothing`
which thus uses the random initialization provided by the neural network library.

## Keyword Arguments
Expand All @@ -28,8 +28,8 @@ of the physics-informed neural network which is used as a solver for a standard
automatic differentiation (via Zygote), this is only for the derivative
in the loss function (the derivative with respect to time).
* `batch`: The batch size for the loss computation. Defaults to `true`, means the neural network is applied at a row vector of values
`t` simultaneously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data.
`false` means which means the application of the neural network is done at individual time points one at a time.
`t` simultaneously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data.
`false` means which means the application of the neural network is done at individual time points one at a time.
This is not applicable to `QuadratureTraining` where `batch` is passed in the `strategy` which is the number of points it can parallelly compute the integrand.
* `param_estim`: Boolean to indicate whether parameters of the differential equations are learnt along with parameters of the neural network.
* `strategy`: The training strategy used to choose the points for the evaluations.
Expand Down Expand Up @@ -339,7 +339,7 @@ end
SciMLBase.interp_summary(::NNODEInterpolation) = "Trained neural network interpolation"
SciMLBase.allowscomplex(::NNODE) = true

function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
alg::NNODE,
args...;
dt = nothing,
Expand Down Expand Up @@ -479,13 +479,15 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
u = [phi(t, res.u) for t in ts]
end

sol = DiffEqBase.build_solution(prob, alg, ts, u;
sol = SciMLBase.build_solution(prob, alg, ts, u;
k = res, dense = true,
interp = NNODEInterpolation(phi, res.u),
calculate_error = false,
retcode = ReturnCode.Success)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
retcode = ReturnCode.Success,
original = res,
resid = res.objective)
SciMLBase.has_analytic(prob.f) &&
SciMLBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
sol
end #solve
18 changes: 9 additions & 9 deletions src/rode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function NNRODE(chain, W, opt = Optim.BFGS(), init_params = nothing; autodiff =
NNRODE(chain, W, opt, init_params, autodiff, kwargs)
end

function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem,
function SciMLBase.solve(prob::SciMLBase.AbstractRODEProblem,
alg::NeuralPDEAlgorithm,
args...;
dt,
Expand All @@ -30,7 +30,7 @@ function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem,
abstol = 1.0f-6,
verbose = false,
maxiters = 100)
DiffEqBase.isinplace(prob) && error("Only out-of-place methods are allowed!")
SciMLBase.isinplace(prob) && error("Only out-of-place methods are allowed!")

u0 = prob.u0
tspan = prob.tspan
Expand All @@ -52,24 +52,24 @@ function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem,
if u0 isa Number
phi = (t, W, θ) -> u0 +
(t - tspan[1]) *
first(chain(adapt(DiffEqBase.parameterless_type(θ), [t, W]),
first(chain(adapt(SciMLBase.parameterless_type(θ), [t, W]),
θ))
else
phi = (t, W, θ) -> u0 +
(t - tspan[1]) *
chain(adapt(DiffEqBase.parameterless_type(θ), [t, W]), θ)
chain(adapt(SciMLBase.parameterless_type(θ), [t, W]), θ)
end
else
_, re = Flux.destructure(chain)
#The phi trial solution
if u0 isa Number
phi = (t, W, θ) -> u0 +
(t - t0) *
first(re(θ)(adapt(DiffEqBase.parameterless_type(θ), [t, W])))
first(re(θ)(adapt(SciMLBase.parameterless_type(θ), [t, W])))
else
phi = (t, W, θ) -> u0 +
(t - t0) *
re(θ)(adapt(DiffEqBase.parameterless_type(θ), [t, W]))
re(θ)(adapt(SciMLBase.parameterless_type(θ), [t, W]))
end
end

Expand Down Expand Up @@ -108,9 +108,9 @@ function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem,
u = [(phi(ts[i], W.W[i], res.minimizer)) for i in 1:length(ts)]
end

sol = DiffEqBase.build_solution(prob, alg, ts, u, W = W, calculate_error = false)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
sol = SciMLBase.build_solution(prob, alg, ts, u, W = W, calculate_error = false)
SciMLBase.has_analytic(prob.f) &&
SciMLBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
sol
end #solve
Loading