-
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9e06f1c
commit c49364c
Showing
8 changed files
with
353 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
""" | ||
add_policy_search(lyapunov_structure, new_dims, control_structure) | ||
Adds dependence on the neural network to the dynamics in a `NeuralLyapunovStructure` | ||
Adds `new_dims` outputs to the neural network and feeds them through `control_structure` to | ||
calculatethe contribution of the neural network to the dynamics. | ||
The existing `lyapunov_structure.network_dim` dimensions are used as in `lyapunov_structure` | ||
to calculate the Lyapunov function. | ||
`lyapunov_structure` should assume in its `V̇` that the dynamics take a form `f(x, p, t)`. | ||
The returned `NeuralLyapunovStructure` will assume instead `f(x, u, p, t)`, where `u` is the | ||
contribution from the neural network. Therefore, this structure cannot be used with a | ||
`NeuralLyapunovPDESystem` method that requires an `ODEFunction`, `ODESystem`, or | ||
`ODEProblem`. | ||
""" | ||
function add_policy_search( | ||
lyapunov_structure::NeuralLyapunovStructure, | ||
new_dims::Integer; | ||
control_structure::Function = identity | ||
)::NeuralLyapunovStructure | ||
let V = lyapunov_structure.V, ∇V = lyapunov_structure.∇V, V̇ = lyapunov_structure.V̇, | ||
V_dim = lyapunov_structure.network_dim, nd = new_dims, u = control_structure | ||
|
||
NeuralLyapunovStructure( | ||
function (net, state, fixed_point) | ||
if length(size(state)) == 1 | ||
if V_dim == 1 | ||
V(st -> net(st)[1], state, fixed_point) | ||
else | ||
V(st -> net(st)[1:V_dim], state, fixed_point) | ||
end | ||
else | ||
V(st -> net(st)[1:V_dim, :], state, fixed_point) | ||
end | ||
end, | ||
function (net, J_net, state, fixed_point) | ||
∇V(st -> net(st)[1:V_dim], st -> J_net(st)[1:V_dim, :], state, fixed_point) | ||
end, | ||
function (net, J_net, f, state, params, t, fixed_point) | ||
V̇(st -> net(st)[1:V_dim], st -> J_net(st)[1:V_dim, :], | ||
(st, p, t) -> f(st, u(net(st)[(V_dim + 1):end]), p, t), state, params, | ||
t, fixed_point) | ||
end, | ||
(f, net, state, p, t) -> f(state, u(net(state)[(V_dim + 1):end]), p, t), | ||
V_dim + nd | ||
) | ||
end | ||
end | ||
|
||
function get_policy( | ||
phi, | ||
θ, | ||
network_func::Function, | ||
dim_u::Integer; | ||
u_func::Function = identity | ||
) | ||
function policy(state::AbstractVector) | ||
u_func(network_func(phi, θ, state)[(end - dim_u + 1):end]) | ||
end | ||
|
||
policy(state::AbstractMatrix) = mapslices(policy, state, dims = [1]) | ||
|
||
return policy | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.