Skip to content

Commit

Permalink
fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal committed Sep 2, 2023
1 parent b814a93 commit 83acfed
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 80 deletions.
75 changes: 4 additions & 71 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ end

function LogDensityProblems.logdensity(Tar::LogTargetDensity, θ)
return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ)
# my suggested Loss likelihood part
# +L2loss2(Tar, θ)
end

LogDensityProblems.dimension(Tar::LogTargetDensity) = Tar.dim
Expand Down Expand Up @@ -184,79 +182,14 @@ function physloglikelihood(Tar::LogTargetDensity, θ)
for i in 1:length(Tar.prob.u0)
# can add phystd[i] for u[i]
physlogprob += logpdf(MvNormal(nnsol[i, :],
LinearAlgebra.Diagonal(
map(abs2,
Tar.phystd[i] .*
ones(length(physsol[i, :]))
)
)
),
physsol[i, :])
LinearAlgebra.Diagonal(map(abs2,
Tar.phystd[i] .*
ones(length(physsol[i, :]))))),
physsol[i, :])
end
return physlogprob
end

# My suggested extra loss function
function L2loss2(Tar::LogTargetDensity, θ)
f = Tar.prob.f

# parameter estimation chosen or not
if Tar.extraparams > 0
dataset = Tar.dataset

# Timepoints to enforce Physics
dataset = Array(reduce(hcat, dataset)')
t = dataset[end, :]
= dataset[1:(end - 1), :]

ode_params = Tar.extraparams == 1 ?
θ[((length(θ) - Tar.extraparams) + 1):length(θ)][1] :
θ[((length(θ) - Tar.extraparams) + 1):length(θ)]

if length(û[:, 1]) == 1
physsol = [f(û[:, i][1],
ode_params,
t[i])
for i in 1:length(û[1, :])]
else
physsol = [f(û[:, i],
ode_params,
t[i])
for i in 1:length(û[1, :])]
end
#form of NN output matrix output dim x n
deri_physsol = reduce(hcat, physsol)

# OG deriv(basically gradient matching in case of an ODEFunction)
# in case of PDE or general ODE we would want to reduce residue of f(du,u,p,t)
if length(û[:, 1]) == 1
deri_sol = [f(û[:, i][1],
Tar.prob.p,
t[i])
for i in 1:length(û[1, :])]
else
deri_sol = [f(û[:, i],
Tar.prob.p,
t[i])
for i in 1:length(û[1, :])]
end
deri_sol = reduce(hcat, deri_sol)

physlogprob = 0
for i in 1:length(Tar.prob.u0)
# can add phystd[i] for u[i]
physlogprob += logpdf(MvNormal(deri_physsol[i, :],
LinearAlgebra.Diagonal(map(abs2,
Tar.l2std[i] .*
ones(length(deri_sol[i, :]))))),
deri_sol[i, :])
end
return physlogprob
else
return 0
end
end

# L2 losses loglikelihood(needed mainly for ODE parameter estimation)
function L2LossData(Tar::LogTargetDensity, θ)
# check if dataset is provided
Expand Down
19 changes: 10 additions & 9 deletions test/BPINN_Tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,10 @@ fh_mcmc_chain1, fhsamples1, fhstats1 = ahmc_bayesian_pinn_ode(prob, chainflux1,
3.0),
param = [
LogNormal(9,
5),
0.5),
],
Metric = DiagEuclideanMetric)
Metric = DiagEuclideanMetric,
n_leapfrog = 30)

fh_mcmc_chain2, fhsamples2, fhstats2 = ahmc_bayesian_pinn_ode(prob, chainlux1,
dataset = dataset,
Expand Down Expand Up @@ -165,6 +166,7 @@ alg = NeuralPDE.BNNODE(chainlux1, dataset = dataset,
],
Metric = DiagEuclideanMetric,
n_leapfrog = 30)

sol2lux = solve(prob, alg)

# testing points
Expand All @@ -191,11 +193,10 @@ meanscurve2 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean
@test abs(p - mean([fhsamples1[i][23] for i in 2000:2500])) < abs(0.2 * p)

#---------------------- solve() call
@test mean(abs.(x̂1 .- sol2flux.ensemblesol[1])) < 5e-1
@test mean(abs.(physsol1_1 .- sol2flux.ensemblesol[1])) < 5e-1
@test mean(abs.(x̂1 .- sol2lux.ensemblesol[1])) < 5e-1
@test mean(abs.(physsol1_1 .- sol2lux.ensemblesol[1])) < 5e-1

@test mean(abs.(x̂1 .- sol2flux.ensemblesol[1])) < 8e-2
@test mean(abs.(physsol1_1 .- sol2flux.ensemblesol[1])) < 8e-2
@test mean(abs.(x̂1 .- sol2lux.ensemblesol[1])) < 8e-2
@test mean(abs.(physsol1_1 .- sol2lux.ensemblesol[1])) < 8e-2
# ESTIMATED ODE PARAMETERS (NN1 AND NN2)
@test abs(p - sol2flux.estimated_ode_params[1]) < abs(0.1 * p)
@test abs(p - sol2lux.estimated_ode_params[1]) < abs(0.1 * p)
Expand Down Expand Up @@ -352,14 +353,14 @@ param1 = mean(i[62] for i in fhsampleslux22[1000:1500])

#-------------------------- solve() call
# (flux chain)
@test mean(abs.(physsol2 .- sol3flux_pestim.ensemblesol[1])) < 5e-2
@test mean(abs.(physsol2 .- sol3flux_pestim.ensemblesol[1])) < 8e-2

# estimated parameters(flux chain)
param1 = sol3flux_pestim.estimated_ode_params[1]
@test abs(param1 - p) < abs(0.35 * p)

# (lux chain)
@prob mean(abs.(physsol2 .- sol3lux_pestim.ensemblesol[1])) < 5e-2
@prob mean(abs.(physsol2 .- sol3lux_pestim.ensemblesol[1])) < 8e-2

# estimated parameters(lux chain)
param1 = sol3lux_pestim.estimated_ode_params[1]
Expand Down

0 comments on commit 83acfed

Please sign in to comment.