diff --git a/Project.toml b/Project.toml index 46bbd11..84814c5 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" @@ -65,6 +66,7 @@ NNlib = "0.9.21" Random = "1.10" Reexport = "1.2.2" ReverseDiff = "1.15" +Static = "1.1.1" Statistics = "1.10" Tracker = "0.2.34" WeightInitializers = "1" diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index fe35f04..aa1d2af 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -7,6 +7,7 @@ using ConcreteStructs: @concrete using ChainRulesCore: ChainRulesCore, @non_differentiable, @ignore_derivatives using Markdown: @doc_str using Random: AbstractRNG +using Static: Static, True, False using ForwardDiff: ForwardDiff diff --git a/src/layers/hamiltonian.jl b/src/layers/hamiltonian.jl index 5268e43..b071f14 100644 --- a/src/layers/hamiltonian.jl +++ b/src/layers/hamiltonian.jl @@ -36,7 +36,8 @@ returns the time derivatives for position and momentum. [Nested Autodiff](https://lux.csail.mit.edu/stable/manual/nested_autodiff) for more information and known limitations. """ -@concrete struct HamiltonianNN{FST} <: AbstractLuxWrapperLayer{:model} +@concrete struct HamiltonianNN <: AbstractLuxWrapperLayer{:model} + fixed_state_type model autodiff end @@ -55,7 +56,7 @@ function HamiltonianNN{FST}(model; autodiff=nothing) where {FST} end end - return HamiltonianNN{FST}(model, autodiff) + return HamiltonianNN(Static.static(FST), model, autodiff) end function LuxCore.initialstates(rng::AbstractRNG, hnn::HamiltonianNN) @@ -69,8 +70,8 @@ function (hnn::HamiltonianNN)(x::AbstractVector, ps, st) return vec(y), stâ‚™ end -function (hnn::HamiltonianNN{FST})(x::AbstractArray{T, N}, ps, st) where {FST, T, N} - model = StatefulLuxLayer{FST}(hnn.model, ps, st.model) +function (hnn::HamiltonianNN)(x::AbstractArray{T, N}, ps, st) where {T, N} + model = StatefulLuxLayer{Static.known(hnn.fixed_state_type)}(hnn.model, ps, st.model) st.first_call && check_hamiltonian_layer(hnn.model, x, ps, st.model)