Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jun 14, 2024
1 parent 9075ed5 commit b2ae23f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 37 deletions.
35 changes: 14 additions & 21 deletions docs/src/examples/neural_gde.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Neural Graph Differential Equations

!!! warn

This tutorial has not been ran or updated in awhile.

This tutorial has been adapted from [here](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/master/examples/neural_ode_cora.jl).
Expand Down Expand Up @@ -33,7 +34,7 @@ onecold(y) = map(argmax, eachcol(y))
X = g.ndata.features
y = onehotbatch(g.ndata.targets, classes) # a dense matrix is not the optimal, but we don't want to use Flux here

= normalized_adjacency(g, add_self_loops = true) |> device
= normalized_adjacency(g; add_self_loops = true) |> device

(; train_mask, val_mask, test_mask) = g.ndata
ytrain = y[:, train_mask]
Expand Down Expand Up @@ -70,9 +71,9 @@ initialstates(rng::AbstractRNG, d::ExplicitGCNConv) = (Ã = d.init_Ã(),)
function ExplicitGCNConv(Ã, ch::Pair{Int, Int}, activation = identity;
init_weight = glorot_normal, init_bias = zeros32)
init_Ã = () -> copy(Ã)
return ExplicitGCNConv{typeof(activation), typeof(init_Ã), typeof(init_weight),
typeof(init_bias)}(first(ch), last(ch), activation,
init_Ã, init_weight, init_bias)
return ExplicitGCNConv{
typeof(activation), typeof(init_Ã), typeof(init_weight), typeof(init_bias)}(
first(ch), last(ch), activation, init_Ã, init_weight, init_bias)
end

function (l::ExplicitGCNConv)(x::AbstractMatrix, ps, st::NamedTuple)
Expand All @@ -94,13 +95,11 @@ initialstates(rng::AbstractRNG, node::NeuralODE) = initialstates(rng, node.model
gnn = Chain(ExplicitGCNConv(Ã, nhidden => nhidden, relu),
ExplicitGCNConv(Ã, nhidden => nhidden, relu))

node = NeuralODE(gnn, (0.0f0, 1.0f0), Tsit5(), save_everystep = false,
node = NeuralODE(gnn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false,
reltol = 1e-3, abstol = 1e-3, save_start = false)

model = Chain(ExplicitGCNConv(Ã, nin => nhidden, relu),
node,
diffeqsol_to_array,
Dense(nhidden, nout))
node, diffeqsol_to_array, Dense(nhidden, nout))

# Loss
logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ); dims = 1))
Expand All @@ -114,7 +113,7 @@ function eval_loss_accuracy(X, y, mask, model, ps, st)
ŷ, _ = model(X, ps, st)
l = logitcrossentropy(ŷ[:, mask], y[:, mask])
acc = mean(onecold(ŷ[:, mask]) .== onecold(y[:, mask]))
return (loss = round(l, digits = 4), acc = round(acc * 100, digits = 2))
return (loss = round(l; digits = 4), acc = round(acc * 100; digits = 2))
end

# Training
Expand Down Expand Up @@ -183,7 +182,7 @@ onecold(y) = map(argmax, eachcol(y))
X = g.ndata.features
y = onehotbatch(g.ndata.targets, classes) # a dense matrix is not the optimal, but we don't want to use Flux here

= normalized_adjacency(g, add_self_loops = true) |> device
= normalized_adjacency(g; add_self_loops = true) |> device
```

### Training Data
Expand Down Expand Up @@ -233,12 +232,8 @@ end

function ExplicitGCNConv(Ã, ch::Pair{Int, Int}, activation = identity;
init_weight = glorot_normal, init_bias = zeros32)
return ExplicitGCNConv{typeof(activation), typeof(init_weight), typeof(init_bias)}(Ã,
first(ch),
last(ch),
activation,
init_weight,
init_bias)
return ExplicitGCNConv{typeof(activation), typeof(init_weight), typeof(init_bias)}(
Ã, first(ch), last(ch), activation, init_weight, init_bias)
end

function (l::ExplicitGCNConv)(x::AbstractMatrix, ps, st::NamedTuple)
Expand All @@ -260,13 +255,11 @@ diffeqsol_to_array(x::ODESolution) = dropdims(Array(x); dims = 3)
gnn = Chain(ExplicitGCNConv(Ã, nhidden => nhidden, relu),
ExplicitGCNConv(Ã, nhidden => nhidden, relu))

node = NeuralODE(gnn, (0.0f0, 1.0f0), Tsit5(), save_everystep = false,
node = NeuralODE(gnn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false,
reltol = 1e-3, abstol = 1e-3, save_start = false)

model = Chain(ExplicitGCNConv(Ã, nin => nhidden, relu),
node,
diffeqsol_to_array,
Dense(nhidden, nout))
node, diffeqsol_to_array, Dense(nhidden, nout))
```

## Training Configuration
Expand All @@ -287,7 +280,7 @@ function eval_loss_accuracy(X, y, mask, model, ps, st)
ŷ, _ = model(X, ps, st)
l = logitcrossentropy(ŷ[:, mask], y[:, mask])
acc = mean(onecold(ŷ[:, mask]) .== onecold(y[:, mask]))
return (loss = round(l, digits = 4), acc = round(acc * 100, digits = 2))
return (loss = round(l; digits = 4), acc = round(acc * 100; digits = 2))
end
```

Expand Down
28 changes: 12 additions & 16 deletions docs/src/examples/physical_constraints.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ terms must add to one. An example of this is as follows:

```@example dae
using SciMLSensitivity
using Lux, ComponentArrays, Optimization, OptimizationOptimJL,
OrdinaryDiffEq, Plots
using Lux, ComponentArrays, Optimization, OptimizationOptimJL, OrdinaryDiffEq, Plots
using Random
rng = Random.default_rng()
Expand All @@ -33,17 +32,16 @@ M = [1.0 0 0
tspan = (0.0, 1.0)
p = [0.04, 3e7, 1e4]
stiff_func = ODEFunction(f!, mass_matrix = M)
stiff_func = ODEFunction(f!; mass_matrix = M)
prob_stiff = ODEProblem(stiff_func, u₀, tspan, p)
sol_stiff = solve(prob_stiff, Rodas5(), saveat = 0.1)
sol_stiff = solve(prob_stiff, Rodas5(); saveat = 0.1)
nn_dudt2 = Lux.Chain(Lux.Dense(3, 64, tanh),
Lux.Dense(64, 2))
nn_dudt2 = Lux.Chain(Lux.Dense(3, 64, tanh), Lux.Dense(64, 2))
pinit, st = Lux.setup(rng, nn_dudt2)
model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1],
tspan, M, Rodas5(autodiff = false), saveat = 0.1)
tspan, M, Rodas5(; autodiff = false); saveat = 0.1)
function predict_stiff_ndae(p)
return model_stiff_ndae(u₀, p, st)[1]
Expand All @@ -65,7 +63,7 @@ l1 = first(loss_stiff_ndae(ComponentArray(pinit)))
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_stiff_ndae(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit))
result_stiff = Optimization.solve(optprob, OptimizationOptimJL.BFGS(), maxiters = 100)
result_stiff = Optimization.solve(optprob, OptimizationOptimJL.BFGS(); maxiters = 100)
```

## Step-by-Step Description
Expand All @@ -74,8 +72,7 @@ result_stiff = Optimization.solve(optprob, OptimizationOptimJL.BFGS(), maxiters

```@example dae2
using SciMLSensitivity
using Lux, ComponentArrays, Optimization, OptimizationOptimJL,
OrdinaryDiffEq, Plots
using Lux, ComponentArrays, Optimization, OptimizationOptimJL, OrdinaryDiffEq, Plots
using Random
rng = Random.default_rng()
Expand Down Expand Up @@ -123,9 +120,9 @@ We define and solve our ODE problem to generate the “labeled” data which wil
train our Neural Network.

```@example dae2
stiff_func = ODEFunction(f!, mass_matrix = M)
stiff_func = ODEFunction(f!; mass_matrix = M)
prob_stiff = ODEProblem(stiff_func, u₀, tspan, p)
sol_stiff = solve(prob_stiff, Rodas5(), saveat = 0.1)
sol_stiff = solve(prob_stiff, Rodas5(); saveat = 0.1)
```

Because this is a DAE, we need to make sure to use a **compatible solver**.
Expand All @@ -138,13 +135,12 @@ is more suited to SciML applications (similarly for `Lux.Dense`). The input to o
will be the initial conditions fed in as `u₀`.

```@example dae2
nn_dudt2 = Lux.Chain(Lux.Dense(3, 64, tanh),
Lux.Dense(64, 2))
nn_dudt2 = Lux.Chain(Lux.Dense(3, 64, tanh), Lux.Dense(64, 2))
pinit, st = Lux.setup(rng, nn_dudt2)
model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1],
tspan, M, Rodas5(autodiff = false), saveat = 0.1)
tspan, M, Rodas5(; autodiff = false); saveat = 0.1)
model_stiff_ndae(u₀, ComponentArray(pinit), st)
```

Expand Down Expand Up @@ -210,5 +206,5 @@ Finally, training with `Optimization.solve` by passing: *loss function*, *model
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_stiff_ndae(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit))
result_stiff = Optimization.solve(optprob, OptimizationOptimJL.BFGS(), maxiters = 100)
result_stiff = Optimization.solve(optprob, OptimizationOptimJL.BFGS(); maxiters = 100)
```

0 comments on commit b2ae23f

Please sign in to comment.