Skip to content

Commit

Permalink
add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Mar 19, 2024
1 parent c058f5e commit 48c762b
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 54 deletions.
63 changes: 62 additions & 1 deletion docs/src/tutorials/pino_ode.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,64 @@
# Physics informed Neural Operator ODEs Solvers

## some example TODO
This tutorial is an introduction to using physics-informed neural operator (PINOs) for solving family of parametric ordinary diferential equations (ODEs).


## Solving a family of parametric ODE.

```@example pino
using Test
using OrdinaryDiffEq, OptimizationOptimisers
using Lux
using Statistics, Random
using NeuralOperators
using NeuralPDE
linear_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
linear = (u, p, t) -> cos(p * t)
tspan = (0.0f0, 2.0f0)
u0 = 0.0f0
```

Generate a dataset for learning a given family of ODEs where the parameter 'a' is varied. The dataset is generated by solving the ODE for different values of 'a' and storing the solutions. The dataset is then used to train the PINO model:
* input data: set of parameters 'a',
* output data: set of solutions u(t){a} corresponding parameter 'a'.

```@example pino
t0, t_end = tspan
instances_size = 50
range_ = range(t0, stop = t_end, length = instances_size)
ts = reshape(collect(range_), 1, instances_size)
batch_size = 50
as = [Float32(i) for i in range(0.1, stop = pi / 2, length = batch_size)]
u_output_ = zeros(Float32, 1, instances_size, batch_size)
prob_set = []
for (i, a_i) in enumerate(as)
prob = ODEProblem(ODEFunction(linear, analytic = linear_analytic), u0, tspan, a_i)
sol1 = solve(prob, Tsit5(); saveat = 0.0204)
reshape_sol = Float32.(reshape(sol1(range_).u', 1, instances_size, 1))
push!(prob_set, prob)
u_output_[:, :, i] = reshape_sol
end
train_set = TRAINSET(prob_set, u_output_);
```

Here it used the PINO method to train the given family of parametric ODEs.

```@example pino
prob = ODEProblem(linear, u0, tspan, 0)
flat_no = FourierNeuralOperator(ch = (2, 16, 16, 16, 16, 16, 32, 1), modes = (16,),
σ = gelu)
opt = OptimizationOptimisers.Adam(0.03)
alg = PINOODE(flat_no, opt, train_set; is_data_loss = true, is_physics_loss = true)
pino_solution = solve(prob, alg, verbose = false, maxiters = 1000)
predict = pino_solution.predict
ground = u_output_
```

Now let's compare the predictions from the learned operator with the ground truth solution which is obtained early by numerically solving the parametric ODE. Where 'i' is the index of the parameter 'a' in the dataset.

```@example pino
plot(predict[1, :, i], label = "Predicted")
plot!(ground[1, :, i], label = "Ground truth")
```
82 changes: 39 additions & 43 deletions src/pino_ode_solve.jl
Original file line number Diff line number Diff line change
@@ -1,40 +1,32 @@
"""
PINOODE(chain,
OptimizationOptimisers.Adam(0.1),
train_set
init_params = nothing;
train_set,
is_data_loss =true,
is_physics_loss =true,
init_params,
kwargs...)
## Positional Arguments
The method is that combine training data and physics constraints
to learn the solution operator of a given family of parametric Ordinary Differential Equations (ODE).
* `chain`: A neural network architecture, defined as either a `Flux.Chain` or a `Lux.AbstractExplicitLayer`.
## Positional Arguments
* `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer` or `Flux.Chain`.
`Flux.Chain` will be converted to `Lux` using `Lux.transform`.
* `opt`: The optimizer to train the neural network.
* `train_set`:
* `init_params`: The initial parameter of the neural network. By default, this is `nothing`
which thus uses the random initialization provided by the neural network library.
* `train_set`: Contains 'input data' - sr of parameters 'a' and output data - set of solutions
u(t){a} corresponding initial conditions 'u0'.
## Keyword Arguments
* `minibatch`:
## Examples
```julia
```
* `is_data_loss` Includes or off a loss function for training on the data set.
* `is_physics_loss`: Includes or off loss function training on physics-informed approach.
* `init_params`: The initial parameter of the neural network. By default, this is `nothing`
which thus uses the random initialization provided by the neural network library.
* `kwargs`: Extra keyword arguments are splatted to the Optimization.jl `solve` call.
## References
Zongyi Li "Physics-Informed Neural Operator for Learning Partial Differential Equations"
"""
struct TRAINSET{}
input_data::Vector{ODEProblem}
output_data::Array
isu0::Bool
end

function TRAINSET(input_data, output_data; isu0 = false)
TRAINSET(input_data, output_data, isu0)
end

struct PINOODE{C, O, P, K} <: DiffEqBase.AbstractODEAlgorithm
chain::C
opt::O
Expand All @@ -57,10 +49,16 @@ function PINOODE(chain,
PINOODE(chain, opt, train_set, is_data_loss, is_physics_loss, init_params, kwargs)
end

"""
PINOPhi(chain::Lux.AbstractExplicitLayer, t0,u0, st)
TODO
"""
struct TRAINSET{}
input_data::Vector{ODEProblem}
output_data::Array
isu0::Bool
end

function TRAINSET(input_data, output_data; isu0 = false)
TRAINSET(input_data, output_data, isu0)
end

mutable struct PINOPhi{C, T, U, S}
chain::C
t0::T
Expand Down Expand Up @@ -91,7 +89,6 @@ end

function (f::PINOPhi{C, T, U})(t::AbstractArray,
θ) where {C <: Lux.AbstractExplicitLayer, T, U}
# Batch via data as row vectors
y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), t), θ, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
ts = adapt(parameterless_type(ComponentArrays.getdata(θ)), t[1:size(y)[1], :, :])
Expand Down Expand Up @@ -119,9 +116,8 @@ function physics_loss(phi::PINOPhi{C, T, U},
ts::AbstractArray,
train_set::TRAINSET,
input_data_set) where {C, T, U}
prob_set, output_data = train_set.input_data, train_set.output_data #TODO
f = prob_set[1].f #TODO one f for all
p = prob_set[1].p
prob_set, _ = train_set.input_data, train_set.output_data
f = prob_set[1].f
out_ = phi(input_data_set, θ)
ts = adapt(parameterless_type(ComponentArrays.getdata(θ)), ts)
if train_set.isu0 == true
Expand All @@ -132,23 +128,24 @@ function physics_loss(phi::PINOPhi{C, T, U},
if p isa Number
fs = cat(
[f.f.(out_[:, :, [i]], p, ts) for (i, p) in enumerate(ps)]..., dims = 3)
else
elseif p isa Vector
fs = cat(
[reduce(
hcat, [f.f(out_[:, j, [i]], p, ts) for j in axes(out_[:, :, [i]], 2)])
for (i, p) in enumerate(ps)]...,
dims = 3)
else
error("p should be a number or a vector")
end
end
NeuralOperators.l₂loss(dfdx(phi, input_data_set, θ), fs)
end

function data_loss(phi::PINOPhi{C, T, U},
θ,
ts::AbstractArray, #TODO remove unessasry
train_set::TRAINSET,
input_data_set) where {C, T, U}
prob_set, output_data = train_set.input_data, train_set.output_data
_, output_data = train_set.input_data, train_set.output_data
output_data = adapt(parameterless_type(ComponentArrays.getdata(θ)), output_data)
NeuralOperators.l₂loss(phi(input_data_set, θ), output_data)
end
Expand All @@ -164,8 +161,6 @@ function generate_data(ts, prob_set, isu0)
f = prob.f
if isu0 == true
in_ = reduce(vcat, [ts, fill(u0, 1, size(ts)[2], 1)])
#TODO for all case p and u0
# in_ = reduce(vcat, [ts, reduce(hcat, fill(u0, 1, size(ts)[2], 1))])
else
if p isa Number
in_ = reduce(vcat, [ts, fill(p, 1, size(ts)[2], 1)])
Expand All @@ -187,11 +182,11 @@ function generate_loss(
C, T, U}
function loss(θ, _)
if is_data_loss
data_loss(phi, θ, ts, train_set, input_data_set)
data_loss(phi, θ, train_set, input_data_set)
elseif is_physics_loss
physics_loss(phi, θ, ts, train_set, input_data_set)
elseif is_data_loss && is_physics_loss
data_loss(phi, θ, ts, train_set, input_data_set) +
data_loss(phi, θ, train_set, input_data_set) +
physics_loss(phi, θ, ts, train_set, input_data_set)
else
error("data loss or physics loss should be true")
Expand All @@ -211,9 +206,9 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
maxiters = nothing)
tspan = prob.tspan
t0 = tspan[1]
u0 = prob.u0
# f = prob.f
# p = prob.p
u0 = prob.u0
# param_estim = alg.param_estim

chain = alg.chain
Expand All @@ -234,7 +229,7 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
instances_size = size(train_set.output_data)[2]
range_ = range(t0, stop = t_end, length = instances_size)
ts = reshape(collect(range_), 1, instances_size)
prob_set, output_data = train_set.input_data, train_set.output_data
prob_set, _ = train_set.input_data, train_set.output_data
isu0 = train_set.isu0
input_data_set = generate_data(ts, prob_set, isu0)

Expand All @@ -255,13 +250,14 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
phi(input_data_set, init_params)
catch err
if isa(err, DimensionMismatch)
throw(DimensionMismatch("Dimensions of the initial u0 and chain should match")) #TODO change message
throw(DimensionMismatch("Dimensions of input data and chain should match"))
else
throw(err)
end
end

total_loss = generate_loss(phi, train_set, input_data_set, ts, is_data_loss, is_physics_loss)
total_loss = generate_loss(
phi, train_set, input_data_set, ts, is_data_loss, is_physics_loss)

# Optimization Algo for Training Strategies
opt_algo = Optimization.AutoZygote()
Expand Down
11 changes: 5 additions & 6 deletions test/PINO_ode_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@ using NeuralPDE

"""
Set of training data:
* input data: set of parameters 'a':
* output data: set of solutions u(t){a} corresponding parameter 'a'
* input data: set of parameters 'a'
* output data: set of solutions u(t){a} corresponding parameter 'a'.
"""
train_set = TRAINSET(prob_set, u_output_);
#TODO u0 ?
prob = ODEProblem(linear, u0, tspan, 0)
chain = Lux.Chain(Lux.Dense(2, 16, Lux.σ),
Lux.Dense(16, 16, Lux.σ),
Expand Down Expand Up @@ -78,11 +77,11 @@ begin

"""
Set of training data:
* input data: set of initial conditions 'u0':
* output data: set of solutions u(t){u0} corresponding initial conditions 'u0'
* input data: set of initial conditions 'u0'
* output data: set of solutions u(t){u0} corresponding initial conditions 'u0'.
"""
train_set = TRAINSET(prob_set, u_output_; isu0 = true)
#TODO u0 ?
#TODO we argument u0 but dont actualy use u0 because we use only set of u0 for generate train set from prob_set

Check warning on line 84 in test/PINO_ode_tests.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"actualy" should be "actually".
prob = ODEProblem(linear, 0.0f0, tspan, p)
fno = FourierNeuralOperator(ch = (2, 16, 16, 16, 16, 16, 32, 1), modes = (16,), σ = gelu)
opt = OptimizationOptimisers.Adam(0.001)
Expand Down
6 changes: 2 additions & 4 deletions test/PINO_ode_tests_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,11 @@ const gpud = gpu_device()

"""
Set of training data:
* input data: set of parameters 'a':
* output data: set of solutions u(t){a} corresponding parameter 'a'
* input data: set of parameters 'a',
* output data: set of solutions u(t){a} corresponding parameter 'a'.
"""
train_set = TRAINSET(prob_set, u_output_)

#TODO u0 ?
prob = ODEProblem(linear, u0, tspan, 0)
inner = 50
chain = Lux.Chain(Lux.Dense(2, inner, Lux.σ),
Expand Down Expand Up @@ -94,7 +93,6 @@ end
end

train_set = TRAINSET(prob_set, u_output_)
#TODO u0 ?
prob = ODEProblem(lotka_volterra, u0, tspan, p)
flat_no = FourierNeuralOperator(ch = (5, 64, 64, 64, 64, 64, 128, 2), modes = (16,),
σ = gelu)
Expand Down

0 comments on commit 48c762b

Please sign in to comment.