diff --git a/.gitignore b/.gitignore index 289c574fd..a86c30781 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ deps/deps.jl Manifest.toml docs/build *.DS_Store -wip \ No newline at end of file +.vscode +wip/ diff --git a/docs/src/examples/neural_ode.md b/docs/src/examples/neural_ode.md index c8d976cc3..61a6cfc68 100644 --- a/docs/src/examples/neural_ode.md +++ b/docs/src/examples/neural_ode.md @@ -12,7 +12,7 @@ Before getting to the explanation, here's some code to start with. We will follow a full explanation of the definition and training process: ```@example neuralode_cp -using ComponentArrays, Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, Random, Plots +using ComponentArrays, Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, OptimizationFlux, Random, Plots rng = Random.default_rng() u0 = Float32[2.0; 0.0] @@ -88,7 +88,7 @@ callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot=tru Let's get a time series array from a spiral ODE to train against. ```@example neuralode -using ComponentArrays, Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, Random, Plots +using ComponentArrays, Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, OptimizationFlux, Random, Plots rng = Random.default_rng() u0 = Float32[2.0; 0.0] diff --git a/test/mnist_gpu.jl b/test/mnist_gpu.jl index 4823652ff..7febe6df7 100644 --- a/test/mnist_gpu.jl +++ b/test/mnist_gpu.jl @@ -1,45 +1,39 @@ -using DiffEqFlux, CUDA, Zygote, MLDataUtils, NNlib, OrdinaryDiffEq, Test -using Flux.Losses: logitcrossentropy +using DiffEqFlux, CUDA, Zygote, MLDataUtils, NNlib, OrdinaryDiffEq, Test, Lux, Statistics, + ComponentArrays, Random, Optimization, OptimizationOptimisers using MLDatasets: MNIST using MLDataUtils: LabelEnc, convertlabel, stratifiedobs CUDA.allowscalar(false) ENV["DATADEPS_ALWAYS_ACCEPT"] = true -function loadmnist(batchsize = bs) - # Use MLDataUtils LabelEnc for natural onehot conversion - onehot(labels_raw) = convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9))) - # Load MNIST - mnist = MNIST(split = :train) +logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims=1); dims=1)) + +function loadmnist(batchsize=bs) + # Use MLDataUtils LabelEnc for natural onehot conversion + onehot(labels_raw) = convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9))) + # Load MNIST + mnist = MNIST(split=:train) imgs, labels_raw = mnist.features, mnist.targets - # Process images into (H,W,C,BS) batches - x_train = Float32.(reshape(imgs,size(imgs,1),size(imgs,2),1,size(imgs,3))) |> Flux.gpu - x_train = batchview(x_train,batchsize) - # Onehot and batch the labels - y_train = onehot(labels_raw) |> Flux.gpu - y_train = batchview(y_train,batchsize) - return x_train, y_train + # Process images into (H,W,C,BS) batches + x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |> Lux.gpu + x_train = batchview(x_train, batchsize) + # Onehot and batch the labels + y_train = onehot(labels_raw) |> Lux.gpu + y_train = batchview(y_train, batchsize) + return x_train, y_train end # Main const bs = 128 x_train, y_train = loadmnist(bs) -down = Flux.Chain(x->reshape(x,(28*28,:)), - Flux.Dense(784,20,tanh) - ) |> Flux.gpu -nfe = 0 -nn = Flux.Chain( - Flux.Dense(20,10,tanh), - Flux.Dense(10,10,tanh), - Flux.Dense(10,20,tanh) - ) |> Flux.gpu -fc = Flux.Chain(Flux.Dense(20,10)) |> Flux.gpu - -nn_ode = NeuralODE(nn, (0.f0, 1.f0), Tsit5(), - save_everystep = false, - reltol = 1e-3, abstol = 1e-3, - save_start = false) |> Flux.gpu +down = Lux.Chain(Lux.FlattenLayer(), Lux.Dense(784, 20, tanh)) +nn = Lux.Chain(Lux.Dense(20, 10, tanh), Lux.Dense(10, 10, tanh), + Lux.Dense(10, 20, tanh)) +fc = Lux.Dense(20, 10) + +nn_ode = NeuralODE(nn, (0.0f0, 1.0f0), Tsit5(), save_everystep=false, reltol=1e-3, + abstol=1e-3, save_start=false) """ DiffEqArray_to_Array(x) @@ -47,58 +41,70 @@ nn_ode = NeuralODE(nn, (0.f0, 1.f0), Tsit5(), Cheap conversion of a `DiffEqArray` instance to a Matrix. """ function DiffEqArray_to_Array(x) - xarr = Flux.gpu(x) + xarr = Lux.gpu(x) return reshape(xarr, size(xarr)[1:2]) end #Build our over-all model topology -m = Flux.Chain(down, - nn_ode, - DiffEqArray_to_Array, - fc) |> Flux.gpu +m = Lux.Chain(; down, nn_ode, convert=Lux.WrappedFunction(DiffEqArray_to_Array), fc) +ps, st = Lux.setup(Random.default_rng(), m) +ps = ComponentArray(ps) |> Lux.gpu +st = st |> Lux.gpu #We can also build the model topology without a NN-ODE -m_no_ode = Flux.Chain(down, nn, fc) |> Flux.gpu +m_no_ode = Lux.Chain(; down, nn, fc) +ps_no_ode, st_no_ode = Lux.setup(Random.default_rng(), m_no_ode) +ps_no_ode = ComponentArray(ps_no_ode) |> Lux.gpu +st_no_ode = st_no_ode |> Lux.gpu #To understand the intermediate NN-ODE layer, we can examine it's dimensionality -x_d = down(x_train[1]) +x_d = first(down(x_train[1], ps.down, st.down)) # We can see that we can compute the forward pass through the NN topology featuring an NNODE layer. -x_m = m(x_train[1]) +x_m = first(m(x_train[1], ps, st)) #Or without the NN-ODE layer. -x_m = m_no_ode(x_train[1]) +x_m = first(m_no_ode(x_train[1], ps_no_ode, st_no_ode)) classify(x) = argmax.(eachcol(x)) -function accuracy(model,data; n_batches=100) +function accuracy(model, data, ps, st; n_batches=100) total_correct = 0 total = 0 - for (x,y) in collect(data)[1:n_batches] - target_class = classify(Flux.cpu(y)) - predicted_class = classify(Flux.cpu(model(x))) + st = Lux.testmode(st) + for (x, y) in collect(data)[1:n_batches] + target_class = classify(Lux.cpu(y)) + predicted_class = classify(Lux.cpu(first(model(x, ps, st)))) total_correct += sum(target_class .== predicted_class) total += length(target_class) end - return total_correct/total + return total_correct / total end #burn in accuracy -accuracy(m, zip(x_train,y_train)) +accuracy(m, zip(x_train, y_train), ps, st) + +function loss_function(ps, x, y) + pred, st_ = m(x, ps, st) + return logitcrossentropy(pred, y), pred +end -loss(x,y) = logitcrossentropy(m(x),y) #burn in loss -loss(x_train[1],y_train[1]) +loss_function(ps, x_train[1], y_train[1]) opt = ADAM(0.05) iter = 0 -cb() = begin +opt_func = OptimizationFunction((ps, _, x, y) -> loss_function(ps, x, y), + Optimization.AutoZygote()) +opt_prob = OptimizationProblem(opt_func, ps) + +function callback(ps, l, pred) global iter += 1 #Monitor that the weights do infact update #Every 10 training iterations show accuracy - (iter%10 == 0) && @show accuracy(m, zip(x_train,y_train)) - global nfe=0 + (iter % 10 == 0) && @show accuracy(m, zip(x_train, y_train), ps, st) + return false end # Train the NN-ODE and monitor the loss and weights. -Flux.train!(loss, Flux.params(down, nn_ode.p, fc), zip(x_train, y_train), opt, cb = cb) -@test accuracy(m, zip(x_train,y_train)) > 0.8 +res = Optimization.solve(opt_prob, opt, zip(x_train, y_train); callback) +@test accuracy(m, zip(x_train, y_train), res.u, st) > 0.8 diff --git a/test/neural_de_gpu.jl b/test/neural_de_gpu.jl index 85044881e..a4d1c03b1 100644 --- a/test/neural_de_gpu.jl +++ b/test/neural_de_gpu.jl @@ -1,86 +1,104 @@ -using DiffEqFlux, Flux, CUDA, Zygote, OrdinaryDiffEq, StochasticDiffEq, Test +using DiffEqFlux, Lux, CUDA, Zygote, OrdinaryDiffEq, StochasticDiffEq, Test, Random, + ComponentArrays CUDA.allowscalar(false) -mp = Flux.Chain(Flux.Dense(2,2)) |> Flux.gpu -x = Float32[2.; 0.] |> Flux.gpu -xs = Float32.(hcat([0.; 0.], [1.; 0.], [2.; 0.])) |> Flux.gpu -tspan = (0.0f0,25.0f0) -dudt = Flux.Chain(Flux.Dense(2,50,tanh),Flux.Dense(50,2)) |> Flux.gpu -dudt_tracker = Flux.Chain(Flux.Dense(2,50,CUDA.tanh),Flux.Dense(50,2)) |> Flux.gpu - -NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false)(x) -NeuralODE(dudt,tspan,Tsit5(),saveat=0.1)(x) -NeuralODE(dudt,tspan,Tsit5(),saveat=0.1,sensealg=TrackerAdjoint())(x) - -NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false)(xs) -NeuralODE(dudt,tspan,Tsit5(),saveat=0.1)(xs) -NeuralODE(dudt,tspan,Tsit5(),saveat=0.1,sensealg=TrackerAdjoint())(xs) - -node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false) -grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) -@test ! iszero(grads[x]) -@test ! iszero(grads[node.p]) - -grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) -@test ! iszero(grads[xs]) -@test ! iszero(grads[node.p]) - -node = NeuralODE(dudt_tracker,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=TrackerAdjoint()) -grads = Zygote.gradient(()->sum(Array(node(x))),Flux.params(x,node)) -@test ! iszero(grads[x]) -@test ! iszero(grads[node.p]) - -grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) -@test ! iszero(grads[xs]) -@test ! iszero(grads[node.p]) - -node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=BacksolveAdjoint()) -grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) -@test ! iszero(grads[x]) -@test ! iszero(grads[node.p]) - -grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) -@test ! iszero(grads[xs]) -@test ! iszero(grads[node.p]) +x = Float32[2.0; 0.0] |> Lux.gpu +xs = hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0]) |> Lux.gpu +tspan = (0.0f0, 25.0f0) + +mp = Lux.Chain(Lux.Dense(2, 2)) + +dudt = Lux.Chain(Lux.Dense(2, 50, tanh), Lux.Dense(50, 2)) +ps_dudt, st_dudt = Lux.setup(Random.default_rng(), dudt) +ps_dudt = ComponentArray(ps_dudt) |> Lux.gpu +st_dudt = st_dudt |> Lux.gpu + +NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false)(x, ps_dudt, st_dudt) +NeuralODE(dudt, tspan, Tsit5(), saveat=0.1)(x, ps_dudt, st_dudt) +NeuralODE(dudt, tspan, Tsit5(), saveat=0.1, sensealg=TrackerAdjoint())(x, ps_dudt, st_dudt) + +NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false)(xs, ps_dudt, st_dudt) +NeuralODE(dudt, tspan, Tsit5(), saveat=0.1)(xs, ps_dudt, st_dudt) +NeuralODE(dudt, tspan, Tsit5(), saveat=0.1, sensealg=TrackerAdjoint())(xs, ps_dudt, st_dudt) + +node = NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false) +ps_node, st_node = Lux.setup(Random.default_rng(), node) +ps_node = ComponentArray(ps_node) |> Lux.gpu +st_node = st_node |> Lux.gpu +grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node) +@test !iszero(grads[1]) +@test !iszero(grads[2]) + +grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node) +@test !iszero(grads[1]) +@test !iszero(grads[2]) + +node = NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false, + sensealg=TrackerAdjoint()) +grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node) +@test !iszero(grads[1]) +@test !iszero(grads[2]) + +grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node) +@test !iszero(grads[1]) +@test !iszero(grads[2]) + +node = NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false, + sensealg=BacksolveAdjoint()) +grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node) +@test !iszero(grads[1]) +@test !iszero(grads[2]) + +grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node) +@test !iszero(grads[1]) +@test !iszero(grads[2]) # Adjoint @testset "adjoint mode" begin - node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false) - grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) - @test ! iszero(grads[x]) - @test ! iszero(grads[node.p]) - - grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) - @test ! iszero(grads[xs]) - @test ! iszero(grads[node.p]) - - node = NeuralODE(dudt,tspan,Tsit5(),saveat=0.0:0.1:10.0) - grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) - @test ! iszero(grads[x]) - @test ! iszero(grads[node.p]) - - grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) - @test ! iszero(grads[xs]) - @test ! iszero(grads[node.p]) - - node = NeuralODE(dudt,tspan,Tsit5(),saveat=1f-1) - grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) - @test ! iszero(grads[x]) - @test ! iszero(grads[node.p]) - - grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) - @test ! iszero(grads[xs]) - @test ! iszero(grads[node.p]) + node = NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false) + grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + + grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + + node = NeuralODE(dudt, tspan, Tsit5(), saveat=0.0:0.1:10.0) + grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + + grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + + node = NeuralODE(dudt, tspan, Tsit5(), saveat=1.0f-1) + grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + + grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node) + @test !iszero(grads[1]) + @test !iszero(grads[2]) end -NeuralDSDE(dudt,mp,(0.0f0,2.0f0),SOSRI(),saveat=0.0:0.1:2.0)(x) -sode = NeuralDSDE(dudt,mp,(0.0f0,2.0f0),SOSRI(),saveat=Float32.(0.0:0.1:2.0),dt=1f-1) -grads = Zygote.gradient(()->sum(sode(x)),Flux.params(x,sode)) - -@test ! iszero(grads[x]) -@test ! iszero(grads[sode.p]) - -grads = Zygote.gradient(()->sum(sode(xs)),Flux.params(xs,sode)) -@test ! iszero(grads[xs]) -@test ! iszero(grads[sode.p]) +ndsde = NeuralDSDE(dudt, mp, (0.0f0, 2.0f0), SOSRI(), saveat=0.0:0.1:2.0) +ps_ndsde, st_ndsde = Lux.setup(Random.default_rng(), ndsde) +ps_ndsde = ComponentArray(ps_ndsde) |> Lux.gpu +st_ndsde = st_ndsde |> Lux.gpu +ndsde(x, ps_ndsde, st_ndsde) + +sode = NeuralDSDE(dudt, mp, (0.0f0, 2.0f0), SOSRI(), saveat=Float32.(0.0:0.1:2.0), + dt=1.0f-1, sensealg=TrackerAdjoint()) +ps_sode, st_sode = Lux.setup(Random.default_rng(), sode) +ps_sode = ComponentArray(ps_sode) |> Lux.gpu +st_sode = st_sode |> Lux.gpu +grads = Zygote.gradient((x, ps) -> sum(first(sode(x, ps, st_sode))), x, ps_sode) +@test !iszero(grads[1]) +@test !iszero(grads[2]) + +grads = Zygote.gradient((xs, ps) -> sum(first(sode(xs, ps, st_sode))), xs, ps_sode) +@test !iszero(grads[1]) +@test !iszero(grads[2]) diff --git a/test/runtests.jl b/test/runtests.jl index 337229d5d..e76a92d59 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,40 +1,72 @@ using DiffEqFlux, SafeTestsets, Test const GROUP = get(ENV, "GROUP", "All") -const is_APPVEYOR = (Sys.iswindows() && haskey(ENV,"APPVEYOR")) -const is_CI = haskey(ENV,"CI") +const is_APPVEYOR = (Sys.iswindows() && haskey(ENV, "APPVEYOR")) +const is_CI = haskey(ENV, "CI") @time begin -if GROUP == "All" || GROUP == "DiffEqFlux" || GROUP == "Layers" - @safetestset "Collocation Regression" begin include("collocation_regression.jl") end - @safetestset "Stiff Nested AD Tests" begin include("stiff_nested_ad.jl") end -end + if GROUP == "All" || GROUP == "DiffEqFlux" || GROUP == "Layers" + @safetestset "Collocation Regression" begin + include("collocation_regression.jl") + end + @safetestset "Stiff Nested AD Tests" begin + include("stiff_nested_ad.jl") + end + end -if GROUP == "All" || GROUP == "DiffEqFlux" || GROUP == "BasicNeuralDE" - @safetestset "Neural DE Tests with Lux" begin include("neural_de_lux.jl") end - @safetestset "Neural DE Tests" begin include("neural_de.jl") end - @safetestset "Augmented Neural DE Tests" begin include("augmented_nde.jl") end - #@safetestset "Neural Graph DE" begin include("neural_gde.jl") end - - @safetestset "Neural ODE MM Tests" begin include("neural_ode_mm.jl") end - @safetestset "Tensor Product Layer" begin include("tensor_product_test.jl") end - @safetestset "Spline Layer" begin include("spline_layer_test.jl") end - @safetestset "Multiple shooting" begin include("multiple_shoot.jl") end -end + if GROUP == "All" || GROUP == "DiffEqFlux" || GROUP == "BasicNeuralDE" + @safetestset "Neural DE Tests with Lux" begin + include("neural_de_lux.jl") + end + @safetestset "Neural DE Tests" begin + include("neural_de.jl") + end + @safetestset "Augmented Neural DE Tests" begin + include("augmented_nde.jl") + end + #@safetestset "Neural Graph DE" begin include("neural_gde.jl") end -if GROUP == "All" || GROUP == "AdvancedNeuralDE" - @safetestset "CNF Layer Tests" begin include("cnf_test.jl") end - @safetestset "Neural Second Order ODE Tests" begin include("second_order_ode.jl") end - @safetestset "Neural Hamiltonian ODE Tests" begin include("hamiltonian_nn.jl") end -end + @safetestset "Neural ODE MM Tests" begin + include("neural_ode_mm.jl") + end + @safetestset "Tensor Product Layer" begin + include("tensor_product_test.jl") + end + @safetestset "Spline Layer" begin + include("spline_layer_test.jl") + end + @safetestset "Multiple shooting" begin + include("multiple_shoot.jl") + end + end -if GROUP == "Newton" - @safetestset "Newton Neural ODE Tests" begin include("newton_neural_ode.jl") end -end + if GROUP == "All" || GROUP == "AdvancedNeuralDE" + @safetestset "CNF Layer Tests" begin + include("cnf_test.jl") + end + @safetestset "Neural Second Order ODE Tests" begin + include("second_order_ode.jl") + end + @safetestset "Neural Hamiltonian ODE Tests" begin + include("hamiltonian_nn.jl") + end + end -if !is_APPVEYOR && GROUP == "GPU" - @safetestset "Neural DE GPU Tests" begin include("neural_de_gpu.jl") end - @safetestset "MNIST GPU Tests: Fully Connected NN" begin include("mnist_gpu.jl") end - @safetestset "MNIST GPU Tests: Convolutional NN" begin include("mnist_conv_gpu.jl") end -end + if GROUP == "Newton" + @safetestset "Newton Neural ODE Tests" begin + include("newton_neural_ode.jl") + end + end + + if !is_APPVEYOR && GROUP == "GPU" + @safetestset "Neural DE GPU Tests" begin + include("neural_de_gpu.jl") + end + @safetestset "MNIST GPU Tests: Fully Connected NN" begin + include("mnist_gpu.jl") + end + @safetestset "MNIST GPU Tests: Convolutional NN" begin + include("mnist_conv_gpu.jl") + end + end end \ No newline at end of file