diff --git a/Project.toml b/Project.toml index d76565ff2..45468985c 100644 --- a/Project.toml +++ b/Project.toml @@ -44,7 +44,7 @@ AdvancedHMC = "0.6.1" Aqua = "0.8" ArrayInterface = "7.7" CUDA = "5.2" -ChainRulesCore = "1.18" +ChainRulesCore = "1.21" ComponentArrays = "0.15.8" Cubature = "1.5" DiffEqBase = "6.144" @@ -59,7 +59,7 @@ Integrals = "4" LineSearches = "7.2" LinearAlgebra = "1" LogDensityProblems = "2" -Lux = "0.5.14" +Lux = "0.5.22" LuxCUDA = "0.3.2" MCMCChains = "6" MethodOfLines = "0.10.7" @@ -82,7 +82,7 @@ SymbolicUtils = "1.4" Symbolics = "5.17" Test = "1" UnPack = "1" -Zygote = "0.6.68" +Zygote = "0.6.69" julia = "1.10" [extras] @@ -91,12 +91,12 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4" [targets] test = ["Aqua", "Test", "CUDA", "SafeTestsets", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "LineSearches", "LuxCUDA", "Flux", "MethodOfLines"] diff --git a/docs/src/tutorials/neural_adapter.md b/docs/src/tutorials/neural_adapter.md index a56e30a26..93f0dd036 100644 --- a/docs/src/tutorials/neural_adapter.md +++ b/docs/src/tutorials/neural_adapter.md @@ -69,7 +69,7 @@ function loss(cord, θ) ch2 .- phi(cord, res.u) end -strategy = NeuralPDE.QuadratureTraining() +strategy = NeuralPDE.QuadratureTraining(; reltol = 1e-6) prob_ = NeuralPDE.neural_adapter(loss, init_params2, pde_system, strategy) res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 10000) @@ -173,7 +173,7 @@ for i in 1:count_decomp bcs_ = create_bcs(domains_[1].domain, phi_bound) @named pde_system_ = PDESystem(eq, bcs_, domains_, [x, y], [u(x, y)]) push!(pde_system_map, pde_system_) - strategy = NeuralPDE.QuadratureTraining() + strategy = NeuralPDE.QuadratureTraining(; reltol = 1e-6) discretization = NeuralPDE.PhysicsInformedNN(chains[i], strategy; init_params = init_params[i]) @@ -243,10 +243,10 @@ callback = function (p, l) end prob_ = NeuralPDE.neural_adapter(losses, init_params2, pde_system_map, - NeuralPDE.QuadratureTraining()) + NeuralPDE.QuadratureTraining(; reltol = 1e-6)) res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 5000) prob_ = NeuralPDE.neural_adapter(losses, res_.u, pde_system_map, - NeuralPDE.QuadratureTraining()) + NeuralPDE.QuadratureTraining(; reltol = 1e-6)) res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 5000) phi_ = PhysicsInformedNN(chain2, strategy; init_params = res_.u).phi diff --git a/src/BPINN_ode.jl b/src/BPINN_ode.jl index 3bbf1afea..087b9d41c 100644 --- a/src/BPINN_ode.jl +++ b/src/BPINN_ode.jl @@ -113,7 +113,7 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000, targetacceptancerate = 0.8), Integratorkwargs = (Integrator = Leapfrog,), autodiff = false, progress = false, verbose = false) - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) BNNODE(chain, Kernel, strategy, draw_samples, priorsNNw, param, l2std, phystd, dataset, physdt, MCMCkwargs, diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index d367bf8b6..2ba1de25b 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -30,6 +30,7 @@ using DomainSets: Domain, ClosedInterval, AbstractInterval, leftendpoint, righte using SciMLBase: @add_kwonly, parameterless_type using UnPack: @unpack import ChainRulesCore, Lux, ComponentArrays +using Lux: FromFluxAdaptor using ChainRulesCore: @non_differentiable RuntimeGeneratedFunctions.init(@__MODULE__) diff --git a/src/advancedHMC_MCMC.jl b/src/advancedHMC_MCMC.jl index 6f3014925..c86c87599 100644 --- a/src/advancedHMC_MCMC.jl +++ b/src/advancedHMC_MCMC.jl @@ -439,7 +439,7 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain; MCMCkwargs = (n_leapfrog = 30,), progress = false, verbose = false) - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) # NN parameter prior mean and variance(PriorsNN must be a tuple) if isinplace(prob) throw(error("The BPINN ODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t).")) diff --git a/src/dae_solve.jl b/src/dae_solve.jl index 3f6bf8f0f..0c9d1323d 100644 --- a/src/dae_solve.jl +++ b/src/dae_solve.jl @@ -42,7 +42,7 @@ end function NNDAE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false, kwargs...) - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) NNDAE(chain, opt, init_params, autodiff, strategy, kwargs) end diff --git a/src/ode_solve.jl b/src/ode_solve.jl index f93183d76..8af6a708d 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -15,7 +15,7 @@ of the physics-informed neural network which is used as a solver for a standard ## Positional Arguments * `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer` or `Flux.Chain`. - `Flux.Chain` will be converted to `Lux` using `Lux.transform`. + `Flux.Chain` will be converted to `Lux` using `adapt(FromFluxAdaptor(false, false), chain)`. * `opt`: The optimizer to train the neural network. * `init_params`: The initial parameter of the neural network. By default, this is `nothing` which thus uses the random initialization provided by the neural network library. @@ -27,11 +27,10 @@ of the physics-informed neural network which is used as a solver for a standard the PDE operators. The reverse mode of the loss function is always automatic differentiation (via Zygote), this is only for the derivative in the loss function (the derivative with respect to time). -* `batch`: The batch size to use for the internal quadrature. Defaults to `0`, which - means the application of the neural network is done at individual time points one - at a time. `batch>0` means the neural network is applied at a row vector of values - `t` simultaneously, i.e. it's the batch size for the neural network evaluations. - This requires a neural network compatible with batched data. +* `batch`: The batch size for the loss computation. Defaults to `true`, means the neural network is applied at a row vector of values + `t` simultaneously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data. + `false` means which means the application of the neural network is done at individual time points one at a time. + This is not applicable to `QuadratureTraining` where `batch` is passed in the `strategy` which is the number of points it can parallelly compute the integrand. * `param_estim`: Boolean to indicate whether parameters of the differential equations are learnt along with parameters of the neural network. * `strategy`: The training strategy used to choose the points for the evaluations. Default of `nothing` means that `QuadratureTraining` with QuadGK is used if no @@ -88,8 +87,8 @@ struct NNODE{C, O, P, B, PE, K, AL <: Union{Nothing, Function}, end function NNODE(chain, opt, init_params = nothing; strategy = nothing, - autodiff = false, batch = nothing, param_estim = false, additional_loss = nothing, kwargs...) - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + autodiff = false, batch = true, param_estim = false, additional_loss = nothing, kwargs...) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) NNODE(chain, opt, init_params, autodiff, batch, strategy, param_estim, additional_loss, kwargs) end @@ -111,11 +110,7 @@ end function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params) θ, st = Lux.setup(Random.default_rng(), chain) - if init_params === nothing - init_params = ComponentArrays.ComponentArray(θ) - else - init_params = ComponentArrays.ComponentArray(init_params) - end + isnothing(init_params) && (init_params = θ) ODEPhi(chain, t, u0, st), init_params end @@ -182,7 +177,7 @@ function ode_dfdx(phi::ODEPhi, t::AbstractVector, θ, autodiff::Bool) end """ - inner_loss(phi, f, autodiff, t, θ, p) + inner_loss(phi, f, autodiff, t, θ, p, param_estim) Simple L2 inner loss at a time `t` with parameters `θ` of the neural network. """ @@ -220,7 +215,7 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, end """ - generate_loss(strategy, phi, f, autodiff, tspan, p, batch) + generate_loss(strategy, phi, f, autodiff, tspan, p, batch, param_estim) Representation of the loss function, parametric on the training strategy `strategy`. """ @@ -229,14 +224,13 @@ function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tsp integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim)) integrand(ts, θ) = [abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim)) for t in ts] - @assert batch == 0 # not implemented function loss(θ, _) - intprob = IntegralProblem(integrand, (tspan[1], tspan[2]), θ) - sol = solve(intprob, QuadGKJL(); abstol = strategy.abstol, reltol = strategy.reltol) + intf = BatchIntegralFunction(integrand, max_batch = strategy.batch) + intprob = IntegralProblem(intf, (tspan[1], tspan[2]), θ) + sol = solve(intprob, strategy.quadrature_alg; abstol = strategy.abstol, reltol = strategy.reltol, maxiters = strategy.maxiters) sol.u end - return loss end @@ -395,16 +389,7 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem, alg.strategy end - batch = if alg.batch === nothing - if strategy isa QuadratureTraining - strategy.batch - else - true - end - else - alg.batch - end - + batch = alg.batch inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, batch, param_estim) additional_loss = alg.additional_loss (param_estim && isnothing(additional_loss)) && throw(ArgumentError("Please provide `additional_loss` in `NNODE` for parameter estimation (`param_estim` is true).")) diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 69116c4da..50f7649dc 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -48,7 +48,7 @@ methodology. * `chain`: a vector of Lux/Flux chains with a d-dimensional input and a 1-dimensional output corresponding to each of the dependent variables. Note that this specification respects the order of the dependent variables as specified in the PDESystem. - Flux chains will be converted to Lux internally using `Lux.transform`. + Flux chains will be converted to Lux internally using `adapt(FromFluxAdaptor(false, false), chain)`. * `strategy`: determines which training strategy will be used. See the Training Strategy documentation for more details. @@ -107,7 +107,7 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN if multioutput !all(i -> i isa Lux.AbstractExplicitLayer, chain) && (chain = Lux.transform.(chain)) else - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) end if phi === nothing if multioutput @@ -243,7 +243,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN if multioutput !all(i -> i isa Lux.AbstractExplicitLayer, chain) && (chain = Lux.transform.(chain)) else - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) end if phi === nothing if multioutput diff --git a/src/training_strategies.jl b/src/training_strategies.jl index 1db878094..c997e6c4c 100644 --- a/src/training_strategies.jl +++ b/src/training_strategies.jl @@ -272,7 +272,7 @@ struct QuadratureTraining{Q <: SciMLBase.AbstractIntegralAlgorithm, T} <: batch::Int64 end -function QuadratureTraining(; quadrature_alg = CubatureJLh(), reltol = 1e-6, abstol = 1e-3, +function QuadratureTraining(; quadrature_alg = CubatureJLh(), reltol = 1e-3, abstol = 1e-6, maxiters = 1_000, batch = 100) QuadratureTraining(quadrature_alg, reltol, abstol, maxiters, batch) end @@ -306,11 +306,7 @@ function get_loss_function(loss_function, lb, ub, eltypeθ, strategy::Quadrature end area = eltypeθ(prod(abs.(ub .- lb))) f_ = (lb, ub, loss_, θ) -> begin - # last_x = 1 function integrand(x, θ) - # last_x = x - # mean(abs2,loss_(x,θ), dims=2) - # size_x = fill(size(x)[2],(1,1)) x = adapt(parameterless_type(ComponentArrays.getdata(θ)), x) sum(abs2, view(loss_(x, θ), 1, :), dims = 2) #./ size_x end diff --git a/test/NNODE_tests.jl b/test/NNODE_tests.jl index 122adcceb..1e2ba3c05 100644 --- a/test/NNODE_tests.jl +++ b/test/NNODE_tests.jl @@ -9,6 +9,7 @@ Random.seed!(100) @testset "Scalar" begin # Run a solve on scalars + println("Scalar") linear = (u, p, t) -> cos(2pi * t) tspan = (0.0f0, 1.0f0) u0 = 0.0f0 @@ -16,26 +17,27 @@ Random.seed!(100) luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) opt = OptimizationOptimisers.Adam(0.1, (0.9, 0.95)) - sol = solve(prob, NNODE(luxchain, opt), dt = 1 / 20.0f0, verbose = true, + sol = solve(prob, NNODE(luxchain, opt), dt = 1 / 20.0f0, verbose = false, abstol = 1.0f-10, maxiters = 200) @test_throws ArgumentError solve(prob, NNODE(luxchain, opt; autodiff = true), dt = 1 / 20.0f0, - verbose = true, abstol = 1.0f-10, maxiters = 200) + verbose = false, abstol = 1.0f-10, maxiters = 200) - sol = solve(prob, NNODE(luxchain, opt), verbose = true, + sol = solve(prob, NNODE(luxchain, opt), verbose = false, abstol = 1.0f-6, maxiters = 200) opt = OptimizationOptimJL.BFGS() - sol = solve(prob, NNODE(luxchain, opt), dt = 1 / 20.0f0, verbose = true, + sol = solve(prob, NNODE(luxchain, opt), dt = 1 / 20.0f0, verbose = false, abstol = 1.0f-10, maxiters = 200) - sol = solve(prob, NNODE(luxchain, opt), verbose = true, + sol = solve(prob, NNODE(luxchain, opt), verbose = false, abstol = 1.0f-6, maxiters = 200) end @testset "Vector" begin # Run a solve on vectors + println("Vector") linear = (u, p, t) -> [cos(2pi * t)] tspan = (0.0f0, 1.0f0) u0 = [0.0f0] @@ -44,14 +46,14 @@ end opt = OptimizationOptimJL.BFGS() sol = solve(prob, NNODE(luxchain, opt), dt = 1 / 20.0f0, abstol = 1e-10, - verbose = true, maxiters = 200) + verbose = false, maxiters = 200) @test_throws ArgumentError solve(prob, NNODE(luxchain, opt; autodiff = true), dt = 1 / 20.0f0, - abstol = 1e-10, verbose = true, maxiters = 200) + abstol = 1e-10, verbose = false, maxiters = 200) sol = solve(prob, NNODE(luxchain, opt), abstol = 1.0f-6, - verbose = true, maxiters = 200) + verbose = false, maxiters = 200) @test sol(0.5) isa Vector @test sol(0.5; idxs = 1) isa Number @@ -59,6 +61,7 @@ end end @testset "Example 1" begin + println("Example 1") linear = (u, p, t) -> @. t^3 + 2 * t + (t^2) * ((1 + 3 * (t^2)) / (1 + t + (t^3))) - u * (t + ((1 + 3 * (t^2)) / (1 + t + t^3))) linear_analytic = (u0, p, t) -> [exp(-(t^2) / 2) / (1 + t + t^3) + t^2] @@ -66,75 +69,70 @@ end luxchain = Lux.Chain(Lux.Dense(1, 128, Lux.σ), Lux.Dense(128, 1)) opt = OptimizationOptimisers.Adam(0.01) - sol = solve(prob, NNODE(luxchain, opt), verbose = true, maxiters = 400) + sol = solve(prob, NNODE(luxchain, opt), verbose = false, maxiters = 400) @test sol.errors[:l2] < 0.5 - @test_throws AssertionError solve(prob, NNODE(luxchain, opt; batch = true), verbose = true, - maxiters = 400) - sol = solve(prob, NNODE(luxchain, opt; batch = false, strategy = StochasticTraining(100)), - verbose = true, maxiters = 400) + verbose = false, maxiters = 400) @test sol.errors[:l2] < 0.5 sol = solve(prob, NNODE(luxchain, opt; batch = true, strategy = StochasticTraining(100)), - verbose = true, maxiters = 400) + verbose = false, maxiters = 400) @test sol.errors[:l2] < 0.5 - sol = solve(prob, NNODE(luxchain, opt; batch = false), verbose = true, + sol = solve(prob, NNODE(luxchain, opt; batch = false), verbose = false, maxiters = 400, dt = 1 / 5.0f0) @test sol.errors[:l2] < 0.5 - sol = solve(prob, NNODE(luxchain, opt; batch = true), verbose = true, + sol = solve(prob, NNODE(luxchain, opt; batch = true), verbose = false, maxiters = 400, dt = 1 / 5.0f0) @test sol.errors[:l2] < 0.5 end @testset "Example 2" begin + println("Example 2") linear = (u, p, t) -> -u / 5 + exp(-t / 5) .* cos(t) linear_analytic = (u0, p, t) -> exp(-t / 5) * (u0 + sin(t)) prob = ODEProblem(ODEFunction(linear, analytic = linear_analytic), 0.0f0, (0.0f0, 1.0f0)) luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) opt = OptimizationOptimisers.Adam(0.1) - sol = solve(prob, NNODE(luxchain, opt), verbose = true, maxiters = 400, + sol = solve(prob, NNODE(luxchain, opt), verbose = false, maxiters = 400, abstol = 1.0f-8) @test sol.errors[:l2] < 0.5 - @test_throws AssertionError solve(prob, NNODE(luxchain, opt; batch = true), verbose = true, - maxiters = 400, - abstol = 1.0f-8) - sol = solve(prob, NNODE(luxchain, opt; batch = false, strategy = StochasticTraining(100)), - verbose = true, maxiters = 400, + verbose = false, maxiters = 400, abstol = 1.0f-8) @test sol.errors[:l2] < 0.5 sol = solve(prob, NNODE(luxchain, opt; batch = true, strategy = StochasticTraining(100)), - verbose = true, maxiters = 400, + verbose = false, maxiters = 400, abstol = 1.0f-8) @test sol.errors[:l2] < 0.5 - sol = solve(prob, NNODE(luxchain, opt; batch = false), verbose = true, + sol = solve(prob, NNODE(luxchain, opt; batch = false), verbose = false, maxiters = 400, abstol = 1.0f-8, dt = 1 / 5.0f0) @test sol.errors[:l2] < 0.5 - sol = solve(prob, NNODE(luxchain, opt; batch = true), verbose = true, + sol = solve(prob, NNODE(luxchain, opt; batch = true), verbose = false, maxiters = 400, abstol = 1.0f-8, dt = 1 / 5.0f0) @test sol.errors[:l2] < 0.5 end @testset "Example 3" begin + println("Example 3") linear = (u, p, t) -> [cos(2pi * t), sin(2pi * t)] tspan = (0.0f0, 1.0f0) u0 = [0.0f0, -1.0f0 / 2pi] @@ -146,13 +144,14 @@ end alg = NNODE(luxchain, opt; autodiff = false) sol = solve(prob, - alg, verbose = true, dt = 1 / 40.0f0, + alg, verbose = false, dt = 1 / 40.0f0, maxiters = 2000, abstol = 1.0f-7) @test sol.errors[:l2] < 0.5 end @testset "Training Strategies" begin @testset "WeightedIntervalTraining" begin + println("WeightedIntervalTraining") function f(u, p, t) [p[1] * u[1] - p[2] * u[1] * u[2], -p[3] * u[2] + p[4] * u[1] * u[2]] end @@ -169,7 +168,7 @@ end points = 200 alg = NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.WeightedIntervalTraining(weights, points)) - sol = solve(prob_oop, alg, verbose = true, maxiters = 100000, saveat = 0.01) + sol = solve(prob_oop, alg, verbose = false, maxiters = 5000, saveat = 0.01) @test abs(mean(sol) - mean(true_sol)) < 0.2 end @@ -183,6 +182,7 @@ end u_analytical(x) = (1 / (2pi)) .* sin.(2pi .* x) @testset "GridTraining" begin + println("GridTraining") luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) @@ -190,22 +190,24 @@ end end alg1 = NNODE(luxchain, opt, strategy = GridTraining(0.01), additional_loss = additional_loss) - sol1 = solve(prob, alg1, verbose = true, abstol = 1e-8, maxiters = 500) + sol1 = solve(prob, alg1, verbose = false, abstol = 1e-8, maxiters = 500) @test sol1.errors[:l2] < 0.5 end @testset "QuadratureTraining" begin + println("QuadratureTraining") luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) end alg1 = NNODE(luxchain, opt, additional_loss = additional_loss) - sol1 = solve(prob, alg1, verbose = true, abstol = 1e-10, maxiters = 200) + sol1 = solve(prob, alg1, verbose = false, abstol = 1e-10, maxiters = 200) @test sol1.errors[:l2] < 0.5 end @testset "StochasticTraining" begin + println("StochasticTraining") luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) @@ -213,12 +215,13 @@ end end alg1 = NNODE(luxchain, opt, strategy = StochasticTraining(1000), additional_loss = additional_loss) - sol1 = solve(prob, alg1, verbose = true, abstol = 1e-8, maxiters = 500) + sol1 = solve(prob, alg1, verbose = false, abstol = 1e-8, maxiters = 500) @test sol1.errors[:l2] < 0.5 end end @testset "Parameter Estimation" begin + println("Parameter Estimation") function lorenz(u, p, t) return [p[1]*(u[2]-u[1]), u[1]*(p[2]-u[3])-u[2], @@ -240,14 +243,15 @@ end Lux.Dense(n, n, Lux.σ), Lux.Dense(n, 3) ) - opt = OptimizationOptimJL.LBFGS(linesearch = BackTracking()) + opt = OptimizationOptimJL.BFGS(linesearch = BackTracking()) alg = NNODE(luxchain, opt, strategy = GridTraining(0.01), param_estim = true, additional_loss = additional_loss) - sol = solve(prob, alg, verbose = true, abstol = 1e-8, maxiters = 5000, saveat = t_) + sol = solve(prob, alg, verbose = false, abstol = 1e-8, maxiters = 1000, saveat = t_) @test sol.k.u.p≈true_p atol=1e-2 @test reduce(hcat, sol.u)≈u_ atol=1e-2 end @testset "Translating from Flux" begin + println("Translating from Flux") linear = (u, p, t) -> cos(2pi * t) linear_analytic = (u, p, t) -> (1 / (2pi)) * sin(2pi * t) tspan = (0.0, 1.0) @@ -259,6 +263,6 @@ end fluxchain = Flux.Chain(Flux.Dense(1, 5, Flux.σ), Flux.Dense(5, 1)) alg1 = NNODE(fluxchain, opt) @test alg1.chain isa Lux.AbstractExplicitLayer - sol1 = solve(prob, alg1, verbose = true, abstol = 1e-10, maxiters = 200) + sol1 = solve(prob, alg1, verbose = false, abstol = 1e-10, maxiters = 200) @test sol1.errors[:l2] < 0.5 end diff --git a/test/NNODE_tstops_test.jl b/test/NNODE_tstops_test.jl index c0f8422a0..bc4a4b08d 100644 --- a/test/NNODE_tstops_test.jl +++ b/test/NNODE_tstops_test.jl @@ -31,46 +31,55 @@ points = 3 dx = 1.0 @testset "GridTraining" begin + println("GridTraining") @testset "Without added points" begin + println("Without added points") # (difference between solutions should be high) alg = NNODE(chain, opt, autodiff = false, strategy = GridTraining(dx)) - sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat) + sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat) @test abs(mean(sol) - mean(true_sol)) > threshold end @testset "With added points" begin + println("With added points") # (difference between solutions should be low) alg = NNODE(chain, opt, autodiff = false, strategy = GridTraining(dx)) - sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, tstops = addedPoints) + sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat, tstops = addedPoints) @test abs(mean(sol) - mean(true_sol)) < threshold end end @testset "WeightedIntervalTraining" begin + println("WeightedIntervalTraining") @testset "Without added points" begin + println("Without added points") # (difference between solutions should be high) alg = NNODE(chain, opt, autodiff = false, strategy = WeightedIntervalTraining(weights, points)) - sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat) + sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat) @test abs(mean(sol) - mean(true_sol)) > threshold end @testset "With added points" begin + println("With added points") # (difference between solutions should be low) alg = NNODE(chain, opt, autodiff = false, strategy = WeightedIntervalTraining(weights, points)) - sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, tstops = addedPoints) + sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat, tstops = addedPoints) @test abs(mean(sol) - mean(true_sol)) < threshold end end @testset "StochasticTraining" begin + println("StochasticTraining") @testset "Without added points" begin + println("Without added points") # (difference between solutions should be high) alg = NNODE(chain, opt, autodiff = false, strategy = StochasticTraining(points)) - sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat) + sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat) @test abs(mean(sol) - mean(true_sol)) > threshold end @testset "With added points" begin + println("With added points") # (difference between solutions should be low) alg = NNODE(chain, opt, autodiff = false, strategy = StochasticTraining(points)) - sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, tstops = addedPoints) + sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat, tstops = addedPoints) @test abs(mean(sol) - mean(true_sol)) < threshold end end