Skip to content

Commit

Permalink
Merge pull request #27 from SciML/demos
Browse files Browse the repository at this point in the history
Updated demos to use new `Lux.PeriodicEmbedding`
  • Loading branch information
nicholaskl97 authored May 14, 2024
2 parents 96f2125 + 1ee202f commit f93a351
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
[compat]
ForwardDiff = "0.10"
JuMP = "1"
Lux = "0.5"
Lux = "0.5.45"
ModelingToolkit = "8, 9"
NLopt = "1"
NeuralPDE = "5.10"
Expand Down
6 changes: 1 addition & 5 deletions test/damped_pendulum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,7 @@ dim_state = length(bounds)
dim_hidden = 15
dim_output = 2
chain = [Lux.Chain(
Lux.WrappedFunction(x -> vcat(
transpose(sin.(x[1, :])),
transpose(cos.(x[1, :])),
transpose(x[2, :])
)),
PeriodicEmbedding([1], [2π]),
Dense(3, dim_hidden, tanh),
Dense(dim_hidden, dim_hidden, tanh),
Dense(dim_hidden, 1, use_bias = false)
Expand Down
10 changes: 3 additions & 7 deletions test/inverted_pendulum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ dim_phi = 2
dim_u = 1
dim_output = dim_phi + dim_u
chain = [Lux.Chain(
Lux.WrappedFunction(x -> vcat(
transpose(sin.(x[1, :])),
transpose(cos.(x[1, :])),
transpose(x[2, :])
)),
PeriodicEmbedding([1], [2π]),
Dense(3, dim_hidden, tanh),
Dense(dim_hidden, dim_hidden, tanh),
Dense(dim_hidden, 1, use_bias = false)
Expand Down Expand Up @@ -151,7 +147,7 @@ closed_loop_dynamics = ODEFunction(

# Starting still at bottom
downward_equilibrium = zeros(2)
ode_prob = ODEProblem(closed_loop_dynamics, downward_equilibrium, [0.0, 35.0], p)
ode_prob = ODEProblem(closed_loop_dynamics, downward_equilibrium, [0.0, 75.0], p)
sol = solve(ode_prob, Tsit5())
# plot(sol)

Expand All @@ -162,7 +158,7 @@ x_end, y_end = sin(θ_end), -cos(θ_end)

# Starting at a random point
x0 = lb .+ rand(2) .* (ub .- lb)
ode_prob = ODEProblem(closed_loop_dynamics, x0, [0.0, 20.0], p)
ode_prob = ODEProblem(closed_loop_dynamics, x0, [0.0, 75.0], p)
sol = solve(ode_prob, Tsit5())
# plot(sol)

Expand Down
6 changes: 1 addition & 5 deletions test/inverted_pendulum_ODESystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,7 @@ dim_phi = 2
dim_u = 1
dim_output = dim_phi + dim_u
chain = [Lux.Chain(
Lux.WrappedFunction(x -> vcat(
transpose(sin.(x[1, :])),
transpose(cos.(x[1, :])),
transpose(x[2, :])
)),
PeriodicEmbedding([1], [2π]),
Dense(3, dim_hidden, tanh),
Dense(dim_hidden, dim_hidden, tanh),
Dense(dim_hidden, 1, use_bias = false)
Expand Down

0 comments on commit f93a351

Please sign in to comment.