Skip to content

Commit

Permalink
changes from reviews-1
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal committed Aug 7, 2023
1 parent 0bf8152 commit b23db49
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
matrix:
group:
#fixes 682
-ODEBPINN
- ODEBPINN

- NNPDE1
- NNPDE2
Expand Down
18 changes: 9 additions & 9 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ end
function generate_Tar(chain::Flux.Chain, init_params::Nothing)
θ, re = Flux.destructure(chain)
# find_good_stepsize,phasepoint takes only float64
θ = collect(Float64, θ)
return θ, re, nothing
end

Expand Down Expand Up @@ -193,9 +192,10 @@ function physloglikelihood(Tar::LogTargetDensity, θ)
p = Tar.prob.p
dt = Tar.physdt
if isempty(Tar.dataset[end])
t = collect(Float64, Tar.prob.tspan[1]:dt:Tar.prob.tspan[2])
t = collect(eltype(dt), Tar.prob.tspan[1]:dt:Tar.prob.tspan[2])
else
t = vcat(collect(Float64, Tar.prob.tspan[1]:dt:Tar.prob.tspan[2]), Tar.dataset[end])
t = vcat(collect(eltype(dt), Tar.prob.tspan[1]:dt:Tar.prob.tspan[2]),
Tar.dataset[end])
end

# parameter estimation chosen or not
Expand Down Expand Up @@ -309,12 +309,11 @@ end

function ahmc_bayesian_pinn_ode(prob::DiffEqBase.DEProblem, chain;
dataset = [[]],
# dataset::Vector{Vector{Float64}}
init_params = nothing, nchains = 1,
draw_samples = 1000, l2std = [0.05],
phystd = [0.05], priorsNNw = (0.0, 2.0),
param = [],
autodiff = false, physdt = 1 / 20.0f0,
autodiff = false, physdt = 1 / 20.0,
Proposal = StaticTrajectory,
Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8,
Integrator = Leapfrog,
Expand All @@ -338,12 +337,13 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.DEProblem, chain;
throw(error("number of chains must be greater than 1"))
end

#
# eltype(physdt) cause needs Float64 for find_good_stepsize
if chain isa Lux.AbstractExplicitLayer
# Lux chain(using component array later as vector_to_parameter need namedtuple,AHMC uses Float64)
initial_θ = collect(Float64, vcat(ComponentArrays.ComponentArray(initial_nnθ)))
# Lux chain(using component array later as vector_to_parameter need namedtuple)
initial_θ = collect(eltype(physdt),
vcat(ComponentArrays.ComponentArray(initial_nnθ)))
else
initial_θ = initial_nnθ
initial_θ = collect(eltype(physdt), initial_nnθ)
end

# adding ode parameter estimation
Expand Down

0 comments on commit b23db49

Please sign in to comment.