Skip to content

Commit

Permalink
Merge pull request #829 from SciML/ap/hnn
Browse files Browse the repository at this point in the history
Updates to HNN
  • Loading branch information
ChrisRackauckas authored Jun 1, 2023
2 parents a5249a8 + 18e69c2 commit aa4581e
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 122 deletions.
11 changes: 10 additions & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,21 @@ Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationFlux = "253f991c-a7b2-45f8-8852-8b9a9df78a86"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OptimizationPolyalgorithms = "500b13db-7e66-49ce-bda4-eed966be6282"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Expand All @@ -28,20 +33,24 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
[compat]
CSV = "0.10"
ComponentArrays = "0.13"
DataFrames = "1"
DataDeps = "0.7"
DataFrames = "1"
DiffEqFlux = "2"
DifferentialEquations = "7.6.0"
Distances = "0.10.7"
Distributions = "0.25.78"
Documenter = "0.27"
Flux = "0.13.7"
ForwardDiff = "0.10"
IterTools = "1"
Lux = "0.4.34"
MLDataUtils = "0.5"
MLDatasets = "0.7"
Optimisers = "0.2"
Optimization = "3.9"
OptimizationFlux = "0.1"
OptimizationOptimJL = "0.1"
OptimizationOptimisers = "0.1"
OptimizationPolyalgorithms = "0.1"
OrdinaryDiffEq = "6.31"
Plots = "1.36"
Expand Down
123 changes: 62 additions & 61 deletions docs/src/examples/hamiltonian_nn.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,50 +9,47 @@ m\ddot x + kx = 0
Now we make some simplifying assumptions, and assign ``m = 1`` and ``k = 1``. Analytically solving this equation, we get ``x = sin(t)``. Hence, ``q = sin(t)``, and ``p = cos(t)``. Using these solutions, we generate our dataset and fit the `NeuralHamiltonianDE` to learn the dynamics of this system.

```@example hamiltonian_cp
using Flux, DiffEqFlux, DifferentialEquations, Statistics, Plots, ReverseDiff
using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random,
ComponentArrays, Optimization, OptimizationOptimisers, IterTools
t = range(0.0f0, 1.0f0, length = 1024)
t = range(0.0f0, 1.0f0, length=1024)
π_32 = Float32(π)
q_t = reshape(sin.(2π_32 * t), 1, :)
p_t = reshape(cos.(2π_32 * t), 1, :)
dqdt = 2π_32 .* p_t
dpdt = -2π_32 .* q_t
data = cat(q_t, p_t, dims = 1)
target = cat(dqdt, dpdt, dims = 1)
dataloader = Flux.Data.DataLoader((data, target); batchsize=256, shuffle=true)
data = vcat(q_t, p_t)
target = vcat(dqdt, dpdt)
B = 256
NEPOCHS = 500
dataloader = ncycle(((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))),
selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))))
for i in 1:(size(data, 2) ÷ B)), NEPOCHS)
hnn = HamiltonianNN(
Flux.Chain(Flux.Dense(2, 64, relu), Flux.Dense(64, 1))
)
hnn = HamiltonianNN(Lux.Chain(Lux.Dense(2, 64, relu), Lux.Dense(64, 1)))
ps, st = Lux.setup(Random.default_rng(), hnn)
ps_c = ps |> ComponentArray
p = hnn.p
opt = ADAM(0.01f0)
opt = ADAM(0.01)
function loss_function(ps, data, target)
pred, st_ = hnn(data, ps, st)
return mean(abs2, pred .- target), pred
end
loss(x, y, p) = mean((hnn(x, p) .- y) .^ 2)
opt_func = OptimizationFunction((ps, _, data, target) -> loss_function(ps, data, target),
Optimization.AutoForwardDiff())
opt_prob = OptimizationProblem(opt_func, ps_c)
callback() = println("Loss Neural Hamiltonian DE = $(loss(data, target, p))")
res = Optimization.solve(opt_prob, opt, dataloader)
epochs = 500
for epoch in 1:epochs
for (x, y) in dataloader
gs = ReverseDiff.gradient(p -> loss(x, y, p), p)
Flux.Optimise.update!(opt, p, gs)
end
if epoch % 100 == 1
callback()
end
end
callback()
ps_trained = res.u
model = NeuralHamiltonianDE(
hnn, (0.0f0, 1.0f0),
Tsit5(), save_everystep = false,
save_start = true, saveat = t
)
model = NeuralHamiltonianDE(hnn, (0.0f0, 1.0f0), Tsit5(), save_everystep=false,
save_start=true, saveat=t)
pred = Array(model(data[:, 1]))
pred = Array(first(model(data[:, 1], ps_trained, st)))
plot(data[1, :], data[2, :], lw=4, label="Original")
plot!(pred[1, :], pred[2, :], lw=4, label="Predicted")
xlabel!("Position (q)")
Expand All @@ -77,51 +74,52 @@ dpdt = -2π_32 .* q_t
data = cat(q_t, p_t, dims = 1)
target = cat(dqdt, dpdt, dims = 1)
dataloader = Flux.Data.DataLoader((data, target); batchsize=256, shuffle=true)
B = 256
NEPOCHS = 500
dataloader = ncycle(((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))),
selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))))
for i in 1:(size(data, 2) ÷ B)), NEPOCHS)
```

### Training the HamiltonianNN

We parameterize the HamiltonianNN with a small MultiLayered Perceptron (HNN also works with the Fast* Layers provided in DiffEqFlux). HNNs are trained by optimizing the gradients of the Neural Network. Zygote currently doesn't support nesting itself, so we will be using ReverseDiff in the training loop to compute the gradients of the HNN Layer for Optimization.
We parameterize the HamiltonianNN with a small MultiLayered Perceptron. HNNs are trained by optimizing the gradients of the Neural Network. Zygote currently doesn't support nesting itself, so we will be using ForwardDiff in the training loop to compute the gradients of the HNN Layer for Optimization.

```@example hamiltonian
hnn = HamiltonianNN(
Flux.Chain(Flux.Dense(2, 64, relu), Flux.Dense(64, 1))
)
hnn = HamiltonianNN(Lux.Chain(Lux.Dense(2, 64, relu), Lux.Dense(64, 1)))
ps, st = Lux.setup(Random.default_rng(), hnn)
ps_c = ps |> ComponentArray
p = hnn.p
opt = ADAM(0.01f0)
opt = ADAM(0.01)
function loss_function(ps, data, target)
pred, st_ = hnn(data, ps, st)
return mean(abs2, pred .- target), pred
end
loss(x, y, p) = mean((hnn(x, p) .- y) .^ 2)
function callback(ps, loss, pred)
println("Loss: ", loss)
return false
end
callback() = println("Loss Neural Hamiltonian DE = $(loss(data, target, p))")
opt_func = OptimizationFunction((ps, _, data, target) -> loss_function(ps, data, target),
Optimization.AutoForwardDiff())
opt_prob = OptimizationProblem(opt_func, ps_c)
epochs = 500
for epoch in 1:epochs
for (x, y) in dataloader
gs = ReverseDiff.gradient(p -> loss(x, y, p), p)
Flux.Optimise.update!(opt, p, gs)
end
if epoch % 100 == 1
callback()
end
end
callback()
res = solve(opt_prob, opt, dataloader; callback)
ps_trained = res.u
```

### Solving the ODE using trained HNN

In order to visualize the learned trajectories, we need to solve the ODE. We will use the `NeuralHamiltonianDE` layer, which is essentially a wrapper over `HamiltonianNN` layer, and solves the ODE.

```@example hamiltonian
model = NeuralHamiltonianDE(
hnn, (0.0f0, 1.0f0),
Tsit5(), save_everystep = false,
save_start = true, saveat = t
)
model = NeuralHamiltonianDE(hnn, (0.0f0, 1.0f0), Tsit5(), save_everystep=false,
save_start=true, saveat=t)
pred = Array(model(data[:, 1]))
pred = Array(first(model(data[:, 1], ps_trained, st)))
plot(data[1, :], data[2, :], lw=4, label="Original")
plot!(pred[1, :], pred[2, :], lw=4, label="Predicted")
xlabel!("Position (q)")
Expand All @@ -133,12 +131,15 @@ ylabel!("Momentum (p)")
## Expected Output

```julia
Loss Neural Hamiltonian DE = 18.768814
Loss Neural Hamiltonian DE = 0.022630047
Loss Neural Hamiltonian DE = 0.015060622
Loss Neural Hamiltonian DE = 0.013170851
Loss Neural Hamiltonian DE = 0.011898238
Loss Neural Hamiltonian DE = 0.009806873
Loss: 19.865715
Loss: 18.196068
Loss: 19.179213
Loss: 19.58956
Loss: 0.02267044
Loss: 0.019175647
Loss: 0.02218909
Loss: 0.018870523
```

## References
Expand Down
91 changes: 59 additions & 32 deletions src/hnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@ particles. It then returns the time derivatives for position and momentum.
for such applications.
!!! note
This layer currently doesn't support GPU. The support will be added in future
with some AD fixes.
To compute the gradients for this layer, it is recommended to use ForwardDiff.jl
To obtain the gradients to train this network, ReverseDiff.gradient is supposed to
be used. This prevents the usage of `DiffEqFlux.sciml_train` or `Flux.train`. Follow
this [tutorial](https://docs.sciml.ai/DiffEqFlux/stable/examples/hamiltonian_nn/) to see how
To obtain the gradients to train this network, ForwardDiff.gradient is supposed to
be used. Follow this
[tutorial](https://docs.sciml.ai/DiffEqFlux/stable/examples/hamiltonian_nn/) to see how
to define a training loop to circumvent this issue.
```julia
Expand All @@ -32,28 +31,42 @@ References:
[1] Greydanus, Samuel, Misko Dzamba, and Jason Yosinski. "Hamiltonian Neural Networks." Advances in Neural Information Processing Systems 32 (2019): 15379-15389.
"""
struct HamiltonianNN{M, R, P}
struct HamiltonianNN{M,R,P} <: LuxCore.AbstractExplicitContainerLayer{(:model,)}
model::M
re::R
p::P
end

function HamiltonianNN(model; p = nothing)
_p, re = Flux.destructure(model)
if p === nothing
p = _p
end
return new{typeof(model), typeof(re), typeof(p)}(model, re, p)
end
function HamiltonianNN(model; p=nothing)
_p, re = Flux.destructure(model)
p === nothing && (p = _p)
return HamiltonianNN{typeof(model),typeof(re),typeof(p)}(model, re, p)
end

function HamiltonianNN(model::LuxCore.AbstractExplicitLayer; p=nothing)
@assert p === nothing
return HamiltonianNN{typeof(model),Nothing,Nothing}(model, nothing, nothing)
end

Flux.trainable(hnn::HamiltonianNN) = (hnn.p,)

function _hamiltonian_forward(re, p, x)
H = Flux.gradient(x -> sum(re(p)(x)), x)[1]
H = only(Zygote.gradient(x -> sum(re(p)(x)), x))
n = size(x, 1) ÷ 2
return vcat(selectdim(H, 1, (n+1):2n), -selectdim(H, 1, 1:n))
end

(hnn::HamiltonianNN)(x, p=hnn.p) = _hamiltonian_forward(hnn.re, p, x)

function (hnn::HamiltonianNN{<:LuxCore.AbstractExplicitLayer})(x, ps, st)
(_, st), pb_f = Zygote.pullback(x) do x
y, st_ = hnn.model(x, ps, st)
return sum(y), st_
end
H = only(pb_f((one(eltype(x)), nothing)))
n = size(x, 1) ÷ 2
return cat(H[(n + 1):2n, :], -H[1:n, :], dims=1)
return vcat(selectdim(H, 1, (n+1):2n), -selectdim(H, 1, 1:n)), st
end
(hnn::HamiltonianNN)(x, p = hnn.p) = _hamiltonian_forward(hnn.re, p, x)

"""
Contructs a Neural Hamiltonian DE Layer for solving Hamiltonian Problems
Expand All @@ -78,27 +91,41 @@ struct NeuralHamiltonianDE{M,P,RE,T,A,K} <: NeuralDELayer
tspan::T
args::A
kwargs::K
end

function NeuralHamiltonianDE(model, tspan, args...; p = nothing, kwargs...)
hnn = HamiltonianNN(model, p=p)
new{typeof(hnn.model), typeof(hnn.p), typeof(hnn.re),
typeof(tspan), typeof(args), typeof(kwargs)}(
hnn, hnn.p, tspan, args, kwargs)
end
# TODO: Make sensealg an argument
function NeuralHamiltonianDE(model, tspan, args...; p=nothing, kwargs...)
hnn = HamiltonianNN(model, p=p)
return NeuralHamiltonianDE{typeof(hnn.model),typeof(hnn.p),typeof(hnn.re),
typeof(tspan),typeof(args),typeof(kwargs)}(
hnn, hnn.p, tspan, args, kwargs)
end

function NeuralHamiltonianDE(hnn::HamiltonianNN{M,RE,P}, tspan, args...;
p = hnn.p, kwargs...) where {M,RE,P}
new{M, P, RE, typeof(tspan), typeof(args),
typeof(kwargs)}(hnn, hnn.p, tspan, args, kwargs)
end
function NeuralHamiltonianDE(hnn::HamiltonianNN{M,RE,P}, tspan, args...;
p=hnn.p, kwargs...) where {M,RE,P}
return NeuralHamiltonianDE{M,P,RE,typeof(tspan),typeof(args),
typeof(kwargs)}(hnn, hnn.p, tspan, args, kwargs)
end

function (nhde::NeuralHamiltonianDE)(x, p = nhde.p)
function (nhde::NeuralHamiltonianDE)(x, p=nhde.p)
function neural_hamiltonian!(du, u, p, t)
du .= reshape(nhde.model(u, p), size(du))
end
prob = ODEProblem(neural_hamiltonian!, x, nhde.tspan, p)
# NOTE: Nesting Zygote is an issue. So we can't use ZygoteVJP
sense = InterpolatingAdjoint(autojacvec = false)
solve(prob, nhde.args...; sensealg = sense, nhde.kwargs...)
prob = ODEProblem(ODEFunction{true}(neural_hamiltonian!), x, nhde.tspan, p)
# NOTE: Nesting Zygote is an issue. So we can't use ZygoteVJP. Instead we use
# ForwardDiff.jl internally.
sensealg = InterpolatingAdjoint(; autojacvec=true)
return solve(prob, nhde.args...; sensealg, nhde.kwargs...)
end

function (nhde::NeuralHamiltonianDE{<:LuxCore.AbstractExplicitLayer})(x, ps, st)
function neural_hamiltonian!(du, u, p, t)
y, st = nhde.model(u, p, st)
du .= reshape(y, size(du))
end
prob = ODEProblem(ODEFunction{true}(neural_hamiltonian!), x, nhde.tspan, ps)
# NOTE: Nesting Zygote is an issue. So we can't use ZygoteVJP. Instead we use
# ForwardDiff.jl internally.
sensealg = InterpolatingAdjoint(; autojacvec=true)
return solve(prob, nhde.args...; sensealg, nhde.kwargs...), st
end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GraphSignals = "3ebe565e-a4b5-49c6-aed2-300248c3a9c1"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationFlux = "253f991c-a7b2-45f8-8852-8b9a9df78a86"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Expand Down
Loading

0 comments on commit aa4581e

Please sign in to comment.