Skip to content

Commit

Permalink
update multiple parameters task
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Jun 18, 2024
1 parent 188ceec commit eb005c6
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 61 deletions.
102 changes: 65 additions & 37 deletions src/pino_ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,26 @@ end
function dfdx(phi::PINOPhi{C, T}, x::Tuple, θ, prob::ODEProblem) where {C <: DeepONet, T}
p, t = x
f = prob.f
# if any(in(keys(bounds)), (:u0,))
# branch_left, branch_right = f.(0, zeros(size(p)), t .+ sqrt(eps(eltype(p)))),
# f.(0, zeros(size(p)), t)
# else
branch_left, branch_right = f.(0, p, t .+ sqrt(eps(eltype(p)))), f.(0, p, t)
# end
branch_left, branch_right = f.(0, p, t .+ sqrt(eps(eltype(p)))), f.(0, p, t)
trunk_left, trunk_right = t .+ sqrt(eps(eltype(t))), t
x_left = (branch = branch_left, trunk = trunk_left)
x_right = (branch = branch_right, trunk = trunk_right)
(phi(x_left, θ) .- phi(x_right, θ)) / sqrt(eps(eltype(t)))
end

# function physics_loss(
# phi::PINOPhi{C, T}, prob::ODEProblem, x, θ) where {C <: DeepONet, T}
# p, t = x
# f = prob.f
# du = vec(dfdx(phi, x, θ, prob))
# f_ = f.(0, p, t)
# tuple = (branch = f_, trunk = t)
# out = phi(tuple, θ)
# f_ = vec(f.(out, p, t))
# norm = prod(size(out))
# sum(abs2, du .- f_) / norm
# end

function physics_loss(
phi::PINOPhi{C, T}, prob::ODEProblem, x, θ) where {C <: DeepONet, T}
p, t = x
Expand All @@ -101,50 +109,69 @@ function physics_loss(
# tuple = (branch = p, trunk = t)
# out = phi(tuple, θ)
# f.(out, p, t) ?
#TODO if DeepONet else err
du = vec(dfdx(phi, x, θ, prob))
# if any(in(keys(bounds)), (:u0,))
# f_ = f.(0, zeros(size(p)), t)
# else
f_ = f.(0, p, t)
# end
#TODO if DeepONet else error
tuple = (branch = f_, trunk = t)
fs_ = hcat([hcat([f(0, p[:, i, :], t[j]) for i in axes(p, 2)]) for j in axes(t, 3)]...)
fs = reshape(fs_, (1, size(fs_)...))
tuple = (branch = p, trunk = t)
out = phi(tuple, θ)
f_ = vec(f.(out, p, t))

tuple = (branch = fs, trunk = t) #TODO -> tuple = (branch = p, trunk = t)
out = phi(tuple, θ)
# f_ = vec(f.(out, p, t))
norm = prod(size(out))
sum(abs2, du .- f_) / norm
end

function inital_condition_loss(
# function initial_condition_loss(
# phi::PINOPhi{C, T}, prob::ODEProblem, x, θ) where {C <: DeepONet, T}
# p, t = x
# f = prob.f
# t0 = t[:, :, [1]]
# f_0 = f.(0, p, t0)
# tuple = (branch = f_0, trunk = t0)
# out = phi(tuple, θ)
# u = vec(out)
# u0_ = fill(prob.u0, size(out))
# u0 = vec(u0_)

# norm = prod(size(u0))
# sum(abs2, u .- u0) / norm
# end

function initial_condition_loss(
phi::PINOPhi{C, T}, prob::ODEProblem, x, θ) where {C <: DeepONet, T}
p, t = x
f = prob.f
t0 = t[:, :, [1]]
#TODO
# if any(in(keys(bounds)), (:u0,))
# u0_ = collect(p)
# u0 = vec(u0_)
# #TODO f_0 = f.(0, p, t0) and p call as u0
# f_0 = f.(0, zeros(size(u0_)), t0)
# else
u0_ = fill(prob.u0, size(out))
u0 = vec(u0_)
f_0 = f.(0, p, t0)
# end
tuple = (branch = f_0, trunk = t0)
t0 = t[:, :, [5]]
fs_0 = hcat([f(0, p[:, i, :], t[j]) for i in axes(p, 2)]...)
fs0 = reshape(fs_0, (1, size(fs_0)...))
tuple = (branch = fs0, trunk = t0)
out = phi(tuple, θ)
u = vec(out)
u0_ = fill(prob.u0, size(out))
u0 = vec(u0_)

norm = prod(size(u0))
sum(abs2, u .- u0) / norm
end

# function get_trainset(strategy::GridTraining, bounds, tspan)
# db, dt = strategy.dx
# v = values(bounds)[1]
# #TODO for all v
# p_ = v[1]:db:v[2]
# p = reshape(p_, 1, size(p_)[1], 1)
# t_ = collect(tspan[1]:dt:tspan[2])
# t = reshape(t_, 1, 1, size(t_)[1])
# (p, t)
# end

function get_trainset(strategy::GridTraining, bounds, tspan)
db, dt = strategy.dx
v = values(bounds)[1]
#TODO for all v
p_ = v[1]:db:v[2]
p = reshape(p_, 1, size(p_)[1], 1)
dt = strategy.dx
size_of_p = 50
p_ = [range(start = b[1], length = size_of_p, stop = b[2]) for b in bounds]
p = vcat([collect(reshape(p_i, 1, size(p_i)[1], 1)) for p_i in p_]...)
t_ = collect(tspan[1]:dt:tspan[2])
t = reshape(t_, 1, 1, size(t_)[1])
(p, t)
Expand Down Expand Up @@ -174,9 +201,10 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
# if !any(in(keys(bounds)), (:p, :u0))
# error("bounds should contain p only")
# end
if !any(in(keys(bounds)), (:p,))
error("bounds should contain p only")
end
#TODO new p
# if !any(in(keys(bounds)), (:p,))
# error("bounds should contain p only")
# end
phi, init_params = generate_pino_phi_θ(chain, init_params)

isinplace(prob) &&
Expand All @@ -194,7 +222,7 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
end

if strategy isa GridTraining
if length(strategy.dx) !== 2
if length(strategy.dx) !== 2 #TODO ?
throw(ArgumentError("The discretization should have two elements dx= [db,dt],
steps for branch and trunk bounds"))
end
Expand Down
57 changes: 33 additions & 24 deletions test/PINO_ode_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ using NeuralPDE
c = deeponet(x, θ, st)[1]

bounds = (p = [0.1f0, pi],)
# bounds = [0.1f0, pi]
db = (bounds.p[2] - bounds.p[1]) / 50
dt = (tspan[2] - tspan[1]) / 40
strategy = GridTraining([db, dt])
Expand Down Expand Up @@ -141,25 +142,14 @@ end
@test ground_solutionsol.u rtol=0.005
end

plot(sol.u[1, :, :], linetype = :contourf)
plot!(ground_solution[1, :, :], linetype = :contourf)

function plot_()
# Animate
anim = @animate for (i) in 1:41
plot(ground_solution[1, i, :], label = "Ground")
# plot(equation_[1, i, :], label = "equation")
plot!(sol.u[1, i, :], label = "Predicted")
end
gif(anim, "pino.gif", fps = 15)
end

plot_()

#vector outputs and multiple parameters
@testset "Example du = cos(p * t)" begin
# equation = (u, p, t) -> cos(p1 * t) + p2
equation = (u, p, t) -> cos(p[1] * t) + p[2]
function equation1(u, p, t)
p1, p2 = p[1], p[2]
cos(p1 * t) + p2
end

equation = (u, p, t) -> cos(p * t)
tspan = (0.0f0, 1.0f0)
u0 = 1.0f0
prob = ODEProblem(equation, u0, tspan)
Expand All @@ -173,20 +163,39 @@ plot_()
Lux.Dense(10, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast))

deeponet = NeuralPDE.DeepONet(branch, trunk; linear = nothing)

bounds = (p1 = [0.1f0, pi], p2 = [0.1f0, 2.0f0])
db = (bounds.u0[2] - bounds.u0[1]) / 50
deeponet = DeepONet(branch, trunk; linear = nothing)
# p1 = [0.1f0, pi]; p2 = [0.1f0, 2.0f0]
# bounds = (p = [p1, p2],)
#TODO add size_of_p = 50
bounds = [[0.1f0, pi], [0.1f0, 2.0f0]]
# db = 0.025f0
dt = (tspan[2] - tspan[1]) / 40
strategy = NeuralPDE.GridTraining([db, dt])
strategy = GridTraining(dt)
opt = OptimizationOptimisers.Adam(0.03)
alg = NeuralPDE.PINOODE(deeponet, opt, bounds; strategy = strategy)
alg = PINOODE(deeponet, opt, bounds; strategy = strategy)
sol = solve(prob, alg, verbose = false, maxiters = 2000)
ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)

ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
p_ = bounds.p[1]:strategy.dx[1]:bounds.p[2]
p = reshape(p_, 1, size(p_)[1], 1)
ground_solution = ground_analytic.(u0, p, sol.t.trunk)

@test ground_solutionsol.u rtol=0.01
end



plot(sol.u[1, :, :], linetype = :contourf)
plot!(ground_solution[1, :, :], linetype = :contourf)

function plot_()
# Animate
anim = @animate for (i) in 1:41
plot(ground_solution[1, i, :], label = "Ground")
# plot(equation_[1, i, :], label = "equation")
plot!(sol.u[1, i, :], label = "Predicted")
end
gif(anim, "pino.gif", fps = 15)
end

plot_()

0 comments on commit eb005c6

Please sign in to comment.