diff --git a/src/collocated_estim.jl b/src/collocated_estim.jl index 3902f74a2..0fe608e95 100644 --- a/src/collocated_estim.jl +++ b/src/collocated_estim.jl @@ -1,56 +1,14 @@ -# suggested extra loss function +# suggested extra loss function for ODE solver case function L2loss2(Tar::LogTargetDensity, θ) f = Tar.prob.f # parameter estimation chosen or not if Tar.extraparams > 0 - # deri_sol = deri_sol' autodiff = Tar.autodiff - # # Timepoints to enforce Physics - # dataset = Array(reduce(hcat, dataset)') - # t = dataset[end, :] - # û = dataset[1:(end - 1), :] - - # ode_params = Tar.extraparams == 1 ? - # θ[((length(θ) - Tar.extraparams) + 1):length(θ)][1] : - # θ[((length(θ) - Tar.extraparams) + 1):length(θ)] - - # if length(û[:, 1]) == 1 - # physsol = [f(û[:, i][1], - # ode_params, - # t[i]) - # for i in 1:length(û[1, :])] - # else - # physsol = [f(û[:, i], - # ode_params, - # t[i]) - # for i in 1:length(û[1, :])] - # end - # #form of NN output matrix output dim x n - # deri_physsol = reduce(hcat, physsol) - - # > for perfect deriv(basically gradient matching in case of an ODEFunction) - # in case of PDE or general ODE we would want to reduce residue of f(du,u,p,t) - # if length(û[:, 1]) == 1 - # deri_sol = [f(û[:, i][1], - # Tar.prob.p, - # t[i]) - # for i in 1:length(û[1, :])] - # else - # deri_sol = [f(û[:, i], - # Tar.prob.p, - # t[i]) - # for i in 1:length(û[1, :])] - # end - # deri_sol = reduce(hcat, deri_sol) - # deri_sol = reduce(hcat, derivatives) - # Timepoints to enforce Physics t = Tar.dataset[end] u1 = Tar.dataset[2] û = Tar.dataset[1] - # Tar(t, θ[1:(length(θ) - Tar.extraparams)])' - # nnsol = NNodederi(Tar, t, θ[1:(length(θ) - Tar.extraparams)], autodiff) @@ -71,24 +29,7 @@ function L2loss2(Tar::LogTargetDensity, θ) end #form of NN output matrix output dim x n deri_physsol = reduce(hcat, physsol) - - # if length(Tar.prob.u0) == 1 - # nnsol = [f(û[i], - # Tar.prob.p, - # t[i]) - # for i in 1:length(û[:, 1])] - # else - # nnsol = [f([û[i], u1[i]], - # Tar.prob.p, - # t[i]) - # for i in 1:length(û[:, 1])] - # end - # form of NN output matrix output dim x n - # nnsol = reduce(hcat, nnsol) - - # > Instead of dataset gradients trying NN derivatives with dataset collocation - # # convert to matrix as nnsol - + physlogprob = 0 for i in 1:length(Tar.prob.u0) # can add phystd[i] for u[i] @@ -102,64 +43,4 @@ function L2loss2(Tar::LogTargetDensity, θ) else return 0 end -end - -# PDE(DU,U,P,T)=0 - -# Derivated via Central Diff -# function calculate_derivatives2(dataset) -# x̂, time = dataset -# num_points = length(x̂) -# # Initialize an array to store the derivative values. -# derivatives = similar(x̂) - -# for i in 2:(num_points - 1) -# # Calculate the first-order derivative using central differences. -# Δt_forward = time[i + 1] - time[i] -# Δt_backward = time[i] - time[i - 1] - -# derivative = (x̂[i + 1] - x̂[i - 1]) / (Δt_forward + Δt_backward) - -# derivatives[i] = derivative -# end - -# # Derivatives at the endpoints can be calculated using forward or backward differences. -# derivatives[1] = (x̂[2] - x̂[1]) / (time[2] - time[1]) -# derivatives[end] = (x̂[end] - x̂[end - 1]) / (time[end] - time[end - 1]) -# return derivatives -# end - -function calderivatives(prob, dataset) - chainflux = Flux.Chain(Flux.Dense(1, 8, tanh), Flux.Dense(8, 8, tanh), - Flux.Dense(8, 2)) |> Flux.f64 - # chainflux = Flux.Chain(Flux.Dense(1, 7, tanh), Flux.Dense(7, 1)) |> Flux.f64 - function loss(x, y) - # sum(Flux.mse.(prob.u0[1] .+ (prob.tspan[2] .- x)' .* chainflux(x)[1, :], y[1]) + - # Flux.mse.(prob.u0[2] .+ (prob.tspan[2] .- x)' .* chainflux(x)[2, :], y[2])) - # sum(Flux.mse.(prob.u0[1] .+ (prob.tspan[2] .- x)' .* chainflux(x)[1, :], y[1])) - sum(Flux.mse.(chainflux(x), y)) - end - optimizer = Flux.Optimise.ADAM(0.01) - epochs = 3000 - for epoch in 1:epochs - Flux.train!(loss, - Flux.params(chainflux), - [(dataset[end]', dataset[1:(end - 1)])], - optimizer) - end - - # A1 = (prob.u0' .+ - # (prob.tspan[2] .- (dataset[end]' .+ sqrt(eps(eltype(Float64)))))' .* - # chainflux(dataset[end]' .+ sqrt(eps(eltype(Float64))))') - - # A2 = (prob.u0' .+ - # (prob.tspan[2] .- (dataset[end]'))' .* - # chainflux(dataset[end]')') - - A1 = chainflux(dataset[end]' .+ sqrt(eps(eltype(dataset[end][1])))) - A2 = chainflux(dataset[end]') - - gradients = (A2 .- A1) ./ sqrt(eps(eltype(dataset[end][1]))) - - return gradients end \ No newline at end of file