-
-
Notifications
You must be signed in to change notification settings - Fork 154
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
219 additions
and
162 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,4 +5,5 @@ deps/deps.jl | |
Manifest.toml | ||
docs/build | ||
*.DS_Store | ||
wip | ||
.vscode | ||
wip/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,104 +1,110 @@ | ||
using DiffEqFlux, CUDA, Zygote, MLDataUtils, NNlib, OrdinaryDiffEq, Test | ||
using Flux.Losses: logitcrossentropy | ||
using DiffEqFlux, CUDA, Zygote, MLDataUtils, NNlib, OrdinaryDiffEq, Test, Lux, Statistics, | ||
ComponentArrays, Random, Optimization, OptimizationOptimisers | ||
using MLDatasets: MNIST | ||
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs | ||
|
||
CUDA.allowscalar(false) | ||
ENV["DATADEPS_ALWAYS_ACCEPT"] = true | ||
|
||
function loadmnist(batchsize = bs) | ||
# Use MLDataUtils LabelEnc for natural onehot conversion | ||
onehot(labels_raw) = convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9))) | ||
# Load MNIST | ||
mnist = MNIST(split = :train) | ||
logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims=1); dims=1)) | ||
|
||
function loadmnist(batchsize=bs) | ||
# Use MLDataUtils LabelEnc for natural onehot conversion | ||
onehot(labels_raw) = convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9))) | ||
# Load MNIST | ||
mnist = MNIST(split=:train) | ||
imgs, labels_raw = mnist.features, mnist.targets | ||
# Process images into (H,W,C,BS) batches | ||
x_train = Float32.(reshape(imgs,size(imgs,1),size(imgs,2),1,size(imgs,3))) |> Flux.gpu | ||
x_train = batchview(x_train,batchsize) | ||
# Onehot and batch the labels | ||
y_train = onehot(labels_raw) |> Flux.gpu | ||
y_train = batchview(y_train,batchsize) | ||
return x_train, y_train | ||
# Process images into (H,W,C,BS) batches | ||
x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |> Lux.gpu | ||
x_train = batchview(x_train, batchsize) | ||
# Onehot and batch the labels | ||
y_train = onehot(labels_raw) |> Lux.gpu | ||
y_train = batchview(y_train, batchsize) | ||
return x_train, y_train | ||
end | ||
|
||
# Main | ||
const bs = 128 | ||
x_train, y_train = loadmnist(bs) | ||
|
||
down = Flux.Chain(x->reshape(x,(28*28,:)), | ||
Flux.Dense(784,20,tanh) | ||
) |> Flux.gpu | ||
nfe = 0 | ||
nn = Flux.Chain( | ||
Flux.Dense(20,10,tanh), | ||
Flux.Dense(10,10,tanh), | ||
Flux.Dense(10,20,tanh) | ||
) |> Flux.gpu | ||
fc = Flux.Chain(Flux.Dense(20,10)) |> Flux.gpu | ||
|
||
nn_ode = NeuralODE(nn, (0.f0, 1.f0), Tsit5(), | ||
save_everystep = false, | ||
reltol = 1e-3, abstol = 1e-3, | ||
save_start = false) |> Flux.gpu | ||
down = Lux.Chain(Lux.FlattenLayer(), Lux.Dense(784, 20, tanh)) | ||
nn = Lux.Chain(Lux.Dense(20, 10, tanh), Lux.Dense(10, 10, tanh), | ||
Lux.Dense(10, 20, tanh)) | ||
fc = Lux.Dense(20, 10) | ||
|
||
nn_ode = NeuralODE(nn, (0.0f0, 1.0f0), Tsit5(), save_everystep=false, reltol=1e-3, | ||
abstol=1e-3, save_start=false) | ||
|
||
""" | ||
DiffEqArray_to_Array(x) | ||
Cheap conversion of a `DiffEqArray` instance to a Matrix. | ||
""" | ||
function DiffEqArray_to_Array(x) | ||
xarr = Flux.gpu(x) | ||
xarr = Lux.gpu(x) | ||
return reshape(xarr, size(xarr)[1:2]) | ||
end | ||
|
||
#Build our over-all model topology | ||
m = Flux.Chain(down, | ||
nn_ode, | ||
DiffEqArray_to_Array, | ||
fc) |> Flux.gpu | ||
m = Lux.Chain(; down, nn_ode, convert=Lux.WrappedFunction(DiffEqArray_to_Array), fc) | ||
ps, st = Lux.setup(Random.default_rng(), m) | ||
ps = ComponentArray(ps) |> Lux.gpu | ||
st = st |> Lux.gpu | ||
|
||
#We can also build the model topology without a NN-ODE | ||
m_no_ode = Flux.Chain(down, nn, fc) |> Flux.gpu | ||
m_no_ode = Lux.Chain(; down, nn, fc) | ||
ps_no_ode, st_no_ode = Lux.setup(Random.default_rng(), m_no_ode) | ||
ps_no_ode = ComponentArray(ps_no_ode) |> Lux.gpu | ||
st_no_ode = st_no_ode |> Lux.gpu | ||
|
||
#To understand the intermediate NN-ODE layer, we can examine it's dimensionality | ||
x_d = down(x_train[1]) | ||
x_d = first(down(x_train[1], ps.down, st.down)) | ||
|
||
# We can see that we can compute the forward pass through the NN topology featuring an NNODE layer. | ||
x_m = m(x_train[1]) | ||
x_m = first(m(x_train[1], ps, st)) | ||
#Or without the NN-ODE layer. | ||
x_m = m_no_ode(x_train[1]) | ||
x_m = first(m_no_ode(x_train[1], ps_no_ode, st_no_ode)) | ||
|
||
classify(x) = argmax.(eachcol(x)) | ||
|
||
function accuracy(model,data; n_batches=100) | ||
function accuracy(model, data, ps, st; n_batches=100) | ||
total_correct = 0 | ||
total = 0 | ||
for (x,y) in collect(data)[1:n_batches] | ||
target_class = classify(Flux.cpu(y)) | ||
predicted_class = classify(Flux.cpu(model(x))) | ||
st = Lux.testmode(st) | ||
for (x, y) in collect(data)[1:n_batches] | ||
target_class = classify(Lux.cpu(y)) | ||
predicted_class = classify(Lux.cpu(first(model(x, ps, st)))) | ||
total_correct += sum(target_class .== predicted_class) | ||
total += length(target_class) | ||
end | ||
return total_correct/total | ||
return total_correct / total | ||
end | ||
#burn in accuracy | ||
accuracy(m, zip(x_train,y_train)) | ||
accuracy(m, zip(x_train, y_train), ps, st) | ||
|
||
function loss_function(ps, x, y) | ||
pred, st_ = m(x, ps, st) | ||
return logitcrossentropy(pred, y), pred | ||
end | ||
|
||
loss(x,y) = logitcrossentropy(m(x),y) | ||
#burn in loss | ||
loss(x_train[1],y_train[1]) | ||
loss_function(ps, x_train[1], y_train[1]) | ||
|
||
opt = ADAM(0.05) | ||
iter = 0 | ||
|
||
cb() = begin | ||
opt_func = OptimizationFunction((ps, _, x, y) -> loss_function(ps, x, y), | ||
Optimization.AutoZygote()) | ||
opt_prob = OptimizationProblem(opt_func, ps) | ||
|
||
function callback(ps, l, pred) | ||
global iter += 1 | ||
#Monitor that the weights do infact update | ||
#Every 10 training iterations show accuracy | ||
(iter%10 == 0) && @show accuracy(m, zip(x_train,y_train)) | ||
global nfe=0 | ||
(iter % 10 == 0) && @show accuracy(m, zip(x_train, y_train), ps, st) | ||
return false | ||
end | ||
|
||
# Train the NN-ODE and monitor the loss and weights. | ||
Flux.train!(loss, Flux.params(down, nn_ode.p, fc), zip(x_train, y_train), opt, cb = cb) | ||
@test accuracy(m, zip(x_train,y_train)) > 0.8 | ||
res = Optimization.solve(opt_prob, opt, zip(x_train, y_train); callback) | ||
@test accuracy(m, zip(x_train, y_train), res.u, st) > 0.8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,86 +1,104 @@ | ||
using DiffEqFlux, Flux, CUDA, Zygote, OrdinaryDiffEq, StochasticDiffEq, Test | ||
using DiffEqFlux, Lux, CUDA, Zygote, OrdinaryDiffEq, StochasticDiffEq, Test, Random, | ||
ComponentArrays | ||
|
||
CUDA.allowscalar(false) | ||
|
||
mp = Flux.Chain(Flux.Dense(2,2)) |> Flux.gpu | ||
x = Float32[2.; 0.] |> Flux.gpu | ||
xs = Float32.(hcat([0.; 0.], [1.; 0.], [2.; 0.])) |> Flux.gpu | ||
tspan = (0.0f0,25.0f0) | ||
dudt = Flux.Chain(Flux.Dense(2,50,tanh),Flux.Dense(50,2)) |> Flux.gpu | ||
dudt_tracker = Flux.Chain(Flux.Dense(2,50,CUDA.tanh),Flux.Dense(50,2)) |> Flux.gpu | ||
|
||
NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false)(x) | ||
NeuralODE(dudt,tspan,Tsit5(),saveat=0.1)(x) | ||
NeuralODE(dudt,tspan,Tsit5(),saveat=0.1,sensealg=TrackerAdjoint())(x) | ||
|
||
NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false)(xs) | ||
NeuralODE(dudt,tspan,Tsit5(),saveat=0.1)(xs) | ||
NeuralODE(dudt,tspan,Tsit5(),saveat=0.1,sensealg=TrackerAdjoint())(xs) | ||
|
||
node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false) | ||
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) | ||
@test ! iszero(grads[x]) | ||
@test ! iszero(grads[node.p]) | ||
|
||
grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) | ||
@test ! iszero(grads[xs]) | ||
@test ! iszero(grads[node.p]) | ||
|
||
node = NeuralODE(dudt_tracker,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=TrackerAdjoint()) | ||
grads = Zygote.gradient(()->sum(Array(node(x))),Flux.params(x,node)) | ||
@test ! iszero(grads[x]) | ||
@test ! iszero(grads[node.p]) | ||
|
||
grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) | ||
@test ! iszero(grads[xs]) | ||
@test ! iszero(grads[node.p]) | ||
|
||
node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=BacksolveAdjoint()) | ||
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) | ||
@test ! iszero(grads[x]) | ||
@test ! iszero(grads[node.p]) | ||
|
||
grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) | ||
@test ! iszero(grads[xs]) | ||
@test ! iszero(grads[node.p]) | ||
x = Float32[2.0; 0.0] |> Lux.gpu | ||
xs = hcat([0.0; 0.0], [1.0; 0.0], [2.0; 0.0]) |> Lux.gpu | ||
tspan = (0.0f0, 25.0f0) | ||
|
||
mp = Lux.Chain(Lux.Dense(2, 2)) | ||
|
||
dudt = Lux.Chain(Lux.Dense(2, 50, tanh), Lux.Dense(50, 2)) | ||
ps_dudt, st_dudt = Lux.setup(Random.default_rng(), dudt) | ||
ps_dudt = ComponentArray(ps_dudt) |> Lux.gpu | ||
st_dudt = st_dudt |> Lux.gpu | ||
|
||
NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false)(x, ps_dudt, st_dudt) | ||
NeuralODE(dudt, tspan, Tsit5(), saveat=0.1)(x, ps_dudt, st_dudt) | ||
NeuralODE(dudt, tspan, Tsit5(), saveat=0.1, sensealg=TrackerAdjoint())(x, ps_dudt, st_dudt) | ||
|
||
NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false)(xs, ps_dudt, st_dudt) | ||
NeuralODE(dudt, tspan, Tsit5(), saveat=0.1)(xs, ps_dudt, st_dudt) | ||
NeuralODE(dudt, tspan, Tsit5(), saveat=0.1, sensealg=TrackerAdjoint())(xs, ps_dudt, st_dudt) | ||
|
||
node = NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false) | ||
ps_node, st_node = Lux.setup(Random.default_rng(), node) | ||
ps_node = ComponentArray(ps_node) |> Lux.gpu | ||
st_node = st_node |> Lux.gpu | ||
grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node) | ||
@test !iszero(grads[1]) | ||
@test !iszero(grads[2]) | ||
|
||
grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node) | ||
@test !iszero(grads[1]) | ||
@test !iszero(grads[2]) | ||
|
||
node = NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false, | ||
sensealg=TrackerAdjoint()) | ||
grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node) | ||
@test !iszero(grads[1]) | ||
@test !iszero(grads[2]) | ||
|
||
grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node) | ||
@test !iszero(grads[1]) | ||
@test !iszero(grads[2]) | ||
|
||
node = NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false, | ||
sensealg=BacksolveAdjoint()) | ||
grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node) | ||
@test !iszero(grads[1]) | ||
@test !iszero(grads[2]) | ||
|
||
grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node) | ||
@test !iszero(grads[1]) | ||
@test !iszero(grads[2]) | ||
|
||
# Adjoint | ||
@testset "adjoint mode" begin | ||
node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false) | ||
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) | ||
@test ! iszero(grads[x]) | ||
@test ! iszero(grads[node.p]) | ||
|
||
grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) | ||
@test ! iszero(grads[xs]) | ||
@test ! iszero(grads[node.p]) | ||
|
||
node = NeuralODE(dudt,tspan,Tsit5(),saveat=0.0:0.1:10.0) | ||
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) | ||
@test ! iszero(grads[x]) | ||
@test ! iszero(grads[node.p]) | ||
|
||
grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) | ||
@test ! iszero(grads[xs]) | ||
@test ! iszero(grads[node.p]) | ||
|
||
node = NeuralODE(dudt,tspan,Tsit5(),saveat=1f-1) | ||
grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) | ||
@test ! iszero(grads[x]) | ||
@test ! iszero(grads[node.p]) | ||
|
||
grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) | ||
@test ! iszero(grads[xs]) | ||
@test ! iszero(grads[node.p]) | ||
node = NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false) | ||
grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node) | ||
@test !iszero(grads[1]) | ||
@test !iszero(grads[2]) | ||
|
||
grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node) | ||
@test !iszero(grads[1]) | ||
@test !iszero(grads[2]) | ||
|
||
node = NeuralODE(dudt, tspan, Tsit5(), saveat=0.0:0.1:10.0) | ||
grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node) | ||
@test !iszero(grads[1]) | ||
@test !iszero(grads[2]) | ||
|
||
grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node) | ||
@test !iszero(grads[1]) | ||
@test !iszero(grads[2]) | ||
|
||
node = NeuralODE(dudt, tspan, Tsit5(), saveat=1.0f-1) | ||
grads = Zygote.gradient((x, ps) -> sum(first(node(x, ps, st_node))), x, ps_node) | ||
@test !iszero(grads[1]) | ||
@test !iszero(grads[2]) | ||
|
||
grads = Zygote.gradient((xs, ps) -> sum(first(node(xs, ps, st_node))), xs, ps_node) | ||
@test !iszero(grads[1]) | ||
@test !iszero(grads[2]) | ||
end | ||
|
||
NeuralDSDE(dudt,mp,(0.0f0,2.0f0),SOSRI(),saveat=0.0:0.1:2.0)(x) | ||
sode = NeuralDSDE(dudt,mp,(0.0f0,2.0f0),SOSRI(),saveat=Float32.(0.0:0.1:2.0),dt=1f-1) | ||
grads = Zygote.gradient(()->sum(sode(x)),Flux.params(x,sode)) | ||
|
||
@test ! iszero(grads[x]) | ||
@test ! iszero(grads[sode.p]) | ||
|
||
grads = Zygote.gradient(()->sum(sode(xs)),Flux.params(xs,sode)) | ||
@test ! iszero(grads[xs]) | ||
@test ! iszero(grads[sode.p]) | ||
ndsde = NeuralDSDE(dudt, mp, (0.0f0, 2.0f0), SOSRI(), saveat=0.0:0.1:2.0) | ||
ps_ndsde, st_ndsde = Lux.setup(Random.default_rng(), ndsde) | ||
ps_ndsde = ComponentArray(ps_ndsde) |> Lux.gpu | ||
st_ndsde = st_ndsde |> Lux.gpu | ||
ndsde(x, ps_ndsde, st_ndsde) | ||
|
||
sode = NeuralDSDE(dudt, mp, (0.0f0, 2.0f0), SOSRI(), saveat=Float32.(0.0:0.1:2.0), | ||
dt=1.0f-1, sensealg=TrackerAdjoint()) | ||
ps_sode, st_sode = Lux.setup(Random.default_rng(), sode) | ||
ps_sode = ComponentArray(ps_sode) |> Lux.gpu | ||
st_sode = st_sode |> Lux.gpu | ||
grads = Zygote.gradient((x, ps) -> sum(first(sode(x, ps, st_sode))), x, ps_sode) | ||
@test !iszero(grads[1]) | ||
@test !iszero(grads[2]) | ||
|
||
grads = Zygote.gradient((xs, ps) -> sum(first(sode(xs, ps, st_sode))), xs, ps_sode) | ||
@test !iszero(grads[1]) | ||
@test !iszero(grads[2]) |
Oops, something went wrong.