Skip to content

Commit

Permalink
Created phi_to_net and changed parameters to res.u.depvar instead…
Browse files Browse the repository at this point in the history
… of `res.u`.
  • Loading branch information
nicholaskl97 committed Apr 3, 2024
1 parent 2cd4720 commit d25f534
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 27 deletions.
48 changes: 38 additions & 10 deletions src/numerical_lyapunov_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ functions representing the Lyapunov function and its time derivative: ``V(x), V
These functions can operate on a state vector or columnwise on a matrix of state vectors.
# Arguments
- `phi`, `θ`: `phi` is the neural network with parameters `θ`.
- `phi`: the neural network, represented as `phi(x, θ)` if the neural network has a single
output, or a `Vector` of the same with one entry per neural network output.
- `θ`: the parameters of the neural network; `θ[:φ1]` should be the parameters of the first
neural network output (even if there is only one), `θ[:φ2]` the parameters of the
second (if there are multiple), and so on.
- `structure`: a [`NeuralLyapunovStructure`](@ref) representing the structure of the neural
Lyapunov function.
- `dynamics`: the system dynamics, as a function to be used in conjunction with
Expand Down Expand Up @@ -39,15 +43,7 @@ function get_numerical_lyapunov_function(
J_net = nothing
)::Tuple{Function, Function}
# network_func is the numerical form of neural network output
output_dim = structure.network_dim
network_func = let φ = phi, _θ = θ, dim = output_dim
function (x)
reduce(
vcat,
Array(φ[i](x, _θ.depvar[Symbol(, i)])) for i in 1:dim
)
end
end
network_func = phi_to_net(phi, θ)

# V is the numerical form of Lyapunov function
V = let V_structure = structure.V, net = network_func, x0 = fixed_point
Expand Down Expand Up @@ -91,3 +87,35 @@ function get_numerical_lyapunov_function(
end
end
end

"""
phi_to_net(phi, θ[; idx])
Return the network as a function of state alone.
# Arguments
- `phi`: the neural network, represented as `phi(state, θ)` if the neural network has a
single output, or a `Vector` of the same with one entry per neural network output.
- `θ`: the parameters of the neural network; `θ[:φ1]` should be the parameters of the first
neural network output (even if there is only one), `θ[:φ2]` the parameters of the
second (if there are multiple), and so on.
- `idx`: the neural network outputs to include in the returned function; defaults to all and
only applicable when `phi isa Vector`.
"""
function phi_to_net(phi, θ)
let= θ, φ = phi
return (state) -> φ(state, _θ[:φ1])

Check warning on line 108 in src/numerical_lyapunov_functions.jl

View check run for this annotation

Codecov / codecov/patch

src/numerical_lyapunov_functions.jl#L106-L108

Added lines #L106 - L108 were not covered by tests
end
end

function phi_to_net(phi::Vector, θ; idx = eachindex(phi))
let= θ, φ = phi, _idx = idx
return function (x)
reduce(
vcat,
Array(φ[i](x, _θ[Symbol(, i)])) for i in _idx
)
end
end
end
17 changes: 7 additions & 10 deletions src/policy_search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,14 @@ function get_policy(
control_dim::Integer;
control_structure::Function = identity
)
function policy(state::AbstractVector)
control_structure(
reduce(
vcat,
Array(phi[i](state, θ.depvar[Symbol(, i)]))
for i in (network_dim - control_dim + 1):network_dim
)
)
end
network_func = phi_to_net(phi, θ; idx = (network_dim - control_dim + 1):network_dim)

policy(state::AbstractMatrix) = mapslices(policy, state, dims = [1])
policy(state::AbstractVector) = control_structure(network_func(state))
policy(states::AbstractMatrix) = mapslices(

Check warning on line 73 in src/policy_search.jl

View check run for this annotation

Codecov / codecov/patch

src/policy_search.jl#L73

Added line #L73 was not covered by tests
control_structure,
network_func(states),
dims = [1]
)

return policy
end
2 changes: 1 addition & 1 deletion test/damped_pendulum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ res = Optimization.solve(prob, BFGS(); maxiters = 300)

V_func, V̇_func = get_numerical_lyapunov_function(
discretization.phi,
res.u,
res.u.depvar,
structure,
ODEFunction(dynamics),
zeros(length(bounds));
Expand Down
2 changes: 1 addition & 1 deletion test/damped_sho.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ res = Optimization.solve(prob, BFGS(); maxiters = 500)
###################### Get numerical numerical functions ######################
V_func, V̇_func = get_numerical_lyapunov_function(
discretization.phi,
res.u,
res.u.depvar,
structure,
f,
zeros(2);
Expand Down
4 changes: 2 additions & 2 deletions test/inverted_pendulum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ res = Optimization.solve(prob, BFGS(); maxiters = 300)

V_func, V̇_func = get_numerical_lyapunov_function(
discretization.phi,
res.u,
res.u.depvar,
structure,
open_loop_pendulum_dynamics,
upright_equilibrium;
p = p
)

u = get_policy(discretization.phi, res.u, dim_output, dim_u)
u = get_policy(discretization.phi, res.u.depvar, dim_output, dim_u)

################################## Simulate ###################################

Expand Down
4 changes: 2 additions & 2 deletions test/inverted_pendulum_ODESystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,14 @@ res = Optimization.solve(prob, BFGS(); maxiters = 300)

V_func, V̇_func = get_numerical_lyapunov_function(
discretization.phi,
res.u,
res.u.depvar,
structure,
open_loop_pendulum_dynamics,
upright_equilibrium;
p = p
)

u = get_policy(discretization.phi, res.u, dim_output, dim_u)
u = get_policy(discretization.phi, res.u.depvar, dim_output, dim_u)

################################## Simulate ###################################

Expand Down
2 changes: 1 addition & 1 deletion test/roa_estimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ res = Optimization.solve(prob, BFGS(); maxiters = 300)
###################### Get numerical numerical functions ######################
V_func, V̇_func = get_numerical_lyapunov_function(
discretization.phi,
res.u,
res.u.depvar,
structure,
f,
zeros(length(lb))
Expand Down

0 comments on commit d25f534

Please sign in to comment.