Skip to content

Commit

Permalink
fine tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Mar 21, 2024
1 parent 10169cb commit 2f2be69
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 50 deletions.
2 changes: 1 addition & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ include("BPINN_ode.jl")
include("PDE_BPINN.jl")
include("dgm.jl")

export NNODE, NNDAE, PINOODE, TRAINSET
export NNODE, NNDAE, PINOODE, TRAINSET, EquationSolving, OperatorLearning
PhysicsInformedNN, discretize,
GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining,
WeightedIntervalTraining,
Expand Down
133 changes: 92 additions & 41 deletions src/pino_ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,43 @@ function TRAINSET(input_data, output_data; isu0 = false)
TRAINSET(input_data, output_data, isu0)
end

mutable struct PINOPhi{C, T, U, S}
chain::C
t0::T
u0::U
st::S
function PINOPhi(chain::Lux.AbstractExplicitLayer, t0, u0, st)
new{typeof(chain), typeof(t0), typeof(u0), typeof(st)}(chain, t0, u0, st)
end
end

struct PINOsolution{}
predict::Array
res::SciMLBase.OptimizationSolution
phi::PINOPhi
input_data_set::Array
end

abstract type PINOPhases end
struct OperatorLearning <: PINOPhases
is_data_loss::Bool
is_physics_loss::Bool
end
function OperatorLearning(; is_data_loss = true, is_physics_loss = true)
OperatorLearning(is_data_loss, is_physics_loss)
end
struct EquationSolving <: PINOPhases
pino_solution::PINOsolution
end

"""
PINOODE(chain,
OptimizationOptimisers.Adam(0.1),
train_set,
is_data_loss =true,
is_physics_loss =true,
init_params,
#TODO update docstring
kwargs...)
The method is that combine training data and physics constraints
Expand All @@ -41,37 +71,20 @@ struct PINOODE{C, O, P, K} <: DiffEqBase.AbstractODEAlgorithm
chain::C
opt::O
train_set::TRAINSET
is_data_loss::Bool
is_physics_loss::Bool
pino_phase::PINOPhases
init_params::P
kwargs::K
end

function PINOODE(chain,
opt,
train_set;
is_data_loss = true,
is_physics_loss = true,
train_set,
pino_phase;
init_params = nothing,
kwargs...)
#TODO fnn transform
#TODO fnn transform check
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
PINOODE(chain, opt, train_set, is_data_loss, is_physics_loss, init_params, kwargs)
end

mutable struct PINOPhi{C, T, U, S}
chain::C
t0::T
u0::U
st::S
function PINOPhi(chain::Lux.AbstractExplicitLayer, t0, u0, st)
new{typeof(chain), typeof(t0), typeof(u0), typeof(st)}(chain, t0, u0, st)
end
end
struct PINOsolution{}
predict::Array
res::SciMLBase.OptimizationSolution
phi::PINOPhi
PINOODE(chain, opt, train_set, pino_phase, init_params, kwargs)
end

function generate_pino_phi_θ(chain::Lux.AbstractExplicitLayer,
Expand Down Expand Up @@ -125,10 +138,10 @@ function physics_loss(phi::PINOPhi{C, T, U},
input_data_set) where {C, T, U}
prob_set, _ = train_set.input_data, train_set.output_data
f = prob_set[1].f
p = prob_set[1].p
out_ = phi(input_data_set, θ)
ts = adapt(parameterless_type(ComponentArrays.getdata(θ)), ts)
if train_set.isu0 == true
p = prob_set[1].p
fs = f.f.(out_, p, ts)
else
ps = [prob.p for prob in prob_set]
Expand Down Expand Up @@ -157,15 +170,15 @@ function data_loss(phi::PINOPhi{C, T, U},
l₂loss(phi(input_data_set, θ), output_data)
end

function generate_data(ts, prob_set, isu0)
function generate_data(ts, prob_set::Vector{ODEProblem}, isu0)
batch_size = size(prob_set)[1]
instances_size = size(ts)[2]
dims = isu0 ? length(prob_set[1].u0) + 1 : length(prob_set[1].p) + 1
input_data_set = Array{Float32, 3}(undef, dims, instances_size, batch_size)
for (i, prob) in enumerate(prob_set)
u0 = prob.u0
p = prob.p
f = prob.f
# f = prob.f
if isu0 == true
in_ = reduce(vcat, [ts, fill(u0, 1, size(ts)[2], 1)])
else
Expand All @@ -185,8 +198,9 @@ end

function generate_loss(
phi::PINOPhi{C, T, U}, train_set::TRAINSET, input_data_set, ts,
is_data_loss, is_physics_loss) where {
pino_phase::OperatorLearning) where {
C, T, U}
is_data_loss, is_physics_loss = pino_phase.is_data_loss, pino_phase.is_physics_loss
function loss(θ, _)
if is_data_loss
data_loss(phi, θ, train_set, input_data_set)
Expand All @@ -202,6 +216,31 @@ function generate_loss(
return loss
end

function finetune_loss(phi::PINOPhi{C, T, U},
θ,
train_set::TRAINSET,
input_data_set,
pino_phase::EquationSolving) where {C, T, U}
_, output_data = train_set.input_data, train_set.output_data
output_data = adapt(parameterless_type(ComponentArrays.getdata(θ)), output_data)
pino_solution = pino_phase.pino_solution
learned_operator = pino_solution.phi
predict = learned_operator(input_data_set, pino_solution.res.u)
l₂loss(phi(input_data_set, θ), predict)
end

function generate_loss(
phi::PINOPhi{C, T, U}, train_set::TRAINSET, input_data_set, ts,
pino_phase::EquationSolving) where {
C, T, U}
a = 1 / 100
function loss(θ, _)
physics_loss(phi, θ, ts, train_set, input_data_set) +
a * finetune_loss(phi, θ, train_set, input_data_set, pino_phase)
end
return loss
end

function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
alg::PINOODE,
args...;
Expand All @@ -212,42 +251,43 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
saveat = nothing,
maxiters = nothing)
tspan = prob.tspan
t0 = tspan[1]
t0, t_end = tspan[1], tspan[2]
u0 = prob.u0
p = prob.p
# f = prob.f
# p = prob.p
# param_estim = alg.param_estim

chain = alg.chain
opt = alg.opt
init_params = alg.init_params
is_data_loss = alg.is_data_loss
is_physics_loss = alg.is_physics_loss

pino_phase = alg.pino_phase
# mapping between functional space of some vararible 'a' of equation (for example initial
# condition {u(t0 x)} or parameter p) and solution of equation u(t)
train_set = alg.train_set

!(chain isa Lux.AbstractExplicitLayer) &&
error("Only Lux.AbstractExplicitLayer neural networks are supported")

t0 = tspan[1]
t_end = tspan[2]
instances_size = size(train_set.output_data)[2]
range_ = range(t0, stop = t_end, length = instances_size)
ts = reshape(collect(range_), 1, instances_size)
prob_set, _ = train_set.input_data, train_set.output_data
prob_set, output_set = train_set.input_data, train_set.output_data
isu0 = train_set.isu0
input_data_set = generate_data(ts, prob_set, isu0)
# input_data_set = if pino_phase == EquationSolving
# generate_data(ts, [prob], isu0)
# elseif pino_phase == OperatorLearning
# generate_data(ts, prob_set, isu0)
# else
# error("pino_phase should be EquationSolving or OperatorLearning")
# end

if isu0
if isu0 #TODO remove the block
u0 = input_data_set[2:end, :, :]
phi, init_params = generate_pino_phi_θ(chain, t0, u0, init_params)
else
u0 = prob.u0
phi, init_params = generate_pino_phi_θ(chain, t0, u0, init_params)
end

phi, init_params = generate_pino_phi_θ(chain, t0, u0, init_params)
init_params = ComponentArrays.ComponentArray(init_params)

isinplace(prob) &&
Expand All @@ -263,8 +303,19 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
end
end

total_loss = generate_loss(
phi, train_set, input_data_set, ts, is_data_loss, is_physics_loss)
if pino_phase isa EquationSolving
#TODO bad code rewrite,the parameter must uniquely match the index
#TODO doenst need TRAINSET for EquationSolving
find(as, a) = findfirst(x -> isapprox(x.p, a.p), as)
index = find(prob_set, prob)
input_data_set = input_data_set[:, :, [index]]
train_set = TRAINSET(prob_set[index:index], output_set[:, :, [index]], isu0)
total_loss = generate_loss(phi, train_set, input_data_set, ts, pino_phase)
elseif pino_phase isa OperatorLearning
total_loss = generate_loss(phi, train_set, input_data_set, ts, pino_phase)
else
error("pino_phase should be EquationSolving or OperatorLearning")
end

# Optimization Algo for Training Strategies
opt_algo = Optimization.AutoZygote()
Expand All @@ -282,5 +333,5 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
optprob = OptimizationProblem(optf, init_params)
res = solve(optprob, opt; callback, maxiters, alg.kwargs...)
predict = phi(input_data_set, res.u)
PINOsolution(predict, res, phi)
PINOsolution(predict, res, phi, input_data_set)
end
28 changes: 22 additions & 6 deletions test/PINO_ode_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Test
using OrdinaryDiffEq, OptimizationOptimisers
using Lux
using Statistics, Random
#using NeuralOperators
# using NeuralOperators
using NeuralPDE

@testset "Example p" begin
Expand Down Expand Up @@ -33,8 +33,9 @@ using NeuralPDE
* input data: set of parameters 'a'
* output data: set of solutions u(t){a} corresponding parameter 'a'.
"""
train_set = TRAINSET(prob_set, u_output_);
prob = ODEProblem(linear, u0, tspan, 0)
train_set = TRAINSET(prob_set, u_output_)
p = pi / 2
prob = ODEProblem(linear, u0, tspan, p)
chain = Lux.Chain(Lux.Dense(2, 16, Lux.σ),
Lux.Dense(16, 16, Lux.σ),
Lux.Dense(16, 16, Lux.σ),
Expand All @@ -44,14 +45,28 @@ using NeuralPDE
# flat_no = FourierNeuralOperator(ch = (2, 16, 16, 16, 16, 16, 32, 1), modes = (16,),
# σ = gelu)
opt = OptimizationOptimisers.Adam(0.03)
alg = PINOODE(
chain, opt, train_set; is_data_loss = true, is_physics_loss = true)
# pino_phase = OperatorLearning(train_set, is_data_loss = true, is_physics_loss = true)
pino_phase = OperatorLearning(
is_data_loss = true, is_physics_loss = true)
alg = PINOODE(chain, opt, train_set, pino_phase)
# pino_solution = learn()
pino_solution = solve(prob, alg, verbose = false, maxiters = 2000)
predict = pino_solution.predict
ground = u_output_
@test groundpredict atol=1

pino_phase = EquationSolving(pino_solution)
alg = PINOODE(chain, opt, train_set, pino_phase)
pino_solution = solve(prob, alg, verbose = true, maxiters = 2000)

find(as, a) = findfirst(x -> isapprox(x.p, a.p), as)
index = find(prob_set, prob)
predict = pino_solution.predict
ground = u_output_[:,:, [index]]
@test groundpredict atol=0.1
end


@testset "Example u0" begin
linear_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
linear = (u, p, t) -> cos(p * t)
Expand Down Expand Up @@ -91,7 +106,8 @@ end
Lux.Dense(16, 32, Lux.σ),
Lux.Dense(32, 1))
opt = OptimizationOptimisers.Adam(0.001)
alg = PINOODE(chain, opt, train_set)
pino_phase = OperatorLearning()
alg = PINOODE(chain, opt, train_set, pino_phase)
pino_solution = solve(prob, alg, verbose = false, maxiters = 2000)
predict = pino_solution.predict
ground = u_output_
Expand Down
6 changes: 4 additions & 2 deletions test/PINO_ode_tests_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ const gpud = gpu_device()
Lux.Dense(inner, 1))
ps = Lux.setup(Random.default_rng(), chain)[1] |> ComponentArray |> gpud
opt = OptimizationOptimisers.Adam(0.03)
alg = PINOODE(chain, opt, train_set; init_params = ps)
pino_phase = OperatorLearning()
alg = PINOODE(chain, opt, train_set, pino_phase; init_params = ps)
pino_solution = solve(prob, alg, verbose = false, maxiters = 2000)
predict = pino_solution.predict |> cpu
ground = u_output_ |> cpu
Expand Down Expand Up @@ -108,8 +109,9 @@ end
ps = Lux.setup(Random.default_rng(), chain)[1] |> ComponentArray |> gpud

opt = OptimizationOptimisers.Adam(0.001)
pino_phase = OperatorLearning()
alg = PINOODE(
chain, opt, train_set; init_params = ps, is_data_loss = true, is_physics_loss = true)
chain, opt, train_set, pino_phase; init_params = ps)
pino_solution = solve(prob, alg, verbose = false, maxiters = 4000)
predict = pino_solution.predict |> cpu
ground = u_output_
Expand Down

0 comments on commit 2f2be69

Please sign in to comment.