Skip to content

Commit

Permalink
fix: update the Hamiltonian NN property
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 14, 2024
1 parent 2a240c0 commit 405925c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"

[weakdeps]
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
Expand Down Expand Up @@ -65,6 +66,7 @@ Reexport = "0.2, 1"
SciMLBase = "2"
SciMLSensitivity = "7"
Setfield = "1.1.1"
Static = "1.1.1"
Statistics = "1.10"
StochasticDiffEq = "6.68.0"
Test = "1.10"
Expand Down
5 changes: 3 additions & 2 deletions src/DiffEqFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ using SciMLSensitivity: SciMLSensitivity, AdjointLSS, BacksolveAdjoint, EnzymeVJ
SteadyStateAdjoint, TrackerAdjoint, TrackerVJP, ZygoteAdjoint,
ZygoteVJP
using Setfield: @set!
using Static: True, False

const CRC = ChainRulesCore

@reexport using ADTypes, Lux, Boltz

fixed_state_type(_) = true
# TODO: Update the signature
fixed_state_type(::Layers.HamiltonianNN{FST}) where {FST} = FST
fixed_state_type(::Layers.HamiltonianNN{True}) = true
fixed_state_type(::Layers.HamiltonianNN{False}) = false

include("ffjord.jl")
include("neural_de.jl")
Expand Down

0 comments on commit 405925c

Please sign in to comment.