Skip to content

Commit

Permalink
Merge pull request #839 from sathvikbhagavan/sb/complex
Browse files Browse the repository at this point in the history
feat: allow complex for NNODE
  • Loading branch information
ChrisRackauckas authored Mar 29, 2024
2 parents 0cb2e06 + dad8815 commit 0856525
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 5 deletions.
3 changes: 2 additions & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ pages = ["index.md",
"examples/heterogeneous.md",
"examples/linear_parabolic.md",
"examples/nonlinear_elliptic.md",
"examples/nonlinear_hyperbolic.md"],
"examples/nonlinear_hyperbolic.md",
"examples/complex.md"],
"Manual" => Any["manual/ode.md",
"manual/dae.md",
"manual/pinns.md",
Expand Down
95 changes: 95 additions & 0 deletions docs/src/examples/complex.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Complex Equations with PINNs

NeuralPDE supports training PINNs with complex differential equations. This example will demonstrate how to use it for [`NNODE`](@ref). Let us consider a system of [bloch equations](https://en.wikipedia.org/wiki/Bloch_equations). Note [`QuadratureTraining`](@ref) cannot be used with complex equations due to current limitations of computing quadratures.

As the input to this neural network is time which is real, we need to initialize the parameters of the neural network with complex values for it to output and train with complex values.

```@example complex
using Random, NeuralPDE
using OrdinaryDiffEq
using Lux, OptimizationOptimisers
using Plots
rng = Random.default_rng()
Random.seed!(100)
function bloch_equations(u, p, t)
Ω, Δ, Γ = p
γ = Γ / 2
ρ₁₁, ρ₂₂, ρ₁₂, ρ₂₁ = u
d̢ρ = [im * Ω * (ρ₁₂ - ρ₂₁) + Γ * ρ₂₂;
-im * Ω * (ρ₁₂ - ρ₂₁) - Γ * ρ₂₂;
-(γ + im * Δ) * ρ₁₂ - im * Ω * (ρ₂₂ - ρ₁₁);
conj(-(γ + im * Δ) * ρ₁₂ - im * Ω * (ρ₂₂ - ρ₁₁))]
return d̢ρ
end
u0 = zeros(ComplexF64, 4)
u0[1] = 1.0
time_span = (0.0, 2.0)
parameters = [2.0, 0.0, 1.0]
problem = ODEProblem(bloch_equations, u0, time_span, parameters)
chain = Lux.Chain(
Lux.Dense(1, 16, tanh; init_weight = (rng, a...) -> Lux.kaiming_normal(rng, ComplexF64, a...)) ,
Lux.Dense(16, 4; init_weight = (rng, a...) -> Lux.kaiming_normal(rng, ComplexF64, a...))
)
ps, st = Lux.setup(rng, chain)
opt = OptimizationOptimisers.Adam(0.01)
ground_truth = solve(problem, Tsit5(), saveat = 0.01)
alg = NNODE(chain, opt, ps; strategy = StochasticTraining(500))
sol = solve(problem, alg, verbose = false, maxiters = 5000, saveat = 0.01)
```

Now, lets plot the predictions.

`u1`:

```@example complex
plot(sol.t, real.(reduce(hcat, sol.u)[1, :]));
plot!(ground_truth.t, real.(reduce(hcat, ground_truth.u)[1, :]))
```

```@example complex
plot(sol.t, imag.(reduce(hcat, sol.u)[1, :]));
plot!(ground_truth.t, imag.(reduce(hcat, ground_truth.u)[1, :]))
```

`u2`:

```@example complex
plot(sol.t, real.(reduce(hcat, sol.u)[2, :]));
plot!(ground_truth.t, real.(reduce(hcat, ground_truth.u)[2, :]))
```

```@example complex
plot(sol.t, imag.(reduce(hcat, sol.u)[2, :]));
plot!(ground_truth.t, imag.(reduce(hcat, ground_truth.u)[2, :]))
```

`u3`:

```@example complex
plot(sol.t, real.(reduce(hcat, sol.u)[3, :]));
plot!(ground_truth.t, real.(reduce(hcat, ground_truth.u)[3, :]))
```

```@example complex
plot(sol.t, imag.(reduce(hcat, sol.u)[3, :]));
plot!(ground_truth.t, imag.(reduce(hcat, ground_truth.u)[3, :]))
```

`u4`:

```@example complex
plot(sol.t, real.(reduce(hcat, sol.u)[4, :]));
plot!(ground_truth.t, real.(reduce(hcat, ground_truth.u)[4, :]))
```

```@example complex
plot(sol.t, imag.(reduce(hcat, sol.u)[4, :]));
plot!(ground_truth.t, imag.(reduce(hcat, ground_truth.u)[4, :]))
```

We can see it is able to learn the real parts of `u1`, `u2` and imaginary parts of `u3`, `u4`.
8 changes: 4 additions & 4 deletions docs/src/tutorials/neural_adapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ function loss(cord, θ)
ch2 .- phi(cord, res.u)
end
strategy = NeuralPDE.QuadratureTraining(; reltol = 1e-6)
strategy = NeuralPDE.QuadratureTraining(; reltol = 1e-6, abstol = 1e-3)
prob_ = NeuralPDE.neural_adapter(loss, init_params2, pde_system, strategy)
res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 10000)
Expand Down Expand Up @@ -173,7 +173,7 @@ for i in 1:count_decomp
bcs_ = create_bcs(domains_[1].domain, phi_bound)
@named pde_system_ = PDESystem(eq, bcs_, domains_, [x, y], [u(x, y)])
push!(pde_system_map, pde_system_)
strategy = NeuralPDE.QuadratureTraining(; reltol = 1e-6)
strategy = NeuralPDE.QuadratureTraining(; reltol = 1e-6, abstol = 1e-3)
discretization = NeuralPDE.PhysicsInformedNN(chains[i], strategy;
init_params = init_params[i])
Expand Down Expand Up @@ -243,10 +243,10 @@ callback = function (p, l)
end
prob_ = NeuralPDE.neural_adapter(losses, init_params2, pde_system_map,
NeuralPDE.QuadratureTraining(; reltol = 1e-6))
NeuralPDE.QuadratureTraining(; reltol = 1e-6, abstol = 1e-3))
res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 5000)
prob_ = NeuralPDE.neural_adapter(losses, res_.u, pde_system_map,
NeuralPDE.QuadratureTraining(; reltol = 1e-6))
NeuralPDE.QuadratureTraining(; reltol = 1e-6, abstol = 1e-3))
res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 5000)
phi_ = PhysicsInformedNN(chain2, strategy; init_params = res_.u).phi
Expand Down
3 changes: 3 additions & 0 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ function (f::NNODEInterpolation)(t::Vector, idxs, ::Type{Val{0}}, p, continuity)
end

SciMLBase.interp_summary(::NNODEInterpolation) = "Trained neural network interpolation"
SciMLBase.allowscomplex(::NNODE) = true

function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
alg::NNODE,
Expand Down Expand Up @@ -357,6 +358,8 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,

!(chain isa Lux.AbstractExplicitLayer) && error("Only Lux.AbstractExplicitLayer neural networks are supported")
phi, init_params = generate_phi_θ(chain, t0, u0, init_params)
((eltype(eltype(init_params).types[1]) <: Complex || eltype(eltype(init_params).types[2]) <: Complex) && alg.strategy isa QuadratureTraining) &&
error("QuadratureTraining cannot be used with complex parameters. Use other strategies.")

init_params = if alg.param_estim
ComponentArrays.ComponentArray(; depvar = ComponentArrays.ComponentArray(init_params), p = prob.p)
Expand Down
40 changes: 40 additions & 0 deletions test/NNODE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import Lux, OptimizationOptimisers, OptimizationOptimJL
using Flux
using LineSearches

rng = Random.default_rng()
Random.seed!(100)

@testset "Scalar" begin
Expand Down Expand Up @@ -250,6 +251,45 @@ end
@test reduce(hcat, sol.u)u_ atol=1e-2
end

@testset "Complex Numbers" begin
function bloch_equations(u, p, t)
Ω, Δ, Γ = p
γ = Γ / 2
ρ₁₁, ρ₂₂, ρ₁₂, ρ₂₁ = u
d̢ρ = [im * Ω * (ρ₁₂ - ρ₂₁) + Γ * ρ₂₂;
-im * Ω * (ρ₁₂ - ρ₂₁) - Γ * ρ₂₂;
-+ im * Δ) * ρ₁₂ - im * Ω * (ρ₂₂ - ρ₁₁);
conj(-+ im * Δ) * ρ₁₂ - im * Ω * (ρ₂₂ - ρ₁₁))]
return d̢ρ
end

u0 = zeros(ComplexF64, 4)
u0[1] = 1
time_span = (0.0, 2.0)
parameters = [2.0, 0.0, 1.0]

problem = ODEProblem(bloch_equations, u0, time_span, parameters)

chain = Lux.Chain(
Lux.Dense(1, 16, tanh; init_weight = (rng, a...) -> Lux.kaiming_normal(rng, ComplexF64, a...)) ,
Lux.Dense(16, 4; init_weight = (rng, a...) -> Lux.kaiming_normal(rng, ComplexF64, a...))
)
ps, st = Lux.setup(rng, chain)

opt = OptimizationOptimisers.Adam(0.01)
ground_truth = solve(problem, Tsit5(), saveat = 0.01)
strategies = [StochasticTraining(500), GridTraining(0.01), WeightedIntervalTraining([0.1, 0.4, 0.4, 0.1], 500)]

@testset "$(nameof(typeof(strategy)))" for strategy in strategies
alg = NNODE(chain, opt, ps; strategy)
sol = solve(problem, alg, verbose = false, maxiters = 5000, saveat = 0.01)
@test sol.u ground_truth.u rtol=1e-1
end

alg = NNODE(chain, opt, ps; strategy = QuadratureTraining())
@test_throws ErrorException solve(problem, alg, verbose = false, maxiters = 5000, saveat = 0.01)
end

@testset "Translating from Flux" begin
println("Translating from Flux")
linear = (u, p, t) -> cos(2pi * t)
Expand Down

0 comments on commit 0856525

Please sign in to comment.