From e253a4dff7f2aade5ab1d828f775689ad345f7b1 Mon Sep 17 00:00:00 2001 From: Samedh Desai Date: Thu, 17 Aug 2023 12:32:37 -0400 Subject: [PATCH] Initial commit for GivenPointsTraining --- src/NeuralPDE.jl | 2 +- src/ode_solve.jl | 14 +++++++++ src/training_strategies.jl | 14 +++++++++ test/dae_problem_test.jl | 48 ------------------------------ test/given_points_training_test.jl | 33 ++++++++++++++++++++ 5 files changed, 62 insertions(+), 49 deletions(-) delete mode 100644 test/dae_problem_test.jl create mode 100644 test/given_points_training_test.jl diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 81a9ad15b..e661dc5a9 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -56,7 +56,7 @@ export NNODE, NNDAE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE, KolmogorovParamDomain, NNParamKolmogorov, PhysicsInformedNN, discretize, GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining, - WeightedIntervalTraining, + WeightedIntervalTraining, GivenPointsTraining, build_loss_function, get_loss_function, generate_training_sets, get_variables, get_argument, get_bounds, get_phi, get_numeric_derivative, get_numeric_integral, diff --git a/src/ode_solve.jl b/src/ode_solve.jl index 82f3bb029..90e95fc2a 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -382,6 +382,20 @@ function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Boo return loss end +function generate_loss(strategy::GivenPointsTraining, phi, f, autodiff::Bool, tspan, p, batch) + ts =strategy.given_points + + # sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken + function loss(θ, _) + if batch + sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p)) + else + sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in ts]) + end + end + return loss +end + function generate_loss(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, tspan) error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional spaces only. Use StochasticTraining instead.") end diff --git a/src/training_strategies.jl b/src/training_strategies.jl index 6c6dacbb7..46ab6d6dd 100644 --- a/src/training_strategies.jl +++ b/src/training_strategies.jl @@ -332,3 +332,17 @@ function get_loss_function(loss_function, train_set, eltypeθ, τ = nothing) loss = (θ) -> mean(abs2, loss_function(train_set, θ)) end + +struct GivenPointsTraining{T} <: AbstractTrainingStrategy + given_points::Vector{T} +end + +function GivenPointsTraining(given_points) + GivenPointsTraining(given_points) +end + +function get_loss_function(loss_function, train_set, eltypeθ, + strategy::GivenPointsTraining; + τ = nothing) + loss = (θ) -> mean(abs2, loss_function(train_set, θ)) +end diff --git a/test/dae_problem_test.jl b/test/dae_problem_test.jl deleted file mode 100644 index ef9c64e28..000000000 --- a/test/dae_problem_test.jl +++ /dev/null @@ -1,48 +0,0 @@ -using Optimisers, OptimizationOptimisers, Sundials -using Lux, Test, Statistics, Plots - -function fu(yp, y, p, tres) - [-0.04 * y[1] + 1.0e4 * y[2] * y[3] - yp[1], - -(-0.04 * y[1] + 1.0e4 * y[2] * y[3]) - 3.0e7 * y[2] * y[2] - yp[2], - y[1] + y[2] + y[3] - 1.0] -end -u0 = [1.0, 0, 0] -du0 = [-0.04, 0.04, 0.0] -p = [1.5, 1.0, 3.0, 1.0] - -""" -The Robertson biochemical reactions in DAE form - -```math -\frac{dy₁}{dt} = -k₁y₁+k₃y₂y₃ -``` -```math -\frac{dy₂}{dt} = k₁y₁-k₂y₂^2-k₃y₂y₃ -``` -```math -1 = y₁ + y₂ + y₃ -``` -where ``k₁=0.04``, ``k₂=3\times10^7``, ``k₃=10^4``. For details, see: -Hairer Norsett Wanner Solving Ordinary Differential Equations I - Nonstiff Problems Page 129 -Usually solved on ``[0,1e11]`` -""" - -prob_oop = DAEProblem{false}(fu, du0, u0, (0.0, 100000.0), p) -true_sol = solve(prob_oop, IDA(), saveat = 0.01) - -func = Lux.σ -N = 12 -chain = Lux.Chain(Lux.Dense(1, N, func), Lux.Dense(N, length(u0))) - -opt = Optimisers.Adam(0.01) -dx = 0.05 -alg = NeuralPDE.NNDAE(chain, opt, autodiff = false, strategy = NeuralPDE.GridTraining(dx)) -sol = solve(prob_oop, alg, verbose=true, maxiters = 100000, saveat = 0.01) - -# println(abs(mean(true_sol .- sol))) - -# using Plots - -# plot(sol) -# plot!(true_sol) -# # ylims!(0,8) \ No newline at end of file diff --git a/test/given_points_training_test.jl b/test/given_points_training_test.jl new file mode 100644 index 000000000..95ef265e5 --- /dev/null +++ b/test/given_points_training_test.jl @@ -0,0 +1,33 @@ + +using OrdinaryDiffEq, OptimizationPolyalgorithms, Lux, OptimizationOptimJL, Test, Statistics, Plots, Optimisers + +function fu(u, p, t) + [p[1] * u[1] - p[2] * u[1] * u[2], -p[3] * u[2] + p[4] * u[1] * u[2]] +end + +p = [1.5, 1.0, 3.0, 1.0] +u0 = [1.0, 1.0] +tspan = (0.0, 3.0) +points1 = [rand() for i=1:140] +points2 = [rand() + 1 for i=1:40] +points3 = [rand() + 2 for i=1:20] +points = vcat(points1, points2, points3) + +prob_oop = ODEProblem{false}(fu, u0, tspan, p) +true_sol = solve(prob_oop, Tsit5(), saveat = 0.01) +func = Lux.σ +N = 12 +chain = Lux.Chain(Lux.Dense(1, N, func), Lux.Dense(N, N, func), Lux.Dense(N, N, func), + Lux.Dense(N, N, func), Lux.Dense(N, length(u0))) + +opt = Optimisers.Adam(0.01) +alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.GivenPointsTraining(points)) +sol = solve(prob_oop, alg, verbose=true, maxiters = 100000, saveat = 0.01) + +@test abs(mean(sol) - mean(true_sol)) < 0.2 + +using Plots + +plot(sol) +plot!(true_sol) +ylims!(0,8)