From 81792ea77d1afb5568667413ecbc7b00f89bf399 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 7 Apr 2024 08:38:18 -0400 Subject: [PATCH] Tag the original solution to sol.original and simplify dependencies Gives a nice solution to https://discourse.julialang.org/t/continue-solving-ode-problem-when-using-neuralpde-nnode/101580 --- Project.toml | 2 -- src/BPINN_ode.jl | 6 +++--- src/NeuralPDE.jl | 3 +-- src/advancedHMC_MCMC.jl | 8 ++++---- src/dae_solve.jl | 14 ++++++++------ src/ode_solve.jl | 22 ++++++++++++---------- src/rode_solve.jl | 18 +++++++++--------- 7 files changed, 37 insertions(+), 36 deletions(-) diff --git a/Project.toml b/Project.toml index 3541fae00..80e56de80 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Cubature = "667455a9-e2ce-5579-9412-b964f529a492" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -47,7 +46,6 @@ CUDA = "5.2" ChainRulesCore = "1.21" ComponentArrays = "0.15.8" Cubature = "1.5" -DiffEqBase = "6.148" DiffEqNoiseProcess = "5.20" Distributions = "0.25.107" DocStringExtensions = "0.9" diff --git a/src/BPINN_ode.jl b/src/BPINN_ode.jl index e92d1b11c..226c3f329 100644 --- a/src/BPINN_ode.jl +++ b/src/BPINN_ode.jl @@ -4,7 +4,7 @@ BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000, priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05], phystd = [0.05], dataset = [nothing], physdt = 1 / 20.0, - MCMCargs = (n_leapfrog=30), nchains = 1, init_params = nothing, + MCMCargs = (n_leapfrog=30), nchains = 1, init_params = nothing, Adaptorkwargs = (Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8, Metric = DiagEuclideanMetric), Integratorkwargs = (Integrator = Leapfrog,), autodiff = false, progress = false, verbose = false) @@ -64,7 +64,7 @@ sol_lux_pestim = solve(prob, alg) Note that the solution is evaluated at fixed time points according to the strategy chosen. ensemble solution is evaluated and given at steps of `saveat`. -Dataset should only be provided when ODE parameter Estimation is being done. +Dataset should only be provided when ODE parameter Estimation is being done. The neural network is a fully continuous solution so `BPINNsolution` is an accurate interpolation (up to the neural network training result). In addition, the `BPINNstats` is returned as `sol.fullsolution` for further analysis. @@ -170,7 +170,7 @@ struct BPINNsolution{O <: BPINNstats, E, NP, OP, P} end end -function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem, +function SciMLBase.__solve(prob::SciMLBase.ODEProblem, alg::BNNODE, args...; dt = nothing, diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index cfc3b9367..1122afc83 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -5,7 +5,7 @@ module NeuralPDE using DocStringExtensions using Reexport, Statistics -@reexport using DiffEqBase +@reexport using SciMLBase @reexport using ModelingToolkit using Zygote, ForwardDiff, Random, Distributions @@ -16,7 +16,6 @@ using Integrals, Cubature using QuasiMonteCarlo: LatinHypercubeSample import QuasiMonteCarlo using RuntimeGeneratedFunctions -using SciMLBase using Statistics using ArrayInterface import Optim diff --git a/src/advancedHMC_MCMC.jl b/src/advancedHMC_MCMC.jl index 22292462d..252ca2f41 100644 --- a/src/advancedHMC_MCMC.jl +++ b/src/advancedHMC_MCMC.jl @@ -4,7 +4,7 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I, Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}} } dim::Int - prob::DiffEqBase.ODEProblem + prob::SciMLBase.ODEProblem chain::C st::S strategy::ST @@ -336,12 +336,12 @@ end """ ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining, - dataset = [nothing],init_params = nothing, + dataset = [nothing],init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0f0,l2std = [0.05], phystd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1, autodiff = false, Kernel = HMC, Adaptorkwargs = (Adaptor = StanHMCAdaptor, - Metric = DiagEuclideanMetric, + Metric = DiagEuclideanMetric, targetacceptancerate = 0.8), Integratorkwargs = (Integrator = Leapfrog,), MCMCkwargs = (n_leapfrog = 30,), @@ -431,7 +431,7 @@ Incase you are only solving the Equations for solution, do not provide dataset * AdvancedHMC.jl is still developing convenience structs so might need changes on new releases. """ -function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain; +function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain; strategy = GridTraining, dataset = [nothing], init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0, l2std = [0.05], diff --git a/src/dae_solve.jl b/src/dae_solve.jl index f9edbb948..5a5ee83be 100644 --- a/src/dae_solve.jl +++ b/src/dae_solve.jl @@ -31,7 +31,7 @@ of the physics-informed neural network which is used as a solver for a standard By default, `GridTraining` is used with `dt` if given. """ struct NNDAE{C, O, P, K, S <: Union{Nothing, AbstractTrainingStrategy} -} <: DiffEqBase.AbstractDAEAlgorithm +} <: SciMLBase.AbstractDAEAlgorithm chain::C opt::O init_params::P @@ -79,7 +79,7 @@ function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, return loss end -function DiffEqBase.__solve(prob::DiffEqBase.AbstractDAEProblem, +function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem, alg::NNDAE, args...; dt = nothing, @@ -178,12 +178,14 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractDAEProblem, u = [phi(t, res.u) for t in ts] end - sol = DiffEqBase.build_solution(prob, alg, ts, u; + sol = SciMLBase.build_solution(prob, alg, ts, u; k = res, dense = true, calculate_error = false, - retcode = ReturnCode.Success) - DiffEqBase.has_analytic(prob.f) && - DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, + retcode = ReturnCode.Success, + original = res, + resid = res.objective) + SciMLBase.has_analytic(prob.f) && + SciMLBase.calculate_solution_errors!(sol; timeseries_errors = true, dense_errors = false) sol end diff --git a/src/ode_solve.jl b/src/ode_solve.jl index e10793647..64d7b3ac6 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -1,4 +1,4 @@ -abstract type NeuralPDEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end +abstract type NeuralPDEAlgorithm <: SciMLBase.AbstractODEAlgorithm end """ NNODE(chain, opt, init_params = nothing; autodiff = false, batch = 0, additional_loss = nothing, kwargs...) @@ -14,10 +14,10 @@ of the physics-informed neural network which is used as a solver for a standard ## Positional Arguments -* `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer` or `Flux.Chain`. +* `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer` or `Flux.Chain`. `Flux.Chain` will be converted to `Lux` using `adapt(FromFluxAdaptor(false, false), chain)`. * `opt`: The optimizer to train the neural network. -* `init_params`: The initial parameter of the neural network. By default, this is `nothing` +* `init_params`: The initial parameter of the neural network. By default, this is `nothing` which thus uses the random initialization provided by the neural network library. ## Keyword Arguments @@ -28,8 +28,8 @@ of the physics-informed neural network which is used as a solver for a standard automatic differentiation (via Zygote), this is only for the derivative in the loss function (the derivative with respect to time). * `batch`: The batch size for the loss computation. Defaults to `true`, means the neural network is applied at a row vector of values - `t` simultaneously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data. - `false` means which means the application of the neural network is done at individual time points one at a time. + `t` simultaneously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data. + `false` means which means the application of the neural network is done at individual time points one at a time. This is not applicable to `QuadratureTraining` where `batch` is passed in the `strategy` which is the number of points it can parallelly compute the integrand. * `param_estim`: Boolean to indicate whether parameters of the differential equations are learnt along with parameters of the neural network. * `strategy`: The training strategy used to choose the points for the evaluations. @@ -339,7 +339,7 @@ end SciMLBase.interp_summary(::NNODEInterpolation) = "Trained neural network interpolation" SciMLBase.allowscomplex(::NNODE) = true -function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem, +function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, alg::NNODE, args...; dt = nothing, @@ -479,13 +479,15 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem, u = [phi(t, res.u) for t in ts] end - sol = DiffEqBase.build_solution(prob, alg, ts, u; + sol = SciMLBase.build_solution(prob, alg, ts, u; k = res, dense = true, interp = NNODEInterpolation(phi, res.u), calculate_error = false, - retcode = ReturnCode.Success) - DiffEqBase.has_analytic(prob.f) && - DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, + retcode = ReturnCode.Success, + original = res, + resid = res.objective) + SciMLBase.has_analytic(prob.f) && + SciMLBase.calculate_solution_errors!(sol; timeseries_errors = true, dense_errors = false) sol end #solve diff --git a/src/rode_solve.jl b/src/rode_solve.jl index 85b496880..863a0d1be 100644 --- a/src/rode_solve.jl +++ b/src/rode_solve.jl @@ -20,7 +20,7 @@ function NNRODE(chain, W, opt = Optim.BFGS(), init_params = nothing; autodiff = NNRODE(chain, W, opt, init_params, autodiff, kwargs) end -function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem, +function SciMLBase.solve(prob::SciMLBase.AbstractRODEProblem, alg::NeuralPDEAlgorithm, args...; dt, @@ -30,7 +30,7 @@ function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem, abstol = 1.0f-6, verbose = false, maxiters = 100) - DiffEqBase.isinplace(prob) && error("Only out-of-place methods are allowed!") + SciMLBase.isinplace(prob) && error("Only out-of-place methods are allowed!") u0 = prob.u0 tspan = prob.tspan @@ -52,12 +52,12 @@ function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem, if u0 isa Number phi = (t, W, θ) -> u0 + (t - tspan[1]) * - first(chain(adapt(DiffEqBase.parameterless_type(θ), [t, W]), + first(chain(adapt(SciMLBase.parameterless_type(θ), [t, W]), θ)) else phi = (t, W, θ) -> u0 + (t - tspan[1]) * - chain(adapt(DiffEqBase.parameterless_type(θ), [t, W]), θ) + chain(adapt(SciMLBase.parameterless_type(θ), [t, W]), θ) end else _, re = Flux.destructure(chain) @@ -65,11 +65,11 @@ function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem, if u0 isa Number phi = (t, W, θ) -> u0 + (t - t0) * - first(re(θ)(adapt(DiffEqBase.parameterless_type(θ), [t, W]))) + first(re(θ)(adapt(SciMLBase.parameterless_type(θ), [t, W]))) else phi = (t, W, θ) -> u0 + (t - t0) * - re(θ)(adapt(DiffEqBase.parameterless_type(θ), [t, W])) + re(θ)(adapt(SciMLBase.parameterless_type(θ), [t, W])) end end @@ -108,9 +108,9 @@ function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem, u = [(phi(ts[i], W.W[i], res.minimizer)) for i in 1:length(ts)] end - sol = DiffEqBase.build_solution(prob, alg, ts, u, W = W, calculate_error = false) - DiffEqBase.has_analytic(prob.f) && - DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, + sol = SciMLBase.build_solution(prob, alg, ts, u, W = W, calculate_error = false) + SciMLBase.has_analytic(prob.f) && + SciMLBase.calculate_solution_errors!(sol; timeseries_errors = true, dense_errors = false) sol end #solve