Skip to content

Commit

Permalink
fine tunnning update
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Mar 22, 2024
1 parent 2f2be69 commit c38a78f
Show file tree
Hide file tree
Showing 5 changed files with 308 additions and 162 deletions.
21 changes: 21 additions & 0 deletions docs/src/manual/pino_ode.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Physics-Informed Neural operator for solve ODEs

```@docs
PINOODE
```

```@docs
TRAINSET
```

```@docs
PINOsolution
```

```@docs
OperatorLearning
```

```@docs
EquationSolving
```
68 changes: 54 additions & 14 deletions docs/src/tutorials/pino_ode.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,24 @@

This tutorial is an introduction to using physics-informed neural operator (PINOs) for solving family of parametric ordinary diferential equations (ODEs).

#TODO two phase

## Solving a family of parametric ODE.
## Operator Learning for a family of parametric ODE.

```@example pino
using Test
using OrdinaryDiffEq, OptimizationOptimisers
using Lux
using Statistics, Random
using NeuralOperators
# using NeuralOperators
using NeuralPDE
linear_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
linear = (u, p, t) -> cos(p * t)
tspan = (0.0f0, 2.0f0)
u0 = 0.0f0
p = pi / 2f0
prob = ODEProblem(linear, u0, tspan, p)
```

Generate a dataset for learning a given family of ODEs where the parameter 'a' is varied. The dataset is generated by solving the ODE for different values of 'a' and storing the solutions. The dataset is then used to train the PINO model:
Expand All @@ -34,32 +37,69 @@ as = [Float32(i) for i in range(0.1, stop = pi / 2, length = batch_size)]
u_output_ = zeros(Float32, 1, instances_size, batch_size)
prob_set = []
for (i, a_i) in enumerate(as)
prob = ODEProblem(ODEFunction(linear, analytic = linear_analytic), u0, tspan, a_i)
sol1 = solve(prob, Tsit5(); saveat = 0.0204)
prob_ = ODEProblem(ODEFunction(linear, analytic = linear_analytic), u0, tspan, a_i)
sol1 = solve(prob_, Tsit5(); saveat = 0.0204)
reshape_sol = Float32.(reshape(sol1(range_).u', 1, instances_size, 1))
push!(prob_set, prob)
push!(prob_set, prob_)
u_output_[:, :, i] = reshape_sol
end
train_set = TRAINSET(prob_set, u_output_);
train_set = TRAINSET(prob_set, u_output_)
```

Here it used the PINO method to train the given family of parametric ODEs.
Here it used the PINO method to learning operator of the given family of parametric ODEs.

```@example pino
prob = ODEProblem(linear, u0, tspan, 0)
flat_no = FourierNeuralOperator(ch = (2, 16, 16, 16, 16, 16, 32, 1), modes = (16,),
σ = gelu)
opt = OptimizationOptimisers.Adam(0.03)
alg = PINOODE(flat_no, opt, train_set; is_data_loss = true, is_physics_loss = true)
pino_solution = solve(prob, alg, verbose = false, maxiters = 1000)
chain = Lux.Chain(Lux.Dense(2, 16, Lux.σ),
Lux.Dense(16, 16, Lux.σ),
Lux.Dense(16, 16, Lux.σ),
Lux.Dense(16, 32, Lux.σ),
Lux.Dense(32, 32, Lux.σ),
Lux.Dense(32, 1))
# flat_no = FourierNeuralOperator(ch = (2, 16, 16, 16, 16, 16, 32, 1), modes = (16,),
# σ = gelu)
opt = OptimizationOptimisers.Adam(0.01)
pino_phase = OperatorLearning(train_set, is_data_loss = true, is_physics_loss = true)
alg = PINOODE(chain, opt, pino_phase)
pino_solution = solve(
prob, alg, verbose = false, maxiters = 3000)
predict = pino_solution.predict
ground = u_output_
```

Now let's compare the predictions from the learned operator with the ground truth solution which is obtained early by numerically solving the parametric ODE. Where 'i' is the index of the parameter 'a' in the dataset.

```@example pino
i=1
using Plots
i=45
plot(predict[1, :, i], label = "Predicted")
plot!(ground[1, :, i], label = "Ground truth")
```

Now to move on the stage of solving a certain equation using a trained operator and physics

## Solve ODE using learned operator family of parametric ODE for fine tuning.
```@example pino
dt = (t_end - t0) / instances_size
pino_phase = EquationSolving(dt, pino_solution)
chain = Lux.Chain(Lux.Dense(2, 16, Lux.σ),
Lux.Dense(16, 16, Lux.σ),
Lux.Dense(16, 32, Lux.σ),
Lux.Dense(32, 1))
alg = PINOODE(chain, opt, pino_phase)
fine_tune_solution = solve( prob, alg, verbose = false, maxiters = 2000)
fine_tune_predict = fine_tune_solution.predict
operator_predict = pino_solution.phi(
fine_tune_solution.input_data_set, pino_solution.res.u)
ground_fine_tune = linear_analytic.(u0, p, fine_tune_solution.input_data_set[[1], :, :])
```

Compare prediction with ground truth.

```@example pino
plot(operator_predict[1, :, 1], label = "operator_predict")
plot!(fine_tune_predict[1, :, 1], label = "fine_tune_predict")
plot!(ground_fine_tune[1, :, 1], label = "Ground truth")
```
Loading

0 comments on commit c38a78f

Please sign in to comment.