Skip to content

Commit

Permalink
Initial commit for GivenPointsTraining
Browse files Browse the repository at this point in the history
  • Loading branch information
Samedh Desai committed Aug 17, 2023
1 parent 957d20a commit e253a4d
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 49 deletions.
2 changes: 1 addition & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ export NNODE, NNDAE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE,
KolmogorovParamDomain, NNParamKolmogorov,
PhysicsInformedNN, discretize,
GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining,
WeightedIntervalTraining,
WeightedIntervalTraining, GivenPointsTraining,
build_loss_function, get_loss_function,
generate_training_sets, get_variables, get_argument, get_bounds,
get_phi, get_numeric_derivative, get_numeric_integral,
Expand Down
14 changes: 14 additions & 0 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,20 @@ function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Boo
return loss
end

function generate_loss(strategy::GivenPointsTraining, phi, f, autodiff::Bool, tspan, p, batch)
ts =strategy.given_points

# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
function loss(θ, _)
if batch
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p))
else
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in ts])
end
end
return loss
end

function generate_loss(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, tspan)
error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional spaces only. Use StochasticTraining instead.")
end
Expand Down
14 changes: 14 additions & 0 deletions src/training_strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,17 @@ function get_loss_function(loss_function, train_set, eltypeθ,
τ = nothing)
loss = (θ) -> mean(abs2, loss_function(train_set, θ))
end

struct GivenPointsTraining{T} <: AbstractTrainingStrategy
given_points::Vector{T}
end

function GivenPointsTraining(given_points)
GivenPointsTraining(given_points)
end

function get_loss_function(loss_function, train_set, eltypeθ,
strategy::GivenPointsTraining;
τ = nothing)
loss = (θ) -> mean(abs2, loss_function(train_set, θ))
end
48 changes: 0 additions & 48 deletions test/dae_problem_test.jl

This file was deleted.

33 changes: 33 additions & 0 deletions test/given_points_training_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@

using OrdinaryDiffEq, OptimizationPolyalgorithms, Lux, OptimizationOptimJL, Test, Statistics, Plots, Optimisers

function fu(u, p, t)
[p[1] * u[1] - p[2] * u[1] * u[2], -p[3] * u[2] + p[4] * u[1] * u[2]]
end

p = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0, 1.0]
tspan = (0.0, 3.0)
points1 = [rand() for i=1:140]
points2 = [rand() + 1 for i=1:40]
points3 = [rand() + 2 for i=1:20]
points = vcat(points1, points2, points3)

prob_oop = ODEProblem{false}(fu, u0, tspan, p)
true_sol = solve(prob_oop, Tsit5(), saveat = 0.01)
func = Lux.σ
N = 12
chain = Lux.Chain(Lux.Dense(1, N, func), Lux.Dense(N, N, func), Lux.Dense(N, N, func),
Lux.Dense(N, N, func), Lux.Dense(N, length(u0)))

opt = Optimisers.Adam(0.01)
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.GivenPointsTraining(points))
sol = solve(prob_oop, alg, verbose=true, maxiters = 100000, saveat = 0.01)

@test abs(mean(sol) - mean(true_sol)) < 0.2

using Plots

plot(sol)
plot!(true_sol)
ylims!(0,8)

0 comments on commit e253a4d

Please sign in to comment.