Skip to content

Commit

Permalink
changes from reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal committed May 4, 2024
1 parent 3bb93dd commit 2731063
Showing 1 changed file with 2 additions and 121 deletions.
123 changes: 2 additions & 121 deletions src/collocated_estim.jl
Original file line number Diff line number Diff line change
@@ -1,56 +1,14 @@
# suggested extra loss function
# suggested extra loss function for ODE solver case
function L2loss2(Tar::LogTargetDensity, θ)
f = Tar.prob.f

# parameter estimation chosen or not
if Tar.extraparams > 0
# deri_sol = deri_sol'
autodiff = Tar.autodiff
# # Timepoints to enforce Physics
# dataset = Array(reduce(hcat, dataset)')
# t = dataset[end, :]
# û = dataset[1:(end - 1), :]

# ode_params = Tar.extraparams == 1 ?
# θ[((length(θ) - Tar.extraparams) + 1):length(θ)][1] :
# θ[((length(θ) - Tar.extraparams) + 1):length(θ)]

# if length(û[:, 1]) == 1
# physsol = [f(û[:, i][1],
# ode_params,
# t[i])
# for i in 1:length(û[1, :])]
# else
# physsol = [f(û[:, i],
# ode_params,
# t[i])
# for i in 1:length(û[1, :])]
# end
# #form of NN output matrix output dim x n
# deri_physsol = reduce(hcat, physsol)

# > for perfect deriv(basically gradient matching in case of an ODEFunction)
# in case of PDE or general ODE we would want to reduce residue of f(du,u,p,t)
# if length(û[:, 1]) == 1
# deri_sol = [f(û[:, i][1],
# Tar.prob.p,
# t[i])
# for i in 1:length(û[1, :])]
# else
# deri_sol = [f(û[:, i],
# Tar.prob.p,
# t[i])
# for i in 1:length(û[1, :])]
# end
# deri_sol = reduce(hcat, deri_sol)
# deri_sol = reduce(hcat, derivatives)

# Timepoints to enforce Physics
t = Tar.dataset[end]
u1 = Tar.dataset[2]
= Tar.dataset[1]
# Tar(t, θ[1:(length(θ) - Tar.extraparams)])'
#

nnsol = NNodederi(Tar, t, θ[1:(length(θ) - Tar.extraparams)], autodiff)

Expand All @@ -71,24 +29,7 @@ function L2loss2(Tar::LogTargetDensity, θ)
end
#form of NN output matrix output dim x n
deri_physsol = reduce(hcat, physsol)

# if length(Tar.prob.u0) == 1
# nnsol = [f(û[i],
# Tar.prob.p,
# t[i])
# for i in 1:length(û[:, 1])]
# else
# nnsol = [f([û[i], u1[i]],
# Tar.prob.p,
# t[i])
# for i in 1:length(û[:, 1])]
# end
# form of NN output matrix output dim x n
# nnsol = reduce(hcat, nnsol)

# > Instead of dataset gradients trying NN derivatives with dataset collocation
# # convert to matrix as nnsol


physlogprob = 0
for i in 1:length(Tar.prob.u0)
# can add phystd[i] for u[i]
Expand All @@ -102,64 +43,4 @@ function L2loss2(Tar::LogTargetDensity, θ)
else
return 0
end
end

# PDE(DU,U,P,T)=0

# Derivated via Central Diff
# function calculate_derivatives2(dataset)
# x̂, time = dataset
# num_points = length(x̂)
# # Initialize an array to store the derivative values.
# derivatives = similar(x̂)

# for i in 2:(num_points - 1)
# # Calculate the first-order derivative using central differences.
# Δt_forward = time[i + 1] - time[i]
# Δt_backward = time[i] - time[i - 1]

# derivative = (x̂[i + 1] - x̂[i - 1]) / (Δt_forward + Δt_backward)

# derivatives[i] = derivative
# end

# # Derivatives at the endpoints can be calculated using forward or backward differences.
# derivatives[1] = (x̂[2] - x̂[1]) / (time[2] - time[1])
# derivatives[end] = (x̂[end] - x̂[end - 1]) / (time[end] - time[end - 1])
# return derivatives
# end

function calderivatives(prob, dataset)
chainflux = Flux.Chain(Flux.Dense(1, 8, tanh), Flux.Dense(8, 8, tanh),
Flux.Dense(8, 2)) |> Flux.f64
# chainflux = Flux.Chain(Flux.Dense(1, 7, tanh), Flux.Dense(7, 1)) |> Flux.f64
function loss(x, y)
# sum(Flux.mse.(prob.u0[1] .+ (prob.tspan[2] .- x)' .* chainflux(x)[1, :], y[1]) +
# Flux.mse.(prob.u0[2] .+ (prob.tspan[2] .- x)' .* chainflux(x)[2, :], y[2]))
# sum(Flux.mse.(prob.u0[1] .+ (prob.tspan[2] .- x)' .* chainflux(x)[1, :], y[1]))
sum(Flux.mse.(chainflux(x), y))
end
optimizer = Flux.Optimise.ADAM(0.01)
epochs = 3000
for epoch in 1:epochs
Flux.train!(loss,
Flux.params(chainflux),
[(dataset[end]', dataset[1:(end - 1)])],
optimizer)
end

# A1 = (prob.u0' .+
# (prob.tspan[2] .- (dataset[end]' .+ sqrt(eps(eltype(Float64)))))' .*
# chainflux(dataset[end]' .+ sqrt(eps(eltype(Float64))))')

# A2 = (prob.u0' .+
# (prob.tspan[2] .- (dataset[end]'))' .*
# chainflux(dataset[end]')')

A1 = chainflux(dataset[end]' .+ sqrt(eps(eltype(dataset[end][1]))))
A2 = chainflux(dataset[end]')

gradients = (A2 .- A1) ./ sqrt(eps(eltype(dataset[end][1])))

return gradients
end

0 comments on commit 2731063

Please sign in to comment.