From 05e850afb3e828486737efff7c6d19b86960e2e6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 23 May 2023 11:51:56 -0400 Subject: [PATCH] Fix the tests to use Lux and ForwardDiff --- docs/src/examples/hamiltonian_nn.md | 4 +- src/hnn.jl | 33 ++++++++-------- test/Project.toml | 2 + test/hamiltonian_nn.jl | 59 +++++++++++++++-------------- 4 files changed, 52 insertions(+), 46 deletions(-) diff --git a/docs/src/examples/hamiltonian_nn.md b/docs/src/examples/hamiltonian_nn.md index 85f3466aa..67fb1abc8 100644 --- a/docs/src/examples/hamiltonian_nn.md +++ b/docs/src/examples/hamiltonian_nn.md @@ -48,7 +48,7 @@ model = NeuralHamiltonianDE( save_start = true, saveat = t ) -pred = Array(model(data[:, 1], ps_c, st)) +pred = Array(first(model(data[:, 1], ps_c, st))) plot(data[1, :], data[2, :], lw=4, label="Original") plot!(pred[1, :], pred[2, :], lw=4, label="Predicted") xlabel!("Position (q)") @@ -112,7 +112,7 @@ model = NeuralHamiltonianDE( save_start = true, saveat = t ) -pred = Array(model(data[:, 1], ps_c, st)) +pred = Array(first(model(data[:, 1], ps_c, st))) plot(data[1, :], data[2, :], lw=4, label="Original") plot!(pred[1, :], pred[2, :], lw=4, label="Predicted") xlabel!("Position (q)") diff --git a/src/hnn.jl b/src/hnn.jl index 9d01f388a..1887393bc 100644 --- a/src/hnn.jl +++ b/src/hnn.jl @@ -91,19 +91,20 @@ struct NeuralHamiltonianDE{M,P,RE,T,A,K} <: NeuralDELayer tspan::T args::A kwargs::K +end - function NeuralHamiltonianDE(model, tspan, args...; p=nothing, kwargs...) - hnn = HamiltonianNN(model, p=p) - new{typeof(hnn.model),typeof(hnn.p),typeof(hnn.re), - typeof(tspan),typeof(args),typeof(kwargs)}( - hnn, hnn.p, tspan, args, kwargs) - end +# TODO: Make sensealg an argument +function NeuralHamiltonianDE(model, tspan, args...; p=nothing, kwargs...) + hnn = HamiltonianNN(model, p=p) + return NeuralHamiltonianDE{typeof(hnn.model),typeof(hnn.p),typeof(hnn.re), + typeof(tspan),typeof(args),typeof(kwargs)}( + hnn, hnn.p, tspan, args, kwargs) +end - function NeuralHamiltonianDE(hnn::HamiltonianNN{M,RE,P}, tspan, args...; - p=hnn.p, kwargs...) where {M,RE,P} - new{M,P,RE,typeof(tspan),typeof(args), - typeof(kwargs)}(hnn, hnn.p, tspan, args, kwargs) - end +function NeuralHamiltonianDE(hnn::HamiltonianNN{M,RE,P}, tspan, args...; + p=hnn.p, kwargs...) where {M,RE,P} + return NeuralHamiltonianDE{M,P,RE,typeof(tspan),typeof(args), + typeof(kwargs)}(hnn, hnn.p, tspan, args, kwargs) end function (nhde::NeuralHamiltonianDE)(x, p=nhde.p) @@ -113,18 +114,18 @@ function (nhde::NeuralHamiltonianDE)(x, p=nhde.p) prob = ODEProblem(ODEFunction{true}(neural_hamiltonian!), x, nhde.tspan, p) # NOTE: Nesting Zygote is an issue. So we can't use ZygoteVJP. Instead we use # ForwardDiff.jl internally. - sense = InterpolatingAdjoint(autojacvec=true) - return solve(prob, nhde.args...; sensealg=sense, nhde.kwargs...) + sensealg = InterpolatingAdjoint(; autojacvec=true) + return solve(prob, nhde.args...; sensealg, nhde.kwargs...) end function (nhde::NeuralHamiltonianDE{<:LuxCore.AbstractExplicitLayer})(x, ps, st) function neural_hamiltonian!(du, u, p, t) - y, st = nhde.model(u, ps, st) + y, st = nhde.model(u, p, st) du .= reshape(y, size(du)) end prob = ODEProblem(ODEFunction{true}(neural_hamiltonian!), x, nhde.tspan, ps) # NOTE: Nesting Zygote is an issue. So we can't use ZygoteVJP. Instead we use # ForwardDiff.jl internally. - sense = InterpolatingAdjoint(autojacvec=true) - return solve(prob, nhde.args...; sensealg=sense, nhde.kwargs...) + sensealg = InterpolatingAdjoint(; autojacvec=true) + return solve(prob, nhde.args...; sensealg, nhde.kwargs...), st end diff --git a/test/Project.toml b/test/Project.toml index ebee382a1..1b077bb6c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,6 +10,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GraphSignals = "3ebe565e-a4b5-49c6-aed2-300248c3a9c1" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" @@ -17,6 +18,7 @@ MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationFlux = "253f991c-a7b2-45f8-8852-8b9a9df78a86" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" diff --git a/test/hamiltonian_nn.jl b/test/hamiltonian_nn.jl index f5436950d..4c35c0813 100644 --- a/test/hamiltonian_nn.jl +++ b/test/hamiltonian_nn.jl @@ -1,63 +1,66 @@ -using DiffEqFlux, Zygote, OrdinaryDiffEq, ReverseDiff, Test +using DiffEqFlux, Zygote, OrdinaryDiffEq, ForwardDiff, Test, Optimisers, Random, Lux, ComponentArrays # Checks for Shapes and Non-Zero Gradients u0 = rand(Float32, 6, 1) -hnn = HamiltonianNN(Flux.Chain(Flux.Dense(6, 12, relu), Flux.Dense(12, 1))) -p = hnn.p +hnn = HamiltonianNN(Lux.Chain(Lux.Dense(6, 12, relu), Lux.Dense(12, 1))) +ps, st = Lux.setup(Random.default_rng(), hnn) +ps = ps |> ComponentArray -@test size(hnn(u0)) == (6, 1) +@test size(first(hnn(u0, ps, st))) == (6, 1) -@test ! iszero(ReverseDiff.gradient(p -> sum(hnn(u0, p)), p)) +@test !iszero(ForwardDiff.gradient(ps -> sum(first(hnn(u0, ps, st))), ps)) -hnn = HamiltonianNN(Flux.Chain(Flux.Dense(6, 12, relu), Flux.Dense(12, 1))) -p = hnn.p +hnn = HamiltonianNN(Lux.Chain(Lux.Dense(6, 12, relu), Lux.Dense(12, 1))) +ps, st = Lux.setup(Random.default_rng(), hnn) +ps = ps |> ComponentArray -@test size(hnn(u0)) == (6, 1) +@test size(first(hnn(u0, ps, st))) == (6, 1) -@test ! iszero(ReverseDiff.gradient(p -> sum(hnn(u0, p)), p)) +@test !iszero(ForwardDiff.gradient(ps -> sum(first(hnn(u0, ps, st))), ps)) # Test Convergence on a toy problem -t = range(0.0f0, 1.0f0, length = 64) +t = range(0.0f0, 1.0f0, length=64) π_32 = Float32(π) q_t = reshape(sin.(2π_32 * t), 1, :) p_t = reshape(cos.(2π_32 * t), 1, :) dqdt = 2π_32 .* p_t dpdt = -2π_32 .* q_t -data = cat(q_t, p_t, dims = 1) -target = cat(dqdt, dpdt, dims = 1) +data = vcat(q_t, p_t) +target = vcat(dqdt, dpdt) -hnn = HamiltonianNN(Flux.Chain(Flux.Dense(2, 16, relu), Flux.Dense(16, 1))) -p = hnn.p +hnn = HamiltonianNN(Lux.Chain(Lux.Dense(2, 16, relu), Lux.Dense(16, 1))) +ps, st = Lux.setup(Random.default_rng(), hnn) +ps = ps |> ComponentArray opt = ADAM(0.01) -loss(x, y, p) = sum((hnn(x, p) .- y) .^ 2) +st_opt = Optimisers.setup(opt, ps) +loss(data, target, ps) = mean(abs2, first(hnn(data, ps, st)) .- target) -initial_loss = loss(data, target, p) +initial_loss = loss(data, target, ps) -epochs = 100 -for epoch in 1:epochs - gs = ReverseDiff.gradient(p -> loss(data, target, p), p) - Flux.Optimise.update!(opt, p, gs) +for epoch in 1:100 + # Forward Mode over Reverse Mode for Training + gs = ForwardDiff.gradient(ps -> loss(data, target, ps), ps) + st_opt, ps = Optimisers.update!(st_opt, ps, gs) end -final_loss = loss(data, target, p) +final_loss = loss(data, target, ps) -@test initial_loss > final_loss +@test initial_loss > 5 * final_loss # Test output and gradient of NeuralHamiltonianDE Layer tspan = (0.0f0, 1.0f0) model = NeuralHamiltonianDE( hnn, tspan, Tsit5(), - save_everystep = false, save_start = true, - saveat = range(tspan[1], tspan[2], length=10) + save_everystep=false, save_start=true, + saveat=range(tspan[1], tspan[2], length=10) ) -sol = Array(model(data[:, 1])) +sol = Array(first(model(data[:, 1], ps, st))) @test size(sol) == (2, 10) -ps = Flux.params(model) -gs = Flux.gradient(() -> sum(Array(model(data[:, 1]))), ps) +gs = only(Zygote.gradient(ps -> sum(Array(first(model(data[:, 1], ps, st)))), ps)) -@test ! iszero(gs[model.p]) +@test !iszero(gs)