Skip to content

Commit

Permalink
fix: BPINN ODE testing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 13, 2024
1 parent 9bd5e28 commit 648b251
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,8 @@ NN OUTPUT AT t,θ ~ phi(t,θ).
"""
function (f::LogTargetDensity{C, S})(t::AbstractVector,
θ) where {C <: Lux.AbstractLuxLayer, S}
θ = vector_to_parameters(θ, f.init_params)
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
θ = vector_to_parameters(θ, f.init_params)
t_ = convert.(eltypeθ, adapt(typeθ, t'))
y, st = f.chain(t_, θ, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
Expand All @@ -363,8 +363,8 @@ end

function (f::LogTargetDensity{C, S})(t::Number,
θ) where {C <: Lux.AbstractLuxLayer, S}
θ = vector_to_parameters(θ, f.init_params)
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
θ = vector_to_parameters(θ, f.init_params)
t_ = convert.(eltypeθ, adapt(typeθ, [t]))
y, st = f.chain(t_, θ, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
Expand Down
2 changes: 1 addition & 1 deletion test/BPINN_Tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# # Testing Code
# Testing Code
using Test, MCMCChains
using ForwardDiff, Distributions, OrdinaryDiffEq
using OptimizationOptimisers, AdvancedHMC, Lux
Expand Down

0 comments on commit 648b251

Please sign in to comment.