From 4b6dd3f8945c1d53eb9701129231e2fb3a9dcc9b Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 13 Jul 2023 16:50:15 -0400 Subject: [PATCH] WIP: Calibrate and ensemble endpoints This is relying on some PRs like https://github.com/SciML/SciMLBase.jl/pull/467 I need to figure out who to talk to in order to know what the inputs look like. --- src/SimulationService.jl | 53 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/src/SimulationService.jl b/src/SimulationService.jl index 1f86274..7e0a79e 100644 --- a/src/SimulationService.jl +++ b/src/SimulationService.jl @@ -471,16 +471,59 @@ end #-----------------------------------------------------------------------------# calibrate struct Calibrate <: Operation - # TODO + sys::ODESystem + timespan::Tuple{Float64, Float64} + priors::Any # ??? + data::Any # ??? +end + +Calibrate(o::OperationRequest) = Calibrate(o.sys, o.timespan, o.priors, o.data) + +function solve(o::Calibrate; callback) + prob = ODEProblem(o.sys, [], o.timespan, saveat=1) + + # what the data should be like + # o.data + tsave1 = collect(10.0:10.0:100.0) + sol_data1 = solve(prob, saveat = tsave1) + tsave2 = collect(10.0:13.5:100.0) + sol_data2 = solve(prob, saveat = tsave2) + data_with_t = [x => (tsave1, sol_data1[x]), z => (tsave2, sol_data2[z])] + + p_posterior = bayesian_datafit(prob, o.priors, data_with_t) + + df = DataFrame(last.(p_posterior), :auto) + rename!(df, Symbol.(first.(p_posterior))) + + df end -Calibrate(o::OperationRequest) = error("TODO") -solve(o::Calibrate; callback) = error("TODO") #-----------------------------------------------------------------------------# ensemble struct Ensemble <: Operation - # TODO + sys::Vector{ODESystem} + priors::Vector{Pair{Num,Any}} # Any = Distribution + train_datas::Any + ensem_datas::Any + t_forecast::Vector{Float64} + quantiles::Vector{Float64} end + Ensemble(o::OperationRequest) = error("TODO") -solve(o::Ensemble; callback) = error("TODO") + +function solve(o::Ensemble; callback) + probs = [ODEProblem(s, [], o.timespan) for s in sys] + ps = [[β => Uniform(0.01, 10.0), γ => Uniform(0.01, 10.0)] for i in 1:3] + datas = [data_train,data_train,data_train] + enprobs = bayesian_ensemble(probs, ps, datas) + ensem_weights = ensemble_weights(sol, data_ensem) + + forecast_probs = [remake(enprobs.prob[i]; tspan = (t_train[1],t_forecast[end])) for i in 1:length(enprobs.prob)] + fit_enprob = EnsembleProblem(forecast_probs) + sol = solve(fit_enprob; saveat = o.t_forecast); + + # Requires https://github.com/SciML/SciMLBase.jl/pull/467 + # weighted_ensem = WeightedEnsembleSolution(sol, ensem_weights; quantiles = o.quantiles) + # DataFrame(weighted_ensem) +end end # module