diff --git a/src/numerical_lyapunov_functions.jl b/src/numerical_lyapunov_functions.jl index d449b2b..0285e39 100644 --- a/src/numerical_lyapunov_functions.jl +++ b/src/numerical_lyapunov_functions.jl @@ -8,7 +8,11 @@ functions representing the Lyapunov function and its time derivative: ``V(x), V These functions can operate on a state vector or columnwise on a matrix of state vectors. # Arguments -- `phi`, `θ`: `phi` is the neural network with parameters `θ`. +- `phi`: the neural network, represented as `phi(x, θ)` if the neural network has a single + output, or a `Vector` of the same with one entry per neural network output. +- `θ`: the parameters of the neural network; `θ[:φ1]` should be the parameters of the first + neural network output (even if there is only one), `θ[:φ2]` the parameters of the + second (if there are multiple), and so on. - `structure`: a [`NeuralLyapunovStructure`](@ref) representing the structure of the neural Lyapunov function. - `dynamics`: the system dynamics, as a function to be used in conjunction with @@ -39,15 +43,7 @@ function get_numerical_lyapunov_function( J_net = nothing )::Tuple{Function, Function} # network_func is the numerical form of neural network output - output_dim = structure.network_dim - network_func = let φ = phi, _θ = θ, dim = output_dim - function (x) - reduce( - vcat, - Array(φ[i](x, _θ.depvar[Symbol(:φ, i)])) for i in 1:dim - ) - end - end + network_func = phi_to_net(phi, θ) # V is the numerical form of Lyapunov function V = let V_structure = structure.V, net = network_func, x0 = fixed_point @@ -91,3 +87,35 @@ function get_numerical_lyapunov_function( end end end + +""" + phi_to_net(phi, θ[; idx]) + +Return the network as a function of state alone. + +# Arguments + +- `phi`: the neural network, represented as `phi(state, θ)` if the neural network has a + single output, or a `Vector` of the same with one entry per neural network output. +- `θ`: the parameters of the neural network; `θ[:φ1]` should be the parameters of the first + neural network output (even if there is only one), `θ[:φ2]` the parameters of the + second (if there are multiple), and so on. +- `idx`: the neural network outputs to include in the returned function; defaults to all and + only applicable when `phi isa Vector`. +""" +function phi_to_net(phi, θ) + let _θ = θ, φ = phi + return (state) -> φ(state, _θ[:φ1]) + end +end + +function phi_to_net(phi::Vector, θ; idx = eachindex(phi)) + let _θ = θ, φ = phi, _idx = idx + return function (x) + reduce( + vcat, + Array(φ[i](x, _θ[Symbol(:φ, i)])) for i in _idx + ) + end + end +end diff --git a/src/policy_search.jl b/src/policy_search.jl index 2d519f3..eb099b4 100644 --- a/src/policy_search.jl +++ b/src/policy_search.jl @@ -67,17 +67,14 @@ function get_policy( control_dim::Integer; control_structure::Function = identity ) - function policy(state::AbstractVector) - control_structure( - reduce( - vcat, - Array(phi[i](state, θ.depvar[Symbol(:φ, i)])) - for i in (network_dim - control_dim + 1):network_dim - ) - ) - end + network_func = phi_to_net(phi, θ; idx = (network_dim - control_dim + 1):network_dim) - policy(state::AbstractMatrix) = mapslices(policy, state, dims = [1]) + policy(state::AbstractVector) = control_structure(network_func(state)) + policy(states::AbstractMatrix) = mapslices( + control_structure, + network_func(states), + dims = [1] + ) return policy end diff --git a/test/damped_pendulum.jl b/test/damped_pendulum.jl index 20c0df1..b5ef2dc 100644 --- a/test/damped_pendulum.jl +++ b/test/damped_pendulum.jl @@ -109,7 +109,7 @@ res = Optimization.solve(prob, BFGS(); maxiters = 300) V_func, V̇_func = get_numerical_lyapunov_function( discretization.phi, - res.u, + res.u.depvar, structure, ODEFunction(dynamics), zeros(length(bounds)); diff --git a/test/damped_sho.jl b/test/damped_sho.jl index e2e5044..429d810 100644 --- a/test/damped_sho.jl +++ b/test/damped_sho.jl @@ -81,7 +81,7 @@ res = Optimization.solve(prob, BFGS(); maxiters = 500) ###################### Get numerical numerical functions ###################### V_func, V̇_func = get_numerical_lyapunov_function( discretization.phi, - res.u, + res.u.depvar, structure, f, zeros(2); diff --git a/test/inverted_pendulum.jl b/test/inverted_pendulum.jl index 7d46784..6433995 100644 --- a/test/inverted_pendulum.jl +++ b/test/inverted_pendulum.jl @@ -104,14 +104,14 @@ res = Optimization.solve(prob, BFGS(); maxiters = 300) V_func, V̇_func = get_numerical_lyapunov_function( discretization.phi, - res.u, + res.u.depvar, structure, open_loop_pendulum_dynamics, upright_equilibrium; p = p ) -u = get_policy(discretization.phi, res.u, dim_output, dim_u) +u = get_policy(discretization.phi, res.u.depvar, dim_output, dim_u) ################################## Simulate ################################### diff --git a/test/inverted_pendulum_ODESystem.jl b/test/inverted_pendulum_ODESystem.jl index 48ccab7..caff5e4 100644 --- a/test/inverted_pendulum_ODESystem.jl +++ b/test/inverted_pendulum_ODESystem.jl @@ -112,14 +112,14 @@ res = Optimization.solve(prob, BFGS(); maxiters = 300) V_func, V̇_func = get_numerical_lyapunov_function( discretization.phi, - res.u, + res.u.depvar, structure, open_loop_pendulum_dynamics, upright_equilibrium; p = p ) -u = get_policy(discretization.phi, res.u, dim_output, dim_u) +u = get_policy(discretization.phi, res.u.depvar, dim_output, dim_u) ################################## Simulate ################################### diff --git a/test/roa_estimation.jl b/test/roa_estimation.jl index 347d58c..c3d2418 100644 --- a/test/roa_estimation.jl +++ b/test/roa_estimation.jl @@ -68,7 +68,7 @@ res = Optimization.solve(prob, BFGS(); maxiters = 300) ###################### Get numerical numerical functions ###################### V_func, V̇_func = get_numerical_lyapunov_function( discretization.phi, - res.u, + res.u.depvar, structure, f, zeros(length(lb))