Skip to content

Commit

Permalink
Parameter Estimation works, Custom choice for Prior Distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal committed Jul 22, 2023
1 parent f77d04a commit 7cb6141
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 167 deletions.
135 changes: 72 additions & 63 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
using AdvancedHMC, ForwardDiff, LogDensityProblems, LinearAlgebra
using AdvancedHMC, ForwardDiff, LogDensityProblems, LinearAlgebra, Distributions

mutable struct LogTargetDensity{C, S}
dim::Int
prob::DiffEqBase.DEProblem
chain::C
st::S
dataset::Vector{Vector{Float64}}
priors::Vector{Tuple{Float64, Float64}}
priors::Vector{Distribution}
phystd::Vector{Float64}
l2std::Vector{Float64}
autodiff::Bool
Expand All @@ -30,7 +30,7 @@ mutable struct LogTargetDensity{C, S}
end

function LogDensityProblems.logdensity(Tar::LogTargetDensity, θ)
return physloglikelihood(Tar, θ) + L2LossData(Tar, θ) + priorweights(Tar, θ)
return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ)
end

LogDensityProblems.dimension(Tar::LogTargetDensity) = Tar.dim
Expand Down Expand Up @@ -92,68 +92,75 @@ end
# ODE DU/DX
function NNodederi(phi::LogTargetDensity, t::AbstractVector, θ, autodiff::Bool)
if autodiff
# returns matrix [derivative returns vector(vector)]
hcat(ForwardDiff.derivative.(ti -> phi(ti, θ), t)...)
else
(phi(t .+ sqrt(eps(eltype(t))), θ) - phi(t, θ)) ./ sqrt(eps(eltype(t)))
end
end

# physloglike over problem timespan
# physics loglikelihood over problem timespan
function physloglikelihood(Tar::LogTargetDensity, θ)
p = Tar.prob.p
f = Tar.prob.f
p = Tar.prob.p
t = copy(Tar.dataset[end])

allparams = Tar.priors
invparams = allparams[2:length(allparams)]
meaninv = [invparam[1] for invparam in invparams]
# parameter estimation chosen or not
if Tar.extraparams > 0
ode_params = Tar.extraparams == 1 ?
θ[((length(θ) - Tar.extraparams) + 1):length(θ)][1] :
θ[((length(θ) - Tar.extraparams) + 1):length(θ)]
else
ode_params = p == SciMLBase.NullParameters() ? [] : p
end

# train for NN deriative upon dataset as well as beyond but within timespan
autodiff = Tar.autodiff
dt = Tar.physdt
t = collect(Float64, Tar.prob.tspan[1]:dt:Tar.prob.tspan[2])

# # compare derivatives(matrix)
if t[end] != Tar.prob.tspan[2]
append!(t, collect(Float64, t[end]:dt:Tar.prob.tspan[2]))
end

# compare derivatives(matrix)
out = Tar(t, θ[1:(length(θ) - Tar.extraparams)])

# reject samples case
if any(isinf, out[:, 1]) || any(isinf, ode_params)
return -Inf
end

# # this is a vector{vector{dx,dy}}(handle case single u(float passed))
if length(out[:, 1]) == 1
# shifted by prior mean
ode_params = exp.(θ[((length(θ) - Tar.extraparams) + 1):length(θ)] + log.(meaninv))
physsol = [f(out[:, i][1],
ode_params,
t[i])
for i in 1:length(out[1, :])]
else
# shifted by prior mean
ode_params = exp.(θ[((length(θ) - Tar.extraparams) + 1):length(θ)] + log.(meaninv))
physsol = [f(out[:, i],
ode_params,
t[i])
for i in 1:length(out[1, :])]
end
physsol = hcat(physsol...)

# # convert to matrix as nnsol
# convert to matrix as nnsol
nnsol = NNodederi(Tar, t, θ[1:(length(θ) - Tar.extraparams)], autodiff)

physlogprob = 0
n = length(out[1, :])
for i in 1:length(Tar.prob.u0)
# can add phystd[i] for u[i]
physlogprob += logpdf(MvNormal(nnsol[i, :], Tar.phystd[i]), physsol[i, :])
end
return physlogprob
end

# Standard L2 losses training dataset
# L2 losses loglikelihood
function L2LossData(Tar::LogTargetDensity, θ)
# matrix(each row corresponds to vector u's rows)
nn = Tar(Tar.dataset[length(Tar.dataset)], θ[1:(length(θ) - Tar.extraparams)])
nn = Tar(Tar.dataset[end], θ[1:(length(θ) - Tar.extraparams)])

L2logprob = 0
n = length(nn[1, :])
for i in 1:length(Tar.prob.u0)
# can add l2std[i] for u[i]
# for u[i] ith vector must be added to dataset,nn[1,:] is the dx in lotka_volterra
L2logprob += logpdf(MvNormal(nn[i, :], Tar.l2std[i]), Tar.dataset[i])
end
Expand All @@ -163,21 +170,21 @@ end
# priors for NN parameters + ODE constants
function priorweights(Tar::LogTargetDensity, θ)
allparams = Tar.priors
# ode parameters
invparams = allparams[2:length(allparams)]
stdinv = [invparam[2] for invparam in invparams]
# Vector of ode parameters priors
invpriors = allparams[2:end]

# nn weights
nnwparams = allparams[1]
stdw = nnwparams[2]
prisw = nnwparams[1] .* ones(length(θ) - Tar.extraparams)

if Tar.extraparams > 0
return (logpdf(MvNormal(zeros(Tar.extraparams), stdinv),
θ[((length(θ) - Tar.extraparams) + 1):length(θ)])
invlogpdf = sum(logpdf(invpriors[length(θ) - i + 1], θ[i])
for i in (length(θ) - Tar.extraparams + 1):length(θ); init = 0.0)

return (invlogpdf
+
logpdf(MvNormal(prisw, stdw), θ[1:(length(θ) - Tar.extraparams)]))
logpdf(nnwparams, θ[1:(length(θ) - Tar.extraparams)]))
else
return logpdf(MvNormal(prisw, stdw), θ)
return logpdf(nnwparams, θ)
end
end

Expand All @@ -193,7 +200,7 @@ function integratorchoice(Integrator, initial_ϵ; jitter_rate = 3.0,
end

function proposalchoice(Sampler, Integrator; n_steps = 50,
trajectory_length = 30)
trajectory_length = 30.0)
if Sampler == StaticTrajectory
Sampler(Integrator, n_steps)
elseif Sampler == AdvancedHMC.HMCDA
Expand All @@ -214,8 +221,8 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.DEProblem, chain::Flux.Chain,
param = [],
autodiff = false, physdt = 1 / 20.0f0,
Proposal = StaticTrajectory,
Adaptor = StanHMCAdaptor, targetacceptancerate = 0.75,
Integrator = JitteredLeapfrog,
Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8,
Integrator = Leapfrog,
Metric = DiagEuclideanMetric)

# NN parameter prior mean and variance(PriorsNN must be a tuple)
Expand All @@ -234,24 +241,29 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.DEProblem, chain::Flux.Chain,
end

# adding ode parameter estimation
n = length(param)
priors = [priorsNNw]

# [i[1] for i in param]
if length(param) > 0
append!(initial_θ, zeros(length(param)))
append!(priors, param)
end
nparameters = length(initial_θ)
ninv = length(param)
priors = [MvNormal(priorsNNw[1] * ones(nparameters), priorsNNw[2] * ones(nparameters))]

if ninv > 0
# shift ode params(initialise ode params by prior means)
initial_θ = vcat(initial_θ, [Distributions.params(param[i])[1] for i in ninv])
priors = vcat(priors, param)
nparameters += ninv
end

# Testing for Lux chains
ℓπ = LogTargetDensity(nparameters, prob, recon, st, dataset, priors,
phystd, l2std, autodiff, physdt, n)
phystd, l2std, autodiff, physdt, ninv)

# return physloglikelihood(ℓπ, initial_θ)
# return L2LossData(ℓπ, initial_θ)
# return priorweights(ℓπ, initial_θ)

# [add f(t,θ) for t being a number]
t0 = prob.tspan[1]
try
ℓπ(t0, initial_θ[1:(nparameters - n)])
ℓπ(t0, initial_θ[1:(nparameters - ninv)])
catch err
if isa(err, DimensionMismatch)
throw(DimensionMismatch("Dimensions of the initial u0 and chain should match"))
Expand All @@ -264,15 +276,6 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.DEProblem, chain::Flux.Chain,
metric = Metric(nparameters)
hamiltonian = Hamiltonian(metric, ℓπ, ForwardDiff)

initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)

integrator = integratorchoice(Integrator, initial_ϵ)

proposal = proposalchoice(Proposal, integrator)

adaptor = Adaptor(MassMatrixAdaptor(metric),
StepSizeAdaptor(targetacceptancerate, integrator))

# parallel sampling option
if nchains != 1
# Cache to store the chains
Expand All @@ -281,6 +284,14 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.DEProblem, chain::Flux.Chain,
samplesc = Vector{Any}(undef, nchains)

Threads.@threads for i in 1:nchains
# each chain has different initial parameter values(better posterior exploration)
initial_θ = randn(nparameters)
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
integrator = integratorchoice(Integrator, initial_ϵ)
proposal = proposalchoice(Proposal, integrator)
adaptor = Adaptor(MassMatrixAdaptor(metric),
StepSizeAdaptor(targetacceptancerate, integrator))

samples, stats = sample(hamiltonian, proposal, initial_θ, n_samples, adaptor;
progress = true, verbose = false)
samplesc[i] = samples
Expand All @@ -292,6 +303,13 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.DEProblem, chain::Flux.Chain,

return chains, samplesc, statsc
else
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
println(initial_ϵ)
integrator = integratorchoice(Integrator, initial_ϵ)
proposal = proposalchoice(Proposal, integrator)
adaptor = Adaptor(MassMatrixAdaptor(metric),
StepSizeAdaptor(targetacceptancerate, integrator))

samples, stats = sample(hamiltonian, proposal, initial_θ, draw_samples, adaptor;
progress = true)
# return a chain(basic chain),samples and stats
Expand All @@ -302,16 +320,7 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.DEProblem, chain::Flux.Chain,
end

# test for lux chins
#check if prameters estimation works(no)
# check if prameters estimation works(no)
# fix predictions for odes depending upon 1,p in f(u,p,t)
# lotka volterra parameters estimate
# lotka volterra learn curve beyond l2 losses

# non vectorise call functions(noticed sampling time increase)
# function NNodederi(phi::odeByNN, t::Number, θ, autodiff::Bool)
# if autodiff
# ForwardDiff.jacobian(t -> phi(t, θ), t)
# else
# (phi(t + sqrt(eps(typeof(t))), θ) - phi(t, θ)) / sqrt(eps(typeof(t)))
# end
# end
# lotka volterra learn curve beyond l2 losses
Loading

0 comments on commit 7cb6141

Please sign in to comment.