Skip to content

Commit

Permalink
refactor: store static bool in Hamiltonian Layer
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 14, 2024
1 parent 8ce1b10 commit 5abcf36
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"

Expand Down Expand Up @@ -65,6 +66,7 @@ NNlib = "0.9.21"
Random = "1.10"
Reexport = "1.2.2"
ReverseDiff = "1.15"
Static = "1.1.1"
Statistics = "1.10"
Tracker = "0.2.34"
WeightInitializers = "1"
Expand Down
1 change: 1 addition & 0 deletions src/layers/Layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using ConcreteStructs: @concrete
using ChainRulesCore: ChainRulesCore, @non_differentiable, @ignore_derivatives
using Markdown: @doc_str
using Random: AbstractRNG
using Static: Static, True, False

using ForwardDiff: ForwardDiff

Expand Down
9 changes: 5 additions & 4 deletions src/layers/hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ returns the time derivatives for position and momentum.
[Nested Autodiff](https://lux.csail.mit.edu/stable/manual/nested_autodiff) for more
information and known limitations.
"""
@concrete struct HamiltonianNN{FST} <: AbstractLuxWrapperLayer{:model}
@concrete struct HamiltonianNN <: AbstractLuxWrapperLayer{:model}
fixed_state_type
model
autodiff
end
Expand All @@ -55,7 +56,7 @@ function HamiltonianNN{FST}(model; autodiff=nothing) where {FST}
end
end

return HamiltonianNN{FST}(model, autodiff)
return HamiltonianNN(Static.static(FST), model, autodiff)
end

function LuxCore.initialstates(rng::AbstractRNG, hnn::HamiltonianNN)
Expand All @@ -69,8 +70,8 @@ function (hnn::HamiltonianNN)(x::AbstractVector, ps, st)
return vec(y), stₙ
end

function (hnn::HamiltonianNN{FST})(x::AbstractArray{T, N}, ps, st) where {FST, T, N}
model = StatefulLuxLayer{FST}(hnn.model, ps, st.model)
function (hnn::HamiltonianNN)(x::AbstractArray{T, N}, ps, st) where {T, N}
model = StatefulLuxLayer{Static.known(hnn.fixed_state_type)}(hnn.model, ps, st.model)

st.first_call && check_hamiltonian_layer(hnn.model, x, ps, st.model)

Expand Down

0 comments on commit 5abcf36

Please sign in to comment.