From 405925c606d913150339b4aa3ac27e88d34815e4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 13 Sep 2024 23:35:40 -0400 Subject: [PATCH] fix: update the Hamiltonian NN property --- Project.toml | 2 ++ src/DiffEqFlux.jl | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index f8f6a376d..1b6ee765b 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" [weakdeps] DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" @@ -65,6 +66,7 @@ Reexport = "0.2, 1" SciMLBase = "2" SciMLSensitivity = "7" Setfield = "1.1.1" +Static = "1.1.1" Statistics = "1.10" StochasticDiffEq = "6.68.0" Test = "1.10" diff --git a/src/DiffEqFlux.jl b/src/DiffEqFlux.jl index 4bb99c9e7..dc1b9d2a2 100644 --- a/src/DiffEqFlux.jl +++ b/src/DiffEqFlux.jl @@ -20,14 +20,15 @@ using SciMLSensitivity: SciMLSensitivity, AdjointLSS, BacksolveAdjoint, EnzymeVJ SteadyStateAdjoint, TrackerAdjoint, TrackerVJP, ZygoteAdjoint, ZygoteVJP using Setfield: @set! +using Static: True, False const CRC = ChainRulesCore @reexport using ADTypes, Lux, Boltz fixed_state_type(_) = true -# TODO: Update the signature -fixed_state_type(::Layers.HamiltonianNN{FST}) where {FST} = FST +fixed_state_type(::Layers.HamiltonianNN{True}) = true +fixed_state_type(::Layers.HamiltonianNN{False}) = false include("ffjord.jl") include("neural_de.jl")