Skip to content

Commit

Permalink
clear code, rm ParametricFunction
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Jun 21, 2024
1 parent 214b178 commit 7d81063
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 124 deletions.
97 changes: 23 additions & 74 deletions src/pino_ode_solve.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
struct ParametricFunction{}
function_::Union{Nothing, Function}
bounds::Any
end

"""
PINOODE(chain,
OptimizationOptimisers.Adam(0.1),
Expand Down Expand Up @@ -35,12 +30,13 @@ in which will be train the prediction of parametric ODE.
* Sifan Wang "Learning the solution operator of parametric partial differential equations with physics-informed DeepOnets"
* Zongyi Li "Physics-Informed Neural Operator for Learning Partial Differential Equations"
"""
struct PINOODE{C, O, I, S <: Union{Nothing, AbstractTrainingStrategy},
struct PINOODE{C, O, B, I, S <: Union{Nothing, AbstractTrainingStrategy},
AL <: Union{Nothing, Function}, K} <:
SciMLBase.AbstractODEAlgorithm
chain::C
opt::O
parametric_function::ParametricFunction
bounds::B
number_of_parameters::Int
init_params::I
strategy::S
additional_loss::AL
Expand All @@ -49,13 +45,14 @@ end

function PINOODE(chain,
opt,
parametric_function;
bounds,
number_of_parameters;
init_params = nothing,
strategy = nothing,
additional_loss = nothing,
kwargs...)
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
PINOODE(chain, opt, parametric_function, init_params, strategy, additional_loss, kwargs)
PINOODE(chain, opt, bounds,number_of_parameters, init_params, strategy, additional_loss, kwargs)
end

mutable struct PINOPhi{C, S}
Expand All @@ -82,64 +79,34 @@ function (f::PINOPhi{C, T})(x::NamedTuple, θ) where {C <: NeuralOperator, T}
y
end

function dfdx(phi::PINOPhi{C, T}, x::Tuple, θ, prob::ODEProblem) where {C <: DeepONet, T}
# @unpack function_, bounds = parametric_function
# branch_left, branch_right = function_.(p, t), function_.(p, t .+ sqrt(eps(eltype(t))))
pfs, p, t = x
function dfdx(phi::PINOPhi{C, T}, x::Tuple, θ) where {C <: DeepONet, T}
p, t = x
branch_left, branch_right = p, p
trunk_left, trunk_right = t .+ sqrt(eps(eltype(t))), t
x_left = (branch = branch_left, trunk = trunk_left)
x_right = (branch = branch_right, trunk = trunk_right)
(phi(x_left, θ) .- phi(x_right, θ)) ./ sqrt(eps(eltype(t)))
end

# function physics_loss(
# phi::PINOPhi{C, T}, prob::ODEProblem, x, θ) where {C <: DeepONet, T}
# p, t = x
# f = prob.f
# du = vec(dfdx(phi, x, θ, prob))
# f_ = f.(0, p, t)
# tuple = (branch = f_, trunk = t)
# out = phi(tuple, θ)
# f_ = vec(f.(out, p, t))
# norm = prod(size(out))
# sum(abs2, du .- f_) / norm
# end
# function initial_condition_loss(
# phi::PINOPhi{C, T}, prob::ODEProblem, x, θ) where {C <: DeepONet, T}
# p, t = x
# f = prob.f
# t0 = t[:, :, [1]]
# f_0 = f.(0, p, t0)
# tuple = (branch = f_0, trunk = t0)
# out = phi(tuple, θ)
# u = vec(out)
# u0_ = fill(prob.u0, size(out))
# u0 = vec(u0_)
# norm = prod(size(u0))
# sum(abs2, u .- u0) / norm
# end

function physics_loss(
phi::PINOPhi{C, T}, prob::ODEProblem, x, θ) where {C <: DeepONet, T}
pfs, p, t = x
p, t = x
f = prob.f
tuple = (branch = p, trunk = t)
out = phi(tuple, θ)
if size(p)[1] == 1
fs = f.(out, p, t)
f_vec= vec(fs)
# out_ = vec(out)
else
f_vec = reduce(vcat,[[f(out[i], p[:, i, 1], t[j]) for i in axes(p, 2)] for j in axes(t, 3)])
end
du = vec(dfdx(phi, x, θ, prob))
du = vec(dfdx(phi, x, θ))
norm = prod(size(du))
sum(abs2, du .- f_vec) / norm
end

function initial_condition_loss(phi::PINOPhi{C, T}, prob::ODEProblem, x, θ) where {C <: DeepONet, T}
pfs, p, t = x
p, t = x
t0 = t[:, :, [1]]
# pfs0 = pfs[:, :, [1]]
tuple = (branch = p, trunk = t0)
Expand All @@ -150,40 +117,26 @@ function initial_condition_loss(phi::PINOPhi{C, T}, prob::ODEProblem, x, θ) whe
sum(abs2, u .- u0) / norm
end

# function get_trainset(strategy::GridTraining, bounds, tspan)
# db, dt = strategy.dx
# v = values(bounds)[1]
# #TODO for all v
# p_ = v[1]:db:v[2]
# p = reshape(p_, 1, size(p_)[1], 1)
# t_ = collect(tspan[1]:dt:tspan[2])
# t = reshape(t_, 1, 1, size(t_)[1])
# (p, t)
# end

function get_trainset(
strategy::GridTraining, parametric_function::ParametricFunction, tspan)
@unpack function_, bounds = parametric_function
function get_trainset(strategy::GridTraining, bounds, number_of_parameters, tspan)
dt = strategy.dx
#TODO
size_of_p = 50
if bounds isa Tuple
p_ = range(start = bounds[1], length = size_of_p, stop = bounds[2])
if size(bounds)[1] == 1
bound = bounds[1]
p_ = range(start = bound[1], length = number_of_parameters, stop = bound[2])
p = collect(reshape(p_, 1, size(p_)[1], 1))
else
p_ = [range(start = b[1], length = size_of_p, stop = b[2]) for b in bounds]
p_ = [range(start = b[1], length = number_of_parameters, stop = b[2])
for b in bounds]
p = vcat([collect(reshape(p_i, 1, size(p_i)[1], 1)) for p_i in p_]...)
end

t_ = collect(tspan[1]:dt:tspan[2])
t = reshape(t_, 1, 1, size(t_)[1])
pfs = function_.(p,t)
(pfs, p, t)
(p, t)
end

function generate_loss(
strategy::GridTraining, prob::ODEProblem, phi, parametric_function::ParametricFunction, tspan)
x = get_trainset(strategy, parametric_function, tspan)
strategy::GridTraining, prob::ODEProblem, phi, bounds, number_of_parameters, tspan)
x = get_trainset(strategy, bounds, number_of_parameters, tspan)
function loss(θ, _)
initial_condition_loss(phi, prob, x, θ) + physics_loss(phi, prob, x, θ)
end
Expand All @@ -198,7 +151,7 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
saveat = nothing,
maxiters = nothing)
@unpack tspan, u0, f = prob
@unpack chain, opt, parametric_function, init_params, strategy, additional_loss = alg
@unpack chain, opt, bounds, number_of_parameters, init_params, strategy, additional_loss=alg

if !isa(chain, DeepONet)
error("Only DeepONet neural networks are supported")
Expand Down Expand Up @@ -231,7 +184,7 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
throw(ArgumentError("Only GridTraining strategy is supported"))
end

inner_f = generate_loss(strategy, prob, phi, parametric_function, tspan)
inner_f = generate_loss(strategy, prob, phi, bounds, number_of_parameters, tspan)

function total_loss(θ, _)
L2_loss = inner_f(θ, nothing)
Expand All @@ -241,10 +194,6 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
L2_loss
end

# TODO delete
# total_loss(θ, 0)
# Zygote.gradient(θ -> total_loss(θ, 0), θ)

# Optimization Algo for Training Strategies
opt_algo = Optimization.AutoZygote()

Expand All @@ -261,7 +210,7 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
optprob = OptimizationProblem(optf, init_params)
res = solve(optprob, opt; callback, maxiters, alg.kwargs...)

pfs, p, t = get_trainset(strategy, parametric_function, tspan)
p, t = get_trainset(strategy, bounds, number_of_parameters, tspan)
tuple = (branch = p, trunk = t)
u = phi(tuple, res.u)

Expand Down
75 changes: 25 additions & 50 deletions test/PINO_ode_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Test
using OrdinaryDiffEq, OptimizationOptimisers
using OptimizationOptimisers
using Lux
using Statistics, Random
using NeuralPDE
Expand Down Expand Up @@ -27,24 +27,22 @@ using NeuralPDE
θ, st = Lux.setup(Random.default_rng(), deeponet)

c = deeponet(x, θ, st)[1]
function_(p, t) = cos(p*t)
bounds = (pi, 2pi)
parametric_function = ParametricFunction(function_, bounds)
bounds = [(pi, 2pi)]
number_of_parameters = 50
dt = (tspan[2] - tspan[1]) / 40
strategy = GridTraining(dt)
opt = OptimizationOptimisers.Adam(0.01)
alg = PINOODE(deeponet, opt, parametric_function; strategy = strategy)
alg = PINOODE(deeponet, opt, bounds, number_of_parameters; strategy = strategy)
sol = solve(prob, alg, verbose = true, maxiters = 3000)

phi(tuple, sol.original.u)
sol.original.objective
# TODO intrepretation output another mesh

Check warning on line 39 in test/PINO_ode_tests.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"intrepretation" should be "interpretation".
# x = (branch = p, trunk = t)
# phi(sol.original.u)
# sol.
ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
size_of_p = 50
p_ = range(start = bounds[1], length = size_of_p, stop = bounds[2])
#TDOD another number_of_parameters

Check warning on line 44 in test/PINO_ode_tests.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"TDOD" should be "TODO".
p_ = range(start = bounds[1][1], length = number_of_parameters, stop = bounds[1][2])
p = collect(reshape(p_, 1, size(p_)[1], 1))
ground_solution = ground_analytic.(u0, p, sol.t.trunk)

Expand All @@ -66,23 +64,21 @@ end
Lux.Dense(10, 10, Lux.tanh_fast))

deeponet = DeepONet(branch, trunk)
function_(p, t) = cos(p * t)
bounds = (0.1f0, 2.f0)
parametric_function = ParametricFunction(function_, bounds)
sol.original.objective
bounds = [(0.1f0, 2.f0)]
number_of_parameters = 40
dt = (tspan[2] - tspan[1]) / 40
strategy = GridTraining(dt)

opt = OptimizationOptimisers.Adam(0.01)
alg = PINOODE(deeponet, opt, parametric_function; strategy = strategy)
alg = PINOODE(deeponet, opt, bounds, number_of_parameters; strategy = strategy)

sol = solve(prob, alg, verbose = false, maxiters = 5000)
sol = solve(prob, alg, verbose = false, maxiters = 3000)
sol.original.objective
#if u0 == 1
ground_analytic_(u0, p, t) = (p * sin(p * t) - cos(p * t) + (p^2 + 2) * exp(t)) /
(p^2 + 1)
size_of_p = 50
p_ = range(start = bounds[1], length = size_of_p, stop = bounds[2])

p_ = range(start = bounds[1][1], length = number_of_parameters, stop = bounds[1][2])
p = collect(reshape(p_, 1, size(p_)[1], 1))
ground_solution = ground_analytic_.(u0, p, sol.t.trunk)

Expand All @@ -106,11 +102,8 @@ end
linear = Lux.Chain(Lux.Dense(10, 1))
deeponet = DeepONet(branch, trunk; linear = linear)

function_(p, t) = cos(p * t)
bounds = (0.0f0, 10.0f0)
parametric_function = ParametricFunction(function_, bounds)

# db = (bounds.p[2] - bounds.p[1]) / 50
bounds = [(0.0f0, 10.0f0)]
number_of_parameters = 60
dt = (tspan[2] - tspan[1]) / 40
strategy = GridTraining(dt)

Expand All @@ -133,24 +126,23 @@ end
end

branch_size, trunk_size = 50, 40
p,t = get_trainset(branch_size, trunk_size, bounds, tspan)
data, tuple = get_data()
p,t = get_trainset(branch_size, trunk_size, bounds[1], tspan)
data, tuple_ = get_data()

function additional_loss_(phi, θ)
u = phi(tuple, θ)
u = phi(tuple_, θ)
norm = prod(size(u))
sum(abs2, u .- data) / norm
end
alg = PINOODE(
deeponet, opt, parametric_function; strategy = strategy, additional_loss = additional_loss_)
deeponet, opt, bounds, number_of_parameters; strategy = strategy,
additional_loss = additional_loss_)
sol = solve(prob, alg, verbose = true, maxiters = 2000)

size_of_p = 50
p_ = range(start = bounds[1], length = size_of_p, stop = bounds[2])
p_ = range(start = bounds[1][1], length = number_of_parameters, stop = bounds[1][2])
p = reshape(p_, 1, size(p_)[1], 1)
ground_solution = ground_analytic.(u0, p, sol.t.trunk)

@test ground_solutionsol.u rtol=0.005
@test ground_solutionsol.u rtol=0.01
end

#vector outputs and multiple parameters
Expand All @@ -175,38 +167,21 @@ end
Lux.Dense(10, 10, Lux.tanh_fast))

deeponet = DeepONet(branch, trunk)

#TODO add size_of_p = 50
function_(p, t) = cos(p * t)
bounds = [(0.1f0, pi), (1.0f0, 2.0f0)]
parametric_function = ParametricFunction(function_, bounds)
number_of_parameters = 50
dt = (tspan[2] - tspan[1]) / 40
strategy = GridTraining(dt)
opt = OptimizationOptimisers.Adam(0.03)
alg = PINOODE(deeponet, opt, parametric_function; strategy = strategy)
alg = PINOODE(deeponet, opt, bounds, number_of_parameters; strategy = strategy)
sol = solve(prob, alg, verbose = true, maxiters = 3000)

ga = (u0, p, t) -> u0 + p[1] / p[2] * sin(p[2] * t)
p_ = [range(start = b[1], length = size_of_p, stop = b[2]) for b in bounds]
p_ = [range(start = b[1], length = number_of_parameters, stop = b[2]) for b in bounds]
p = vcat([collect(reshape(p_i, 1, size(p_i)[1], 1)) for p_i in p_]...)
t = sol.t.trunk
ground_solution_ = f_vec = reduce(
hcat, [reduce(
vcat, [ga(u0, p[:, i, 1], t[j]) for i in axes(p, 2)]) for j in axes(t, 3)])
ground_solution = reshape(ground_solution_, 1, size(ground_solution_)...)
@test ground_solutionsol.u rtol=0.01
end

# plot(sol.u[1, :, :], linetype = :contourf)
# plot!(ground_solution[1, :, :], linetype = :contourf)

# function plot_()
# # Animate
# anim = @animate for (i) in 1:41
# plot(ground_solution[1, i, :], label = "Ground")
# # plot(equation_[1, i, :], label = "equation")
# plot!(sol.u[1, i, :], label = "Predicted")
# end
# gif(anim, "pino.gif", fps = 10)
# end

# plot_()

0 comments on commit 7d81063

Please sign in to comment.