diff --git a/docs/Project.toml b/docs/Project.toml index 6bb032c02..296f45e14 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -9,16 +9,21 @@ Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationFlux = "253f991c-a7b2-45f8-8852-8b9a9df78a86" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" +OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" OptimizationPolyalgorithms = "500b13db-7e66-49ce-bda4-eed966be6282" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" @@ -28,20 +33,24 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" [compat] CSV = "0.10" ComponentArrays = "0.13" -DataFrames = "1" DataDeps = "0.7" +DataFrames = "1" DiffEqFlux = "2" DifferentialEquations = "7.6.0" Distances = "0.10.7" Distributions = "0.25.78" Documenter = "0.27" Flux = "0.13.7" +ForwardDiff = "0.10" +IterTools = "1" Lux = "0.4.34" MLDataUtils = "0.5" MLDatasets = "0.7" +Optimisers = "0.2" Optimization = "3.9" OptimizationFlux = "0.1" OptimizationOptimJL = "0.1" +OptimizationOptimisers = "0.1" OptimizationPolyalgorithms = "0.1" OrdinaryDiffEq = "6.31" Plots = "1.36" diff --git a/docs/src/examples/hamiltonian_nn.md b/docs/src/examples/hamiltonian_nn.md index 05999e4e0..d49a6a370 100644 --- a/docs/src/examples/hamiltonian_nn.md +++ b/docs/src/examples/hamiltonian_nn.md @@ -9,50 +9,47 @@ m\ddot x + kx = 0 Now we make some simplifying assumptions, and assign ``m = 1`` and ``k = 1``. Analytically solving this equation, we get ``x = sin(t)``. Hence, ``q = sin(t)``, and ``p = cos(t)``. Using these solutions, we generate our dataset and fit the `NeuralHamiltonianDE` to learn the dynamics of this system. ```@example hamiltonian_cp -using Flux, DiffEqFlux, DifferentialEquations, Statistics, Plots, ReverseDiff +using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random, + ComponentArrays, Optimization, OptimizationOptimisers, IterTools -t = range(0.0f0, 1.0f0, length = 1024) +t = range(0.0f0, 1.0f0, length=1024) π_32 = Float32(π) q_t = reshape(sin.(2π_32 * t), 1, :) p_t = reshape(cos.(2π_32 * t), 1, :) dqdt = 2π_32 .* p_t dpdt = -2π_32 .* q_t -data = cat(q_t, p_t, dims = 1) -target = cat(dqdt, dpdt, dims = 1) -dataloader = Flux.Data.DataLoader((data, target); batchsize=256, shuffle=true) +data = vcat(q_t, p_t) +target = vcat(dqdt, dpdt) +B = 256 +NEPOCHS = 500 +dataloader = ncycle(((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))), + selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2))))) + for i in 1:(size(data, 2) ÷ B)), NEPOCHS) -hnn = HamiltonianNN( - Flux.Chain(Flux.Dense(2, 64, relu), Flux.Dense(64, 1)) -) +hnn = HamiltonianNN(Lux.Chain(Lux.Dense(2, 64, relu), Lux.Dense(64, 1))) +ps, st = Lux.setup(Random.default_rng(), hnn) +ps_c = ps |> ComponentArray -p = hnn.p +opt = ADAM(0.01f0) -opt = ADAM(0.01) +function loss_function(ps, data, target) + pred, st_ = hnn(data, ps, st) + return mean(abs2, pred .- target), pred +end -loss(x, y, p) = mean((hnn(x, p) .- y) .^ 2) +opt_func = OptimizationFunction((ps, _, data, target) -> loss_function(ps, data, target), + Optimization.AutoForwardDiff()) +opt_prob = OptimizationProblem(opt_func, ps_c) -callback() = println("Loss Neural Hamiltonian DE = $(loss(data, target, p))") +res = Optimization.solve(opt_prob, opt, dataloader) -epochs = 500 -for epoch in 1:epochs - for (x, y) in dataloader - gs = ReverseDiff.gradient(p -> loss(x, y, p), p) - Flux.Optimise.update!(opt, p, gs) - end - if epoch % 100 == 1 - callback() - end -end -callback() +ps_trained = res.u -model = NeuralHamiltonianDE( - hnn, (0.0f0, 1.0f0), - Tsit5(), save_everystep = false, - save_start = true, saveat = t -) +model = NeuralHamiltonianDE(hnn, (0.0f0, 1.0f0), Tsit5(), save_everystep=false, + save_start=true, saveat=t) -pred = Array(model(data[:, 1])) +pred = Array(first(model(data[:, 1], ps_trained, st))) plot(data[1, :], data[2, :], lw=4, label="Original") plot!(pred[1, :], pred[2, :], lw=4, label="Predicted") xlabel!("Position (q)") @@ -77,37 +74,41 @@ dpdt = -2π_32 .* q_t data = cat(q_t, p_t, dims = 1) target = cat(dqdt, dpdt, dims = 1) -dataloader = Flux.Data.DataLoader((data, target); batchsize=256, shuffle=true) +B = 256 +NEPOCHS = 500 +dataloader = ncycle(((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))), + selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2))))) + for i in 1:(size(data, 2) ÷ B)), NEPOCHS) ``` ### Training the HamiltonianNN -We parameterize the HamiltonianNN with a small MultiLayered Perceptron (HNN also works with the Fast* Layers provided in DiffEqFlux). HNNs are trained by optimizing the gradients of the Neural Network. Zygote currently doesn't support nesting itself, so we will be using ReverseDiff in the training loop to compute the gradients of the HNN Layer for Optimization. +We parameterize the HamiltonianNN with a small MultiLayered Perceptron. HNNs are trained by optimizing the gradients of the Neural Network. Zygote currently doesn't support nesting itself, so we will be using ForwardDiff in the training loop to compute the gradients of the HNN Layer for Optimization. ```@example hamiltonian -hnn = HamiltonianNN( - Flux.Chain(Flux.Dense(2, 64, relu), Flux.Dense(64, 1)) -) +hnn = HamiltonianNN(Lux.Chain(Lux.Dense(2, 64, relu), Lux.Dense(64, 1))) +ps, st = Lux.setup(Random.default_rng(), hnn) +ps_c = ps |> ComponentArray -p = hnn.p +opt = ADAM(0.01f0) -opt = ADAM(0.01) +function loss_function(ps, data, target) + pred, st_ = hnn(data, ps, st) + return mean(abs2, pred .- target), pred +end -loss(x, y, p) = mean((hnn(x, p) .- y) .^ 2) +function callback(ps, loss, pred) + println("Loss: ", loss) + return false +end -callback() = println("Loss Neural Hamiltonian DE = $(loss(data, target, p))") +opt_func = OptimizationFunction((ps, _, data, target) -> loss_function(ps, data, target), + Optimization.AutoForwardDiff()) +opt_prob = OptimizationProblem(opt_func, ps_c) -epochs = 500 -for epoch in 1:epochs - for (x, y) in dataloader - gs = ReverseDiff.gradient(p -> loss(x, y, p), p) - Flux.Optimise.update!(opt, p, gs) - end - if epoch % 100 == 1 - callback() - end -end -callback() +res = solve(opt_prob, opt, dataloader; callback) + +ps_trained = res.u ``` ### Solving the ODE using trained HNN @@ -115,13 +116,10 @@ callback() In order to visualize the learned trajectories, we need to solve the ODE. We will use the `NeuralHamiltonianDE` layer, which is essentially a wrapper over `HamiltonianNN` layer, and solves the ODE. ```@example hamiltonian -model = NeuralHamiltonianDE( - hnn, (0.0f0, 1.0f0), - Tsit5(), save_everystep = false, - save_start = true, saveat = t -) +model = NeuralHamiltonianDE(hnn, (0.0f0, 1.0f0), Tsit5(), save_everystep=false, + save_start=true, saveat=t) -pred = Array(model(data[:, 1])) +pred = Array(first(model(data[:, 1], ps_trained, st))) plot(data[1, :], data[2, :], lw=4, label="Original") plot!(pred[1, :], pred[2, :], lw=4, label="Predicted") xlabel!("Position (q)") @@ -133,12 +131,15 @@ ylabel!("Momentum (p)") ## Expected Output ```julia -Loss Neural Hamiltonian DE = 18.768814 -Loss Neural Hamiltonian DE = 0.022630047 -Loss Neural Hamiltonian DE = 0.015060622 -Loss Neural Hamiltonian DE = 0.013170851 -Loss Neural Hamiltonian DE = 0.011898238 -Loss Neural Hamiltonian DE = 0.009806873 +Loss: 19.865715 +Loss: 18.196068 +Loss: 19.179213 +Loss: 19.58956 +⋮ +Loss: 0.02267044 +Loss: 0.019175647 +Loss: 0.02218909 +Loss: 0.018870523 ``` ## References diff --git a/src/hnn.jl b/src/hnn.jl index 71b8ed0cd..1887393bc 100644 --- a/src/hnn.jl +++ b/src/hnn.jl @@ -10,12 +10,11 @@ particles. It then returns the time derivatives for position and momentum. for such applications. !!! note - This layer currently doesn't support GPU. The support will be added in future - with some AD fixes. + To compute the gradients for this layer, it is recommended to use ForwardDiff.jl -To obtain the gradients to train this network, ReverseDiff.gradient is supposed to -be used. This prevents the usage of `DiffEqFlux.sciml_train` or `Flux.train`. Follow -this [tutorial](https://docs.sciml.ai/DiffEqFlux/stable/examples/hamiltonian_nn/) to see how +To obtain the gradients to train this network, ForwardDiff.gradient is supposed to +be used. Follow this +[tutorial](https://docs.sciml.ai/DiffEqFlux/stable/examples/hamiltonian_nn/) to see how to define a training loop to circumvent this issue. ```julia @@ -32,28 +31,42 @@ References: [1] Greydanus, Samuel, Misko Dzamba, and Jason Yosinski. "Hamiltonian Neural Networks." Advances in Neural Information Processing Systems 32 (2019): 15379-15389. """ -struct HamiltonianNN{M, R, P} +struct HamiltonianNN{M,R,P} <: LuxCore.AbstractExplicitContainerLayer{(:model,)} model::M re::R p::P +end - function HamiltonianNN(model; p = nothing) - _p, re = Flux.destructure(model) - if p === nothing - p = _p - end - return new{typeof(model), typeof(re), typeof(p)}(model, re, p) - end +function HamiltonianNN(model; p=nothing) + _p, re = Flux.destructure(model) + p === nothing && (p = _p) + return HamiltonianNN{typeof(model),typeof(re),typeof(p)}(model, re, p) +end + +function HamiltonianNN(model::LuxCore.AbstractExplicitLayer; p=nothing) + @assert p === nothing + return HamiltonianNN{typeof(model),Nothing,Nothing}(model, nothing, nothing) end Flux.trainable(hnn::HamiltonianNN) = (hnn.p,) function _hamiltonian_forward(re, p, x) - H = Flux.gradient(x -> sum(re(p)(x)), x)[1] + H = only(Zygote.gradient(x -> sum(re(p)(x)), x)) + n = size(x, 1) ÷ 2 + return vcat(selectdim(H, 1, (n+1):2n), -selectdim(H, 1, 1:n)) +end + +(hnn::HamiltonianNN)(x, p=hnn.p) = _hamiltonian_forward(hnn.re, p, x) + +function (hnn::HamiltonianNN{<:LuxCore.AbstractExplicitLayer})(x, ps, st) + (_, st), pb_f = Zygote.pullback(x) do x + y, st_ = hnn.model(x, ps, st) + return sum(y), st_ + end + H = only(pb_f((one(eltype(x)), nothing))) n = size(x, 1) ÷ 2 - return cat(H[(n + 1):2n, :], -H[1:n, :], dims=1) + return vcat(selectdim(H, 1, (n+1):2n), -selectdim(H, 1, 1:n)), st end -(hnn::HamiltonianNN)(x, p = hnn.p) = _hamiltonian_forward(hnn.re, p, x) """ Contructs a Neural Hamiltonian DE Layer for solving Hamiltonian Problems @@ -78,27 +91,41 @@ struct NeuralHamiltonianDE{M,P,RE,T,A,K} <: NeuralDELayer tspan::T args::A kwargs::K +end - function NeuralHamiltonianDE(model, tspan, args...; p = nothing, kwargs...) - hnn = HamiltonianNN(model, p=p) - new{typeof(hnn.model), typeof(hnn.p), typeof(hnn.re), - typeof(tspan), typeof(args), typeof(kwargs)}( - hnn, hnn.p, tspan, args, kwargs) - end +# TODO: Make sensealg an argument +function NeuralHamiltonianDE(model, tspan, args...; p=nothing, kwargs...) + hnn = HamiltonianNN(model, p=p) + return NeuralHamiltonianDE{typeof(hnn.model),typeof(hnn.p),typeof(hnn.re), + typeof(tspan),typeof(args),typeof(kwargs)}( + hnn, hnn.p, tspan, args, kwargs) +end - function NeuralHamiltonianDE(hnn::HamiltonianNN{M,RE,P}, tspan, args...; - p = hnn.p, kwargs...) where {M,RE,P} - new{M, P, RE, typeof(tspan), typeof(args), - typeof(kwargs)}(hnn, hnn.p, tspan, args, kwargs) - end +function NeuralHamiltonianDE(hnn::HamiltonianNN{M,RE,P}, tspan, args...; + p=hnn.p, kwargs...) where {M,RE,P} + return NeuralHamiltonianDE{M,P,RE,typeof(tspan),typeof(args), + typeof(kwargs)}(hnn, hnn.p, tspan, args, kwargs) end -function (nhde::NeuralHamiltonianDE)(x, p = nhde.p) +function (nhde::NeuralHamiltonianDE)(x, p=nhde.p) function neural_hamiltonian!(du, u, p, t) du .= reshape(nhde.model(u, p), size(du)) end - prob = ODEProblem(neural_hamiltonian!, x, nhde.tspan, p) - # NOTE: Nesting Zygote is an issue. So we can't use ZygoteVJP - sense = InterpolatingAdjoint(autojacvec = false) - solve(prob, nhde.args...; sensealg = sense, nhde.kwargs...) + prob = ODEProblem(ODEFunction{true}(neural_hamiltonian!), x, nhde.tspan, p) + # NOTE: Nesting Zygote is an issue. So we can't use ZygoteVJP. Instead we use + # ForwardDiff.jl internally. + sensealg = InterpolatingAdjoint(; autojacvec=true) + return solve(prob, nhde.args...; sensealg, nhde.kwargs...) +end + +function (nhde::NeuralHamiltonianDE{<:LuxCore.AbstractExplicitLayer})(x, ps, st) + function neural_hamiltonian!(du, u, p, t) + y, st = nhde.model(u, p, st) + du .= reshape(y, size(du)) + end + prob = ODEProblem(ODEFunction{true}(neural_hamiltonian!), x, nhde.tspan, ps) + # NOTE: Nesting Zygote is an issue. So we can't use ZygoteVJP. Instead we use + # ForwardDiff.jl internally. + sensealg = InterpolatingAdjoint(; autojacvec=true) + return solve(prob, nhde.args...; sensealg, nhde.kwargs...), st end diff --git a/test/Project.toml b/test/Project.toml index ebee382a1..1b077bb6c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,6 +10,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GraphSignals = "3ebe565e-a4b5-49c6-aed2-300248c3a9c1" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" @@ -17,6 +18,7 @@ MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationFlux = "253f991c-a7b2-45f8-8852-8b9a9df78a86" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" diff --git a/test/hamiltonian_nn.jl b/test/hamiltonian_nn.jl index f5436950d..0164efb4f 100644 --- a/test/hamiltonian_nn.jl +++ b/test/hamiltonian_nn.jl @@ -1,63 +1,67 @@ -using DiffEqFlux, Zygote, OrdinaryDiffEq, ReverseDiff, Test +using DiffEqFlux, Zygote, OrdinaryDiffEq, ForwardDiff, Test, Optimisers, Random, Lux, ComponentArrays, Statistics # Checks for Shapes and Non-Zero Gradients u0 = rand(Float32, 6, 1) -hnn = HamiltonianNN(Flux.Chain(Flux.Dense(6, 12, relu), Flux.Dense(12, 1))) -p = hnn.p +hnn = HamiltonianNN(Lux.Chain(Lux.Dense(6, 12, relu), Lux.Dense(12, 1))) +ps, st = Lux.setup(Random.default_rng(), hnn) +ps = ps |> ComponentArray -@test size(hnn(u0)) == (6, 1) +@test size(first(hnn(u0, ps, st))) == (6, 1) -@test ! iszero(ReverseDiff.gradient(p -> sum(hnn(u0, p)), p)) +@test !iszero(ForwardDiff.gradient(ps -> sum(first(hnn(u0, ps, st))), ps)) -hnn = HamiltonianNN(Flux.Chain(Flux.Dense(6, 12, relu), Flux.Dense(12, 1))) -p = hnn.p +hnn = HamiltonianNN(Lux.Chain(Lux.Dense(6, 12, relu), Lux.Dense(12, 1))) +ps, st = Lux.setup(Random.default_rng(), hnn) +ps = ps |> ComponentArray -@test size(hnn(u0)) == (6, 1) +@test size(first(hnn(u0, ps, st))) == (6, 1) -@test ! iszero(ReverseDiff.gradient(p -> sum(hnn(u0, p)), p)) +@test !iszero(ForwardDiff.gradient(ps -> sum(first(hnn(u0, ps, st))), ps)) # Test Convergence on a toy problem -t = range(0.0f0, 1.0f0, length = 64) +t = range(0.0f0, 1.0f0, length=64) π_32 = Float32(π) q_t = reshape(sin.(2π_32 * t), 1, :) p_t = reshape(cos.(2π_32 * t), 1, :) dqdt = 2π_32 .* p_t dpdt = -2π_32 .* q_t -data = cat(q_t, p_t, dims = 1) -target = cat(dqdt, dpdt, dims = 1) +data = vcat(q_t, p_t) +target = vcat(dqdt, dpdt) -hnn = HamiltonianNN(Flux.Chain(Flux.Dense(2, 16, relu), Flux.Dense(16, 1))) -p = hnn.p +hnn = HamiltonianNN(Lux.Chain(Lux.Dense(2, 16, relu), Lux.Dense(16, 1))) +ps, st = Lux.setup(Random.default_rng(), hnn) +ps = ps |> ComponentArray opt = ADAM(0.01) -loss(x, y, p) = sum((hnn(x, p) .- y) .^ 2) +st_opt = Optimisers.setup(opt, ps) +loss(data, target, ps) = mean(abs2, first(hnn(data, ps, st)) .- target) -initial_loss = loss(data, target, p) +initial_loss = loss(data, target, ps) -epochs = 100 -for epoch in 1:epochs - gs = ReverseDiff.gradient(p -> loss(data, target, p), p) - Flux.Optimise.update!(opt, p, gs) +for epoch in 1:100 + global ps, st_opt + # Forward Mode over Reverse Mode for Training + gs = ForwardDiff.gradient(ps -> loss(data, target, ps), ps) + st_opt, ps = Optimisers.update!(st_opt, ps, gs) end -final_loss = loss(data, target, p) +final_loss = loss(data, target, ps) -@test initial_loss > final_loss +@test initial_loss > 5 * final_loss # Test output and gradient of NeuralHamiltonianDE Layer tspan = (0.0f0, 1.0f0) model = NeuralHamiltonianDE( hnn, tspan, Tsit5(), - save_everystep = false, save_start = true, - saveat = range(tspan[1], tspan[2], length=10) + save_everystep=false, save_start=true, + saveat=range(tspan[1], tspan[2], length=10) ) -sol = Array(model(data[:, 1])) +sol = Array(first(model(data[:, 1], ps, st))) @test size(sol) == (2, 10) -ps = Flux.params(model) -gs = Flux.gradient(() -> sum(Array(model(data[:, 1]))), ps) +gs = only(Zygote.gradient(ps -> sum(Array(first(model(data[:, 1], ps, st)))), ps)) -@test ! iszero(gs[model.p]) +@test !iszero(gs)