Skip to content

Commit

Permalink
Merge branch 'master' into ap/hnn
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas authored Jun 1, 2023
2 parents d4598a2 + a5249a8 commit 18e69c2
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 162 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ deps/deps.jl
Manifest.toml
docs/build
*.DS_Store
wip
.vscode
wip/
4 changes: 2 additions & 2 deletions docs/src/examples/neural_ode.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Before getting to the explanation, here's some code to start with. We will
follow a full explanation of the definition and training process:

```@example neuralode_cp
using ComponentArrays, Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, Random, Plots
using ComponentArrays, Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, OptimizationFlux, Random, Plots
rng = Random.default_rng()
u0 = Float32[2.0; 0.0]
Expand Down Expand Up @@ -88,7 +88,7 @@ callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot=tru
Let's get a time series array from a spiral ODE to train against.

```@example neuralode
using ComponentArrays, Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, Random, Plots
using ComponentArrays, Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, OptimizationFlux, Random, Plots
rng = Random.default_rng()
u0 = Float32[2.0; 0.0]
Expand Down
108 changes: 57 additions & 51 deletions test/mnist_gpu.jl
Original file line number Diff line number Diff line change
@@ -1,104 +1,110 @@
using DiffEqFlux, CUDA, Zygote, MLDataUtils, NNlib, OrdinaryDiffEq, Test
using Flux.Losses: logitcrossentropy
using DiffEqFlux, CUDA, Zygote, MLDataUtils, NNlib, OrdinaryDiffEq, Test, Lux, Statistics,
ComponentArrays, Random, Optimization, OptimizationOptimisers
using MLDatasets: MNIST
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs

CUDA.allowscalar(false)
ENV["DATADEPS_ALWAYS_ACCEPT"] = true

function loadmnist(batchsize = bs)
# Use MLDataUtils LabelEnc for natural onehot conversion
onehot(labels_raw) = convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9)))
# Load MNIST
mnist = MNIST(split = :train)
logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims=1); dims=1))

function loadmnist(batchsize=bs)
# Use MLDataUtils LabelEnc for natural onehot conversion
onehot(labels_raw) = convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9)))
# Load MNIST
mnist = MNIST(split=:train)
imgs, labels_raw = mnist.features, mnist.targets
# Process images into (H,W,C,BS) batches
x_train = Float32.(reshape(imgs,size(imgs,1),size(imgs,2),1,size(imgs,3))) |> Flux.gpu
x_train = batchview(x_train,batchsize)
# Onehot and batch the labels
y_train = onehot(labels_raw) |> Flux.gpu
y_train = batchview(y_train,batchsize)
return x_train, y_train
# Process images into (H,W,C,BS) batches
x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |> Lux.gpu
x_train = batchview(x_train, batchsize)
# Onehot and batch the labels
y_train = onehot(labels_raw) |> Lux.gpu
y_train = batchview(y_train, batchsize)
return x_train, y_train
end

# Main
const bs = 128
x_train, y_train = loadmnist(bs)

down = Flux.Chain(x->reshape(x,(28*28,:)),
Flux.Dense(784,20,tanh)
) |> Flux.gpu
nfe = 0
nn = Flux.Chain(
Flux.Dense(20,10,tanh),
Flux.Dense(10,10,tanh),
Flux.Dense(10,20,tanh)
) |> Flux.gpu
fc = Flux.Chain(Flux.Dense(20,10)) |> Flux.gpu

nn_ode = NeuralODE(nn, (0.f0, 1.f0), Tsit5(),
save_everystep = false,
reltol = 1e-3, abstol = 1e-3,
save_start = false) |> Flux.gpu
down = Lux.Chain(Lux.FlattenLayer(), Lux.Dense(784, 20, tanh))
nn = Lux.Chain(Lux.Dense(20, 10, tanh), Lux.Dense(10, 10, tanh),
Lux.Dense(10, 20, tanh))
fc = Lux.Dense(20, 10)

nn_ode = NeuralODE(nn, (0.0f0, 1.0f0), Tsit5(), save_everystep=false, reltol=1e-3,
abstol=1e-3, save_start=false)

"""
DiffEqArray_to_Array(x)
Cheap conversion of a `DiffEqArray` instance to a Matrix.
"""
function DiffEqArray_to_Array(x)
xarr = Flux.gpu(x)
xarr = Lux.gpu(x)
return reshape(xarr, size(xarr)[1:2])
end

#Build our over-all model topology
m = Flux.Chain(down,
nn_ode,
DiffEqArray_to_Array,
fc) |> Flux.gpu
m = Lux.Chain(; down, nn_ode, convert=Lux.WrappedFunction(DiffEqArray_to_Array), fc)
ps, st = Lux.setup(Random.default_rng(), m)
ps = ComponentArray(ps) |> Lux.gpu
st = st |> Lux.gpu

#We can also build the model topology without a NN-ODE
m_no_ode = Flux.Chain(down, nn, fc) |> Flux.gpu
m_no_ode = Lux.Chain(; down, nn, fc)
ps_no_ode, st_no_ode = Lux.setup(Random.default_rng(), m_no_ode)
ps_no_ode = ComponentArray(ps_no_ode) |> Lux.gpu
st_no_ode = st_no_ode |> Lux.gpu

#To understand the intermediate NN-ODE layer, we can examine it's dimensionality
x_d = down(x_train[1])
x_d = first(down(x_train[1], ps.down, st.down))

# We can see that we can compute the forward pass through the NN topology featuring an NNODE layer.
x_m = m(x_train[1])
x_m = first(m(x_train[1], ps, st))
#Or without the NN-ODE layer.
x_m = m_no_ode(x_train[1])
x_m = first(m_no_ode(x_train[1], ps_no_ode, st_no_ode))

classify(x) = argmax.(eachcol(x))

function accuracy(model,data; n_batches=100)
function accuracy(model, data, ps, st; n_batches=100)
total_correct = 0
total = 0
for (x,y) in collect(data)[1:n_batches]
target_class = classify(Flux.cpu(y))
predicted_class = classify(Flux.cpu(model(x)))
st = Lux.testmode(st)
for (x, y) in collect(data)[1:n_batches]
target_class = classify(Lux.cpu(y))
predicted_class = classify(Lux.cpu(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct/total
return total_correct / total
end
#burn in accuracy
accuracy(m, zip(x_train,y_train))
accuracy(m, zip(x_train, y_train), ps, st)

function loss_function(ps, x, y)
pred, st_ = m(x, ps, st)
return logitcrossentropy(pred, y), pred
end

loss(x,y) = logitcrossentropy(m(x),y)
#burn in loss
loss(x_train[1],y_train[1])
loss_function(ps, x_train[1], y_train[1])

opt = ADAM(0.05)
iter = 0

cb() = begin
opt_func = OptimizationFunction((ps, _, x, y) -> loss_function(ps, x, y),
Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps)

function callback(ps, l, pred)
global iter += 1
#Monitor that the weights do infact update
#Every 10 training iterations show accuracy
(iter%10 == 0) && @show accuracy(m, zip(x_train,y_train))
global nfe=0
(iter % 10 == 0) && @show accuracy(m, zip(x_train, y_train), ps, st)
return false
end

# Train the NN-ODE and monitor the loss and weights.
Flux.train!(loss, Flux.params(down, nn_ode.p, fc), zip(x_train, y_train), opt, cb = cb)
@test accuracy(m, zip(x_train,y_train)) > 0.8
res = Optimization.solve(opt_prob, opt, zip(x_train, y_train); callback)
@test accuracy(m, zip(x_train, y_train), res.u, st) > 0.8
174 changes: 96 additions & 78 deletions test/neural_de_gpu.jl
Original file line number Diff line number Diff line change
@@ -1,86 +1,104 @@
using DiffEqFlux, Flux, CUDA, Zygote, OrdinaryDiffEq, StochasticDiffEq, Test
using DiffEqFlux, Lux, CUDA, Zygote, OrdinaryDiffEq, StochasticDiffEq, Test, Random,
ComponentArrays

CUDA.allowscalar(false)

mp = Flux.Chain(Flux.Dense(2,2)) |> Flux.gpu
x = Float32[2.; 0.] |> Flux.gpu
xs = Float32.(hcat([0.; 0.], [1.; 0.], [2.; 0.])) |> Flux.gpu
tspan = (0.0f0,25.0f0)
dudt = Flux.Chain(Flux.Dense(2,50,tanh),Flux.Dense(50,2)) |> Flux.gpu
dudt_tracker = Flux.Chain(Flux.Dense(2,50,CUDA.tanh),Flux.Dense(50,2)) |> Flux.gpu

NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false)(x)
NeuralODE(dudt,tspan,Tsit5(),saveat=0.1)(x)
NeuralODE(dudt,tspan,Tsit5(),saveat=0.1,sensealg=TrackerAdjoint())(x)

NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false)(xs)
NeuralODE(dudt,tspan,Tsit5(),saveat=0.1)(xs)
NeuralODE(dudt,tspan,Tsit5(),saveat=0.1,sensealg=TrackerAdjoint())(xs)

node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false)
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
@test ! iszero(grads[x])
@test ! iszero(grads[node.p])

grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
@test ! iszero(grads[xs])
@test ! iszero(grads[node.p])

node = NeuralODE(dudt_tracker,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=TrackerAdjoint())
grads = Zygote.gradient(()->sum(Array(node(x))),Flux.params(x,node))
@test ! iszero(grads[x])
@test ! iszero(grads[node.p])

grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
@test ! iszero(grads[xs])
@test ! iszero(grads[node.p])

node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=BacksolveAdjoint())
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
@test ! iszero(grads[x])
@test ! iszero(grads[node.p])

grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
@test ! iszero(grads[xs])
@test ! iszero(grads[node.p])
x = Float32[2.0; 0.0] |> Lux.gpu
xs = hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0]) |> Lux.gpu
tspan = (0.0f0, 25.0f0)

mp = Lux.Chain(Lux.Dense(2, 2))

dudt = Lux.Chain(Lux.Dense(2, 50, tanh), Lux.Dense(50, 2))
ps_dudt, st_dudt = Lux.setup(Random.default_rng(), dudt)
ps_dudt = ComponentArray(ps_dudt) |> Lux.gpu
st_dudt = st_dudt |> Lux.gpu

NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false)(x, ps_dudt, st_dudt)
NeuralODE(dudt, tspan, Tsit5(), saveat=0.1)(x, ps_dudt, st_dudt)
NeuralODE(dudt, tspan, Tsit5(), saveat=0.1, sensealg=TrackerAdjoint())(x, ps_dudt, st_dudt)

NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false)(xs, ps_dudt, st_dudt)
NeuralODE(dudt, tspan, Tsit5(), saveat=0.1)(xs, ps_dudt, st_dudt)
NeuralODE(dudt, tspan, Tsit5(), saveat=0.1, sensealg=TrackerAdjoint())(xs, ps_dudt, st_dudt)

node = NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false)
ps_node, st_node = Lux.setup(Random.default_rng(), node)
ps_node = ComponentArray(ps_node) |> Lux.gpu
st_node = st_node |> Lux.gpu
grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node)
@test !iszero(grads[1])
@test !iszero(grads[2])

grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node)
@test !iszero(grads[1])
@test !iszero(grads[2])

node = NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false,
sensealg=TrackerAdjoint())
grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node)
@test !iszero(grads[1])
@test !iszero(grads[2])

grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node)
@test !iszero(grads[1])
@test !iszero(grads[2])

node = NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false,
sensealg=BacksolveAdjoint())
grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node)
@test !iszero(grads[1])
@test !iszero(grads[2])

grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node)
@test !iszero(grads[1])
@test !iszero(grads[2])

# Adjoint
@testset "adjoint mode" begin
node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false)
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
@test ! iszero(grads[x])
@test ! iszero(grads[node.p])

grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
@test ! iszero(grads[xs])
@test ! iszero(grads[node.p])

node = NeuralODE(dudt,tspan,Tsit5(),saveat=0.0:0.1:10.0)
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
@test ! iszero(grads[x])
@test ! iszero(grads[node.p])

grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
@test ! iszero(grads[xs])
@test ! iszero(grads[node.p])

node = NeuralODE(dudt,tspan,Tsit5(),saveat=1f-1)
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node))
@test ! iszero(grads[x])
@test ! iszero(grads[node.p])

grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node))
@test ! iszero(grads[xs])
@test ! iszero(grads[node.p])
node = NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false)
grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node)
@test !iszero(grads[1])
@test !iszero(grads[2])

grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node)
@test !iszero(grads[1])
@test !iszero(grads[2])

node = NeuralODE(dudt, tspan, Tsit5(), saveat=0.0:0.1:10.0)
grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node)
@test !iszero(grads[1])
@test !iszero(grads[2])

grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node)
@test !iszero(grads[1])
@test !iszero(grads[2])

node = NeuralODE(dudt, tspan, Tsit5(), saveat=1.0f-1)
grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node)
@test !iszero(grads[1])
@test !iszero(grads[2])

grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node)
@test !iszero(grads[1])
@test !iszero(grads[2])
end

NeuralDSDE(dudt,mp,(0.0f0,2.0f0),SOSRI(),saveat=0.0:0.1:2.0)(x)
sode = NeuralDSDE(dudt,mp,(0.0f0,2.0f0),SOSRI(),saveat=Float32.(0.0:0.1:2.0),dt=1f-1)
grads = Zygote.gradient(()->sum(sode(x)),Flux.params(x,sode))

@test ! iszero(grads[x])
@test ! iszero(grads[sode.p])

grads = Zygote.gradient(()->sum(sode(xs)),Flux.params(xs,sode))
@test ! iszero(grads[xs])
@test ! iszero(grads[sode.p])
ndsde = NeuralDSDE(dudt, mp, (0.0f0, 2.0f0), SOSRI(), saveat=0.0:0.1:2.0)
ps_ndsde, st_ndsde = Lux.setup(Random.default_rng(), ndsde)
ps_ndsde = ComponentArray(ps_ndsde) |> Lux.gpu
st_ndsde = st_ndsde |> Lux.gpu
ndsde(x, ps_ndsde, st_ndsde)

sode = NeuralDSDE(dudt, mp, (0.0f0, 2.0f0), SOSRI(), saveat=Float32.(0.0:0.1:2.0),
dt=1.0f-1, sensealg=TrackerAdjoint())
ps_sode, st_sode = Lux.setup(Random.default_rng(), sode)
ps_sode = ComponentArray(ps_sode) |> Lux.gpu
st_sode = st_sode |> Lux.gpu
grads = Zygote.gradient((x, ps) -> sum(first(sode(x, ps, st_sode))), x, ps_sode)
@test !iszero(grads[1])
@test !iszero(grads[2])

grads = Zygote.gradient((xs, ps) -> sum(first(sode(xs, ps, st_sode))), xs, ps_sode)
@test !iszero(grads[1])
@test !iszero(grads[2])
Loading

0 comments on commit 18e69c2

Please sign in to comment.